hgext/record.py
changeset 5037 b2607267236d
child 5040 4f34d9b2568e
equal deleted inserted replaced
5036:ca0d02222d6a 5037:b2607267236d
       
     1 # record.py
       
     2 #
       
     3 # Copyright 2007 Bryan O'Sullivan <bos@serpentine.com>
       
     4 #
       
     5 # This software may be used and distributed according to the terms of
       
     6 # the GNU General Public License, incorporated herein by reference.
       
     7 
       
     8 '''interactive change selection during commit'''
       
     9 
       
    10 from mercurial.i18n import _
       
    11 from mercurial import cmdutil, commands, cmdutil, hg, mdiff, patch, revlog
       
    12 from mercurial import util
       
    13 import copy, cStringIO, errno, operator, os, re, shutil, tempfile
       
    14 
       
    15 lines_re = re.compile(r'@@ -(\d+),(\d+) \+(\d+),(\d+) @@\s*(.*)')
       
    16 
       
    17 def scanpatch(fp):
       
    18     lr = patch.linereader(fp)
       
    19 
       
    20     def scanwhile(first, p):
       
    21         lines = [first]
       
    22         while True:
       
    23             line = lr.readline()
       
    24             if not line:
       
    25                 break
       
    26             if p(line):
       
    27                 lines.append(line)
       
    28             else:
       
    29                 lr.push(line)
       
    30                 break
       
    31         return lines
       
    32 
       
    33     while True:
       
    34         line = lr.readline()
       
    35         if not line:
       
    36             break
       
    37         if line.startswith('diff --git a/'):
       
    38             def notheader(line):
       
    39                 s = line.split(None, 1)
       
    40                 return not s or s[0] not in ('---', 'diff')
       
    41             header = scanwhile(line, notheader)
       
    42             fromfile = lr.readline()
       
    43             if fromfile.startswith('---'):
       
    44                 tofile = lr.readline()
       
    45                 header += [fromfile, tofile]
       
    46             else:
       
    47                 lr.push(fromfile)
       
    48             yield 'file', header
       
    49         elif line[0] == ' ':
       
    50             yield 'context', scanwhile(line, lambda l: l[0] in ' \\')
       
    51         elif line[0] in '-+':
       
    52             yield 'hunk', scanwhile(line, lambda l: l[0] in '-+\\')
       
    53         else:
       
    54             m = lines_re.match(line)
       
    55             if m:
       
    56                 yield 'range', m.groups()
       
    57             else:
       
    58                 raise patch.PatchError('unknown patch content: %r' % line)
       
    59 
       
    60 class header(object):
       
    61     diff_re = re.compile('diff --git a/(.*) b/(.*)$')
       
    62     allhunks_re = re.compile('(?:index|new file|deleted file) ')
       
    63     pretty_re = re.compile('(?:new file|deleted file) ')
       
    64     special_re = re.compile('(?:index|new|deleted|copy|rename) ')
       
    65 
       
    66     def __init__(self, header):
       
    67         self.header = header
       
    68         self.hunks = []
       
    69 
       
    70     def binary(self):
       
    71         for h in self.header:
       
    72             if h.startswith('index '):
       
    73                 return True
       
    74         
       
    75     def pretty(self, fp):
       
    76         for h in self.header:
       
    77             if h.startswith('index '):
       
    78                 fp.write(_('this modifies a binary file (all or nothing)\n'))
       
    79                 break
       
    80             if self.pretty_re.match(h):
       
    81                 fp.write(h)
       
    82                 if self.binary():
       
    83                     fp.write(_('this is a binary file\n'))
       
    84                 break
       
    85             if h.startswith('---'):
       
    86                 fp.write(_('%d hunks, %d lines changed\n') %
       
    87                          (len(self.hunks),
       
    88                           sum([h.added + h.removed for h in self.hunks])))
       
    89                 break
       
    90             fp.write(h)
       
    91 
       
    92     def write(self, fp):
       
    93         fp.write(''.join(self.header))
       
    94 
       
    95     def allhunks(self):
       
    96         for h in self.header:
       
    97             if self.allhunks_re.match(h):
       
    98                 return True
       
    99 
       
   100     def files(self):
       
   101         fromfile, tofile = self.diff_re.match(self.header[0]).groups()
       
   102         if fromfile == tofile:
       
   103             return [fromfile]
       
   104         return [fromfile, tofile]
       
   105 
       
   106     def filename(self):
       
   107         return self.files()[-1]
       
   108 
       
   109     def __repr__(self):
       
   110         return '<header %s>' % (' '.join(map(repr, self.files())))
       
   111 
       
   112     def special(self):
       
   113         for h in self.header:
       
   114             if self.special_re.match(h):
       
   115                 return True
       
   116 
       
   117 def countchanges(hunk):
       
   118     add = len([h for h in hunk if h[0] == '+'])
       
   119     rem = len([h for h in hunk if h[0] == '-'])
       
   120     return add, rem
       
   121 
       
   122 class hunk(object):
       
   123     maxcontext = 3
       
   124 
       
   125     def __init__(self, header, fromline, toline, proc, before, hunk, after):
       
   126         def trimcontext(number, lines):
       
   127             delta = len(lines) - self.maxcontext
       
   128             if False and delta > 0:
       
   129                 return number + delta, lines[:self.maxcontext]
       
   130             return number, lines
       
   131 
       
   132         self.header = header
       
   133         self.fromline, self.before = trimcontext(fromline, before)
       
   134         self.toline, self.after = trimcontext(toline, after)
       
   135         self.proc = proc
       
   136         self.hunk = hunk
       
   137         self.added, self.removed = countchanges(self.hunk)
       
   138 
       
   139     def write(self, fp):
       
   140         delta = len(self.before) + len(self.after)
       
   141         fromlen = delta + self.removed
       
   142         tolen = delta + self.added
       
   143         fp.write('@@ -%d,%d +%d,%d @@%s\n' %
       
   144                  (self.fromline, fromlen, self.toline, tolen,
       
   145                   self.proc and (' ' + self.proc)))
       
   146         fp.write(''.join(self.before + self.hunk + self.after))
       
   147 
       
   148     pretty = write
       
   149 
       
   150     def filename(self):
       
   151         return self.header.filename()
       
   152 
       
   153     def __repr__(self):
       
   154         return '<hunk %r@%d>' % (self.filename(), self.fromline)
       
   155 
       
   156 def parsepatch(fp):
       
   157     class parser(object):
       
   158         def __init__(self):
       
   159             self.fromline = 0
       
   160             self.toline = 0
       
   161             self.proc = ''
       
   162             self.header = None
       
   163             self.context = []
       
   164             self.before = []
       
   165             self.hunk = []
       
   166             self.stream = []
       
   167 
       
   168         def addrange(self, (fromstart, fromend, tostart, toend, proc)):
       
   169             self.fromline = int(fromstart)
       
   170             self.toline = int(tostart)
       
   171             self.proc = proc
       
   172 
       
   173         def addcontext(self, context):
       
   174             if self.hunk:
       
   175                 h = hunk(self.header, self.fromline, self.toline, self.proc,
       
   176                          self.before, self.hunk, context)
       
   177                 self.header.hunks.append(h)
       
   178                 self.stream.append(h)
       
   179                 self.fromline += len(self.before) + h.removed
       
   180                 self.toline += len(self.before) + h.added
       
   181                 self.before = []
       
   182                 self.hunk = []
       
   183                 self.proc = ''
       
   184             self.context = context
       
   185 
       
   186         def addhunk(self, hunk):
       
   187             if self.context:
       
   188                 self.before = self.context
       
   189                 self.context = []
       
   190             self.hunk = data
       
   191 
       
   192         def newfile(self, hdr):
       
   193             self.addcontext([])
       
   194             h = header(hdr)
       
   195             self.stream.append(h)
       
   196             self.header = h
       
   197 
       
   198         def finished(self):
       
   199             self.addcontext([])
       
   200             return self.stream
       
   201 
       
   202         transitions = {
       
   203             'file': {'context': addcontext,
       
   204                      'file': newfile,
       
   205                      'hunk': addhunk,
       
   206                      'range': addrange},
       
   207             'context': {'file': newfile,
       
   208                         'hunk': addhunk,
       
   209                         'range': addrange},
       
   210             'hunk': {'context': addcontext,
       
   211                      'file': newfile,
       
   212                      'range': addrange},
       
   213             'range': {'context': addcontext,
       
   214                       'hunk': addhunk},
       
   215             }
       
   216              
       
   217     p = parser()
       
   218 
       
   219     state = 'context'
       
   220     for newstate, data in scanpatch(fp):
       
   221         try:
       
   222             p.transitions[state][newstate](p, data)
       
   223         except KeyError:
       
   224             raise patch.PatchError('unhandled transition: %s -> %s' %
       
   225                                    (state, newstate))
       
   226         state = newstate
       
   227     return p.finished()
       
   228 
       
   229 def filterpatch(ui, chunks):
       
   230     chunks = list(chunks)
       
   231     chunks.reverse()
       
   232     seen = {}
       
   233     def consumefile():
       
   234         consumed = []
       
   235         while chunks:
       
   236             if isinstance(chunks[-1], header):
       
   237                 break
       
   238             else:
       
   239                 consumed.append(chunks.pop())
       
   240         return consumed
       
   241     resp = None
       
   242     applied = {}
       
   243     while chunks:
       
   244         chunk = chunks.pop()
       
   245         if isinstance(chunk, header):
       
   246             fixoffset = 0
       
   247             hdr = ''.join(chunk.header)
       
   248             if hdr in seen:
       
   249                 consumefile()
       
   250                 continue
       
   251             seen[hdr] = True
       
   252             if not resp:
       
   253                 chunk.pretty(ui)
       
   254             r = resp or ui.prompt(_('record changes to %s? [y]es [n]o') %
       
   255                                   _(' and ').join(map(repr, chunk.files())),
       
   256                                   '(?:|[yYnNqQaA])$') or 'y'
       
   257             if r in 'aA':
       
   258                 r = 'y'
       
   259                 resp = 'y'
       
   260             if r in 'qQ':
       
   261                 raise util.Abort(_('user quit'))
       
   262             if r in 'yY':
       
   263                 applied[chunk.filename()] = [chunk]
       
   264                 if chunk.allhunks():
       
   265                     applied[chunk.filename()] += consumefile()
       
   266             else:
       
   267                 consumefile()
       
   268         else:
       
   269             if not resp:
       
   270                 chunk.pretty(ui)
       
   271             r = resp or ui.prompt(_('record this change to %r? [y]es [n]o') %
       
   272                                   chunk.filename(), '(?:|[yYnNqQaA])$') or 'y'
       
   273             if r in 'aA':
       
   274                 r = 'y'
       
   275                 resp = 'y'
       
   276             if r in 'qQ':
       
   277                 raise util.Abort(_('user quit'))
       
   278             if r in 'yY':
       
   279                 if fixoffset:
       
   280                     chunk = copy.copy(chunk)
       
   281                     chunk.toline += fixoffset
       
   282                 applied[chunk.filename()].append(chunk)
       
   283             else:
       
   284                 fixoffset += chunk.removed - chunk.added
       
   285     return reduce(operator.add, [h for h in applied.itervalues()
       
   286                                  if h[0].special() or len(h) > 1], [])
       
   287 
       
   288 def record(ui, repo, *pats, **opts):
       
   289     '''interactively select changes to commit'''
       
   290 
       
   291     if not ui.interactive:
       
   292         raise util.Abort(_('running non-interactively, use commit instead'))
       
   293 
       
   294     def recordfunc(ui, repo, files, message, match, opts):
       
   295         if files:
       
   296             changes = None
       
   297         else:
       
   298             changes = repo.status(files=files, match=match)[:5]
       
   299             modified, added, removed = changes[:3]
       
   300             files = modified + added + removed
       
   301         diffopts = mdiff.diffopts(git=True, nodates=True)
       
   302         fp = cStringIO.StringIO()
       
   303         patch.diff(repo, repo.dirstate.parents()[0], files=files,
       
   304                    match=match, changes=changes, opts=diffopts, fp=fp)
       
   305         fp.seek(0)
       
   306 
       
   307         chunks = filterpatch(ui, parsepatch(fp))
       
   308         del fp
       
   309 
       
   310         contenders = {}
       
   311         for h in chunks:
       
   312             try: contenders.update(dict.fromkeys(h.files()))
       
   313             except AttributeError: pass
       
   314             
       
   315         newfiles = [f for f in files if f in contenders]
       
   316 
       
   317         if not newfiles:
       
   318             ui.status(_('no changes to record\n'))
       
   319             return 0
       
   320 
       
   321         if changes is None:
       
   322             changes = repo.status(files=newfiles, match=match)[:5]
       
   323         modified = dict.fromkeys(changes[0])
       
   324 
       
   325         backups = {}
       
   326         backupdir = repo.join('record-backups')
       
   327         try:
       
   328             os.mkdir(backupdir)
       
   329         except OSError, err:
       
   330             if err.errno == errno.EEXIST:
       
   331                 pass
       
   332         try:
       
   333             for f in newfiles:
       
   334                 if f not in modified:
       
   335                     continue
       
   336                 fd, tmpname = tempfile.mkstemp(prefix=f.replace('/', '_')+'.',
       
   337                                                dir=backupdir)
       
   338                 os.close(fd)
       
   339                 ui.debug('backup %r as %r\n' % (f, tmpname))
       
   340                 util.copyfile(repo.wjoin(f), tmpname)
       
   341                 backups[f] = tmpname
       
   342 
       
   343             fp = cStringIO.StringIO()
       
   344             for c in chunks:
       
   345                 if c.filename() in backups:
       
   346                     c.write(fp)
       
   347             dopatch = fp.tell()
       
   348             fp.seek(0)
       
   349 
       
   350             if backups:
       
   351                 hg.revert(repo, repo.dirstate.parents()[0], backups.has_key)
       
   352 
       
   353             if dopatch:
       
   354                 ui.debug('applying patch\n')
       
   355                 ui.debug(fp.getvalue())
       
   356                 patch.internalpatch(fp, ui, 1, repo.root)
       
   357             del fp
       
   358 
       
   359             repo.commit(newfiles, message, opts['user'], opts['date'], match,
       
   360                         force_editor=opts.get('force_editor'))
       
   361             return 0
       
   362         finally:
       
   363             try:
       
   364                 for realname, tmpname in backups.iteritems():
       
   365                     ui.debug('restoring %r to %r\n' % (tmpname, realname))
       
   366                     util.copyfile(tmpname, realname)
       
   367                     os.unlink(tmpname)
       
   368                 os.rmdir(backupdir)
       
   369             except OSError:
       
   370                 pass
       
   371     return cmdutil.commit(ui, repo, recordfunc, pats, opts)
       
   372 
       
   373 cmdtable = {
       
   374     'record':
       
   375     (record, [('A', 'addremove', None,
       
   376                _('mark new/missing files as added/removed before committing')),
       
   377               ('d', 'date', '', _('record datecode as commit date')),
       
   378               ('u', 'user', '', _('record user as commiter')),
       
   379               ] + commands.walkopts + commands.commitopts,
       
   380      _('hg record [FILE]...')),
       
   381     }