mercurial/hg.py
changeset 635 85e2209d401c
parent 634 da5378d39269
child 636 ac0ec421e3a5
--- a/mercurial/hg.py
+++ b/mercurial/hg.py
@@ -1025,35 +1025,6 @@ class localrepository:
         return remote.addchangegroup(cg)
 
     def changegroup(self, basenodes):
-        nodes = self.newer(basenodes)
-
-        # construct the link map
-        linkmap = {}
-        for n in nodes:
-            linkmap[self.changelog.rev(n)] = n
-
-        # construct a list of all changed files
-        changed = {}
-        for n in nodes:
-            c = self.changelog.read(n)
-            for f in c[3]:
-                changed[f] = 1
-        changed = changed.keys()
-        changed.sort()
-
-        # the changegroup is changesets + manifests + all file revs
-        revs = [ self.changelog.rev(n) for n in nodes ]
-
-        for y in self.changelog.group(linkmap): yield y
-        for y in self.manifest.group(linkmap): yield y
-        for f in changed:
-            yield struct.pack(">l", len(f) + 4) + f
-            g = self.file(f).group(linkmap)
-            for y in g:
-                yield y
-
-    def addchangegroup(self, generator):
-
         class genread:
             def __init__(self, generator):
                 self.g = generator
@@ -1067,6 +1038,40 @@ class localrepository:
                 d, self.buf = self.buf[:l], self.buf[l:]
                 return d
 
+        def gengroup():
+            nodes = self.newer(basenodes)
+
+            # construct the link map
+            linkmap = {}
+            for n in nodes:
+                linkmap[self.changelog.rev(n)] = n
+
+            # construct a list of all changed files
+            changed = {}
+            for n in nodes:
+                c = self.changelog.read(n)
+                for f in c[3]:
+                    changed[f] = 1
+            changed = changed.keys()
+            changed.sort()
+
+            # the changegroup is changesets + manifests + all file revs
+            revs = [ self.changelog.rev(n) for n in nodes ]
+
+            for y in self.changelog.group(linkmap): yield y
+            for y in self.manifest.group(linkmap): yield y
+            for f in changed:
+                yield struct.pack(">l", len(f) + 4) + f
+                g = self.file(f).group(linkmap)
+                for y in g:
+                    yield y
+
+            yield struct.pack(">l", 0)
+
+        return genread(gengroup())
+
+    def addchangegroup(self, source):
+
         def getchunk():
             d = source.read(4)
             if not d: return ""
@@ -1087,10 +1092,9 @@ class localrepository:
         def revmap(x):
             return self.changelog.rev(x)
 
-        if not generator: return
+        if not source: return
         changesets = files = revisions = 0
 
-        source = genread(generator)
         tr = self.transaction()
 
         # pull off the changeset group
@@ -1592,17 +1596,27 @@ class httprepository:
 
     def changegroup(self, nodes):
         n = " ".join(map(hex, nodes))
-        zd = zlib.decompressobj()
         f = self.do_cmd("changegroup", roots=n)
         bytes = 0
-        while 1:
-            d = f.read(4096)
-            bytes += len(d)
-            if not d:
-                yield zd.flush()
-                break
-            yield zd.decompress(d)
-        self.ui.note("%d bytes of data transfered\n" % bytes)
+
+        class zread:
+            def __init__(self, f):
+                self.zd = zlib.decompressobj()
+                self.f = f
+                self.buf = ""
+            def read(self, l):
+                while l > len(self.buf):
+                    r = f.read(4096)
+                    if r:
+                        self.buf += self.zd.decompress(r)
+                    else:
+                        self.buf += self.zd.flush()
+                        break
+                d, self.buf = self.buf[:l], self.buf[l:]
+                return d
+
+        return zread(f)
+
 
 class sshrepository:
     def __init__(self, ui, path):
@@ -1680,14 +1694,7 @@ class sshrepository:
     def changegroup(self, nodes):
         n = " ".join(map(hex, nodes))
         f = self.do_cmd("changegroup", roots=n)
-        bytes = 0
-        while 1:
-            l = struct.unpack(">l", f.read(4))[0]
-            if l == -1: break
-            d = f.read(l)
-            bytes += len(d)
-            yield d
-        self.ui.note("%d bytes of data transfered\n" % bytes)
+        return self.pipei
 
 def repository(ui, path=None, create=0):
     if path: