changeset 2396:8d44649df03b

refactor ssh server.
author Vadim Gelfer <vadim.gelfer@gmail.com>
date Sun, 04 Jun 2006 10:26:05 -0700
parents a8f1049d1d2d
children e9d402506514
files mercurial/commands.py mercurial/sshserver.py
diffstat 2 files changed, 103 insertions(+), 71 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/commands.py
+++ b/mercurial/commands.py
@@ -13,7 +13,7 @@ demandload(globals(), "fancyopts ui hg u
 demandload(globals(), "fnmatch mdiff random signal tempfile time")
 demandload(globals(), "traceback errno socket version struct atexit sets bz2")
 demandload(globals(), "archival changegroup")
-demandload(globals(), "hgweb.server")
+demandload(globals(), "hgweb.server sshserver")
 
 class UnknownCommand(Exception):
     """Exception raised if command is not in the command table."""
@@ -2452,76 +2452,8 @@ def serve(ui, repo, **opts):
     if opts["stdio"]:
         if repo is None:
             raise hg.RepoError(_('no repo found'))
-        fin, fout = sys.stdin, sys.stdout
-        sys.stdout = sys.stderr
-
-        # Prevent insertion/deletion of CRs
-        util.set_binary(fin)
-        util.set_binary(fout)
-
-        def getarg():
-            argline = fin.readline()[:-1]
-            arg, l = argline.split()
-            val = fin.read(int(l))
-            return arg, val
-        def respond(v):
-            fout.write("%d\n" % len(v))
-            fout.write(v)
-            fout.flush()
-
-        lock = None
-
-        while 1:
-            cmd = fin.readline()[:-1]
-            if cmd == '':
-                return
-            if cmd == "heads":
-                h = repo.heads()
-                respond(" ".join(map(hex, h)) + "\n")
-            if cmd == "lock":
-                lock = repo.lock()
-                respond("")
-            if cmd == "unlock":
-                if lock:
-                    lock.release()
-                lock = None
-                respond("")
-            elif cmd == "branches":
-                arg, nodes = getarg()
-                nodes = map(bin, nodes.split(" "))
-                r = []
-                for b in repo.branches(nodes):
-                    r.append(" ".join(map(hex, b)) + "\n")
-                respond("".join(r))
-            elif cmd == "between":
-                arg, pairs = getarg()
-                pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
-                r = []
-                for b in repo.between(pairs):
-                    r.append(" ".join(map(hex, b)) + "\n")
-                respond("".join(r))
-            elif cmd == "changegroup":
-                nodes = []
-                arg, roots = getarg()
-                nodes = map(bin, roots.split(" "))
-
-                cg = repo.changegroup(nodes, 'serve')
-                while 1:
-                    d = cg.read(4096)
-                    if not d:
-                        break
-                    fout.write(d)
-
-                fout.flush()
-
-            elif cmd == "addchangegroup":
-                if not lock:
-                    respond("not locked")
-                    continue
-                respond("")
-
-                r = repo.addchangegroup(fin, 'serve')
-                respond(str(r))
+        s = sshserver.sshserver(ui, repo)
+        s.serve_forever()
 
     optlist = ("name templates style address port ipv6"
                " accesslog errorlog webdir_conf")
new file mode 100644
--- /dev/null
+++ b/mercurial/sshserver.py
@@ -0,0 +1,100 @@
+# commands.py - command processing for mercurial
+#
+# Copyright 2005 Matt Mackall <mpm@selenic.com>
+#
+# This software may be used and distributed according to the terms
+# of the GNU General Public License, incorporated herein by reference.
+
+from demandload import demandload
+from i18n import gettext as _
+from node import *
+demandload(globals(), "sys util")
+
+class sshserver(object):
+    def __init__(self, ui, repo):
+        self.ui = ui
+        self.repo = repo
+        self.lock = None
+        self.fin = sys.stdin
+        self.fout = sys.stdout
+
+        sys.stdout = sys.stderr
+
+        # Prevent insertion/deletion of CRs
+        util.set_binary(self.fin)
+        util.set_binary(self.fout)
+
+    def getarg(self):
+        argline = self.fin.readline()[:-1]
+        arg, l = argline.split()
+        val = self.fin.read(int(l))
+        return arg, val
+
+    def respond(self, v):
+        self.fout.write("%d\n" % len(v))
+        self.fout.write(v)
+        self.fout.flush()
+
+    def serve_forever(self):
+        while self.serve_one(): pass
+        sys.exit(0)
+
+    def serve_one(self):
+        cmd = self.fin.readline()[:-1]
+        if cmd:
+            impl = getattr(self, 'do_' + cmd, None)
+            if impl: impl()
+        return cmd != ''
+
+    def do_heads(self):
+        h = self.repo.heads()
+        self.respond(" ".join(map(hex, h)) + "\n")
+
+    def do_lock(self):
+        self.lock = self.repo.lock()
+        self.respond("")
+
+    def do_unlock(self):
+        if self.lock:
+            self.lock.release()
+        self.lock = None
+        self.respond("")
+
+    def do_branches(self):
+        arg, nodes = self.getarg()
+        nodes = map(bin, nodes.split(" "))
+        r = []
+        for b in self.repo.branches(nodes):
+            r.append(" ".join(map(hex, b)) + "\n")
+        self.respond("".join(r))
+
+    def do_between(self):
+        arg, pairs = self.getarg()
+        pairs = [map(bin, p.split("-")) for p in pairs.split(" ")]
+        r = []
+        for b in self.repo.between(pairs):
+            r.append(" ".join(map(hex, b)) + "\n")
+        self.respond("".join(r))
+
+    def do_changegroup(self):
+        nodes = []
+        arg, roots = self.getarg()
+        nodes = map(bin, roots.split(" "))
+
+        cg = self.repo.changegroup(nodes, 'serve')
+        while True:
+            d = cg.read(4096)
+            if not d:
+                break
+            self.fout.write(d)
+
+        self.fout.flush()
+
+    def do_addchangegroup(self):
+        if not self.lock:
+            self.respond("not locked")
+            return
+
+        self.respond("")
+        r = self.repo.addchangegroup(self.fin, 'serve')
+        self.respond(str(r))