diff mercurial/util.py @ 5141:d316124ebbea

Make audit_path more stringent. The following properties of a path are now checked for: - under top-level .hg - starts at the root of a windows drive - contains ".." - traverses a symlink (e.g. a/symlink_here/b) - inside a nested repository If any of these is true, the path is rejected. The check for traversing a symlink is arguably stricter than necessary; perhaps we should be checking for symlinks that point outside the repository.
author Bryan O'Sullivan <bos@serpentine.com>
date Fri, 10 Aug 2007 10:46:03 -0700
parents a2c11f49e989
children d84329a11fdd
line wrap: on
line diff
--- a/mercurial/util.py
+++ b/mercurial/util.py
@@ -13,8 +13,8 @@ platform-specific details from the core.
 """
 
 from i18n import _
-import cStringIO, errno, getpass, popen2, re, shutil, sys, tempfile
-import os, threading, time, calendar, ConfigParser, locale, glob
+import cStringIO, errno, getpass, popen2, re, shutil, sys, tempfile, strutil
+import os, stat, threading, time, calendar, ConfigParser, locale, glob
 
 try:
     set = set
@@ -366,6 +366,7 @@ def canonpath(root, cwd, myname):
     if not os.path.isabs(name):
         name = os.path.join(root, cwd, name)
     name = os.path.normpath(name)
+    audit_path = path_auditor(root)
     if name != rootsep and name.startswith(rootsep):
         name = name[len(rootsep):]
         audit_path(name)
@@ -680,12 +681,45 @@ def copyfiles(src, dst, hardlink=None):
         else:
             shutil.copy(src, dst)
 
-def audit_path(path):
-    """Abort if path contains dangerous components"""
-    parts = os.path.normcase(path).split(os.sep)
-    if (os.path.splitdrive(path)[0] or parts[0] in ('.hg', '')
-        or os.pardir in parts):
-        raise Abort(_("path contains illegal component: %s") % path)
+class path_auditor(object):
+    '''ensure that a filesystem path contains no banned components.
+    the following properties of a path are checked:
+
+    - under top-level .hg
+    - starts at the root of a windows drive
+    - contains ".."
+    - traverses a symlink (e.g. a/symlink_here/b)
+    - inside a nested repository'''
+
+    def __init__(self, root):
+        self.audited = {}
+        self.root = root
+
+    def __call__(self, path):
+        if path in self.audited:
+            return
+        parts = os.path.normcase(path).split(os.sep)
+        if (os.path.splitdrive(path)[0] or parts[0] in ('.hg', '')
+            or os.pardir in parts):
+            raise Abort(_("path contains illegal component: %s") % path)
+        def check(prefix):
+            curpath = os.path.join(self.root, prefix)
+            try:
+                st = os.lstat(curpath)
+            except OSError, err:
+                if err.errno != errno.ENOENT:
+                    raise
+            else:
+                if stat.S_ISLNK(st.st_mode):
+                    raise Abort(_('path %r traverses symbolic link %r') %
+                                (path, prefix))
+                if os.path.exists(os.path.join(curpath, '.hg')):
+                    raise Abort(_('path %r is inside repo %r') %
+                                (path, prefix))
+            self.audited[prefix] = True
+        for c in strutil.rfindall(path, os.sep):
+            check(path[:c])
+        self.audited[path] = True
 
 def _makelock_file(info, pathname):
     ld = os.open(pathname, os.O_CREAT | os.O_WRONLY | os.O_EXCL)
@@ -1262,7 +1296,10 @@ class opener(object):
     """
     def __init__(self, base, audit=True):
         self.base = base
-        self.audit = audit
+        if audit:
+            self.audit_path = path_auditor(base)
+        else:
+            self.audit_path = always
 
     def __getattr__(self, name):
         if name == '_can_symlink':
@@ -1271,8 +1308,7 @@ class opener(object):
         raise AttributeError(name)
 
     def __call__(self, path, mode="r", text=False, atomictemp=False):
-        if self.audit:
-            audit_path(path)
+        self.audit_path(path)
         f = os.path.join(self.base, path)
 
         if not text and "b" not in mode:
@@ -1293,8 +1329,7 @@ class opener(object):
         return posixfile(f, mode)
 
     def symlink(self, src, dst):
-        if self.audit:
-            audit_path(dst)
+        self.audit_path(dst)
         linkname = os.path.join(self.base, dst)
         try:
             os.unlink(linkname)