Explorar el Código

Use context manager for locking

Unrud hace 9 años
padre
commit
bca6cec6b3
Se han modificado 2 ficheros con 31 adiciones y 38 borrados
  1. 26 34
      radicale/__init__.py
  2. 5 4
      radicale/storage.py

+ 26 - 34
radicale/__init__.py

@@ -282,45 +282,37 @@ class Application:
         is_authenticated = self.is_authenticated(user, password)
         is_valid_user = is_authenticated or not user
 
-        lock = None
-        try:
-            if is_valid_user:
-                if function in (self.do_GET, self.do_HEAD,
-                                self.do_OPTIONS, self.do_PROPFIND,
-                                self.do_REPORT):
-                    lock_mode = "r"
-                else:
-                    lock_mode = "w"
-                lock = self.Collection.acquire_lock(lock_mode)
+        # Get content
+        content_length = int(environ.get("CONTENT_LENGTH") or 0)
+        if content_length:
+            content = self.decode(
+                environ["wsgi.input"].read(content_length), environ)
+            self.logger.debug("Request content:\n%s" % content)
+        else:
+            content = None
 
+        if is_valid_user:
+            if function in (self.do_GET, self.do_HEAD,
+                            self.do_OPTIONS, self.do_PROPFIND,
+                            self.do_REPORT):
+                lock_mode = "r"
+            else:
+                lock_mode = "w"
+            with self.Collection.acquire_lock(lock_mode):
                 items = self.Collection.discover(
                     path, environ.get("HTTP_DEPTH", "0"))
                 read_allowed_items, write_allowed_items = (
                     self.collect_allowed_items(items, user))
-            else:
-                read_allowed_items, write_allowed_items = None, None
-
-            # Get content
-            content_length = int(environ.get("CONTENT_LENGTH") or 0)
-            if content_length:
-                content = self.decode(
-                    environ["wsgi.input"].read(content_length), environ)
-                self.logger.debug("Request content:\n%s" % content)
-            else:
-                content = None
-
-            if is_valid_user and (
-                    (read_allowed_items or write_allowed_items) or
-                    (is_authenticated and function == self.do_PROPFIND) or
-                    function == self.do_OPTIONS):
-                status, headers, answer = function(
-                    environ, read_allowed_items, write_allowed_items, content,
-                    user)
-            else:
-                status, headers, answer = NOT_ALLOWED
-        finally:
-            if lock:
-                lock.release()
+                if (read_allowed_items or write_allowed_items or
+                        is_authenticated and function == self.do_PROPFIND or
+                        function == self.do_OPTIONS):
+                    status, headers, answer = function(
+                        environ, read_allowed_items, write_allowed_items,
+                        content, user)
+                else:
+                    status, headers, answer = NOT_ALLOWED
+        else:
+            status, headers, answer = NOT_ALLOWED
 
         if (status, headers, answer) == NOT_ALLOWED and not is_authenticated:
             # Unknown or unauthorized user

+ 5 - 4
radicale/storage.py

@@ -277,14 +277,13 @@ class BaseCollection:
         raise NotImplementedError
 
     @classmethod
+    @contextmanager
     def acquire_lock(cls, mode):
-        """Lock the whole storage.
+        """Set a context manager to lock the whole storage.
 
         ``mode`` must either be "r" for shared access or "w" for exclusive
         access.
 
-        Returns an object which has a method ``release``.
-
         """
         raise NotImplementedError
 
@@ -521,6 +520,7 @@ class Collection(BaseCollection):
     _lock = threading.Lock()
 
     @classmethod
+    @contextmanager
     def acquire_lock(cls, mode):
         class Lock:
             def __init__(self, release_method):
@@ -574,4 +574,5 @@ class Collection(BaseCollection):
             # TODO: use readers–writer lock
             cls._lock.acquire()
             lock = Lock(cls._lock.release)
-        return lock
+        yield
+        lock.release()