Unrud 7 лет назад
Родитель
Сommit
24815255be
3 измененных файлов с 111 добавлено и 42 удалено
  1. 60 41
      radicale/__init__.py
  2. 50 1
      radicale/log.py
  3. 1 0
      radicale/tests/__init__.py

+ 60 - 41
radicale/__init__.py

@@ -204,19 +204,23 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
             return super().process_request_thread(request, client_address)
 
 
+class ServerHandler(wsgiref.simple_server.ServerHandler):
+
+
+    def log_exception(self, exc_info):
+        logger.error("An exception occurred during request: %s",
+                     exc_info[1], exc_info=exc_info)
+
+
 class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
     """HTTP requests handler."""
 
-    def __init__(self, *args, **kwargs):
-        # Store exception for logging
-        self.error_stream = io.StringIO()
-        super().__init__(*args, **kwargs)
+    def log_request(self, code="-", size="-"):
+        """Disable request logging."""
 
-    def get_stderr(self):
-        return self.error_stream
-
-    def log_message(self, *args, **kwargs):
-        """Disable inner logging management."""
+    def log_error(self, format, *args):
+        msg = format % args
+        logger.error("An error occurred during request: %s" % msg)
 
     def get_environ(self):
         env = super().get_environ()
@@ -228,12 +232,24 @@ class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
         return env
 
     def handle(self):
-        super().handle()
-        # Log exception
-        error = self.error_stream.getvalue().strip("\n")
-        if error:
-            logger.error(
-                "An unhandled exception occurred during request:\n%s" % error)
+        """Copy of WSGIRequestHandler.handle with different ServerHandler"""
+
+        self.raw_requestline = self.rfile.readline(65537)
+        if len(self.raw_requestline) > 65536:
+            self.requestline = ''
+            self.request_version = ''
+            self.command = ''
+            self.send_error(414)
+            return
+
+        if not self.parse_request():
+            return
+
+        handler = ServerHandler(
+            self.rfile, self.wfile, self.get_stderr(), self.get_environ()
+        )
+        handler.request_handler = self
+        handler.run(self.server.get_app())
 
 
 class Application:
@@ -323,26 +339,28 @@ class Application:
         return read_allowed_items, write_allowed_items
 
     def __call__(self, environ, start_response):
-        try:
-            status, headers, answers = self._handle_request(environ)
-        except Exception as e:
+        with log.register_stream(environ["wsgi.errors"]):
             try:
-                method = str(environ["REQUEST_METHOD"])
-            except Exception:
-                method = "unknown"
-            try:
-                path = str(environ.get("PATH_INFO", ""))
-            except Exception:
-                path = ""
-            logger.error("An exception occurred during %s request on %r: "
-                         "%s", method, path, e, exc_info=True)
-            status, headers, answer = INTERNAL_SERVER_ERROR
-            answer = answer.encode("ascii")
-            status = "%d %s" % (
-                status, client.responses.get(status, "Unknown"))
-            headers = [("Content-Length", str(len(answer)))] + list(headers)
-            answers = [answer]
-        start_response(status, headers)
+                status, headers, answers = self._handle_request(environ)
+            except Exception as e:
+                try:
+                    method = str(environ["REQUEST_METHOD"])
+                except Exception:
+                    method = "unknown"
+                try:
+                    path = str(environ.get("PATH_INFO", ""))
+                except Exception:
+                    path = ""
+                logger.error("An exception occurred during %s request on %r: "
+                             "%s", method, path, e, exc_info=True)
+                status, headers, answer = INTERNAL_SERVER_ERROR
+                answer = answer.encode("ascii")
+                status = "%d %s" % (
+                    status, client.responses.get(status, "Unknown"))
+                headers = [
+                    ("Content-Length", str(len(answer)))] + list(headers)
+                answers = [answer]
+            start_response(status, headers)
         return answers
 
     def _handle_request(self, environ):
@@ -990,24 +1008,25 @@ _application_config_path = None
 _application_lock = threading.Lock()
 
 
-def _init_application(config_path):
+def _init_application(config_path, wsgi_errors):
     global _application, _application_config_path
     with _application_lock:
         if _application is not None:
             return
         log.setup()
-        _application_config_path = config_path
-        configuration = config.load([config_path] if config_path else [],
-                                    ignore_missing_paths=False)
-        log.set_debug(configuration.getboolean("logging", "debug"))
-        _application = Application(configuration)
+        with log.register_stream(wsgi_errors):
+            _application_config_path = config_path
+            configuration = config.load([config_path] if config_path else [],
+                                        ignore_missing_paths=False)
+            log.set_debug(configuration.getboolean("logging", "debug"))
+            _application = Application(configuration)
 
 
 def application(environ, start_response):
     config_path = environ.get("RADICALE_CONFIG",
                               os.environ.get("RADICALE_CONFIG"))
     if _application is None:
-        _init_application(config_path)
+        _init_application(config_path, environ["wsgi.errors"])
     if _application_config_path != config_path:
         raise ValueError("RADICALE_CONFIG must not change: %s != %s" %
                          (repr(config_path), repr(_application_config_path)))

+ 50 - 1
radicale/log.py

@@ -22,6 +22,7 @@ http://docs.python.org/library/logging.config.html
 
 """
 
+import contextlib
 import logging
 import sys
 import threading
@@ -43,16 +44,64 @@ class RemoveTracebackFilter(logging.Filter):
 removeTracebackFilter = RemoveTracebackFilter()
 
 
+class ThreadStreamsHandler(logging.Handler):
+
+    terminator = "\n"
+
+    def __init__(self, fallback_stream, fallback_handler):
+        super().__init__()
+        self._streams = {}
+        self.fallback_stream = fallback_stream
+        self.fallback_handler = fallback_handler
+
+    def setFormatter(self, form):
+        super().setFormatter(form)
+        self.fallback_handler.setFormatter(form)
+
+    def emit(self, record):
+        try:
+            stream = self._streams.get(threading.get_ident())
+            if stream is None:
+                self.fallback_handler.emit(record)
+            else:
+                msg = self.format(record)
+                stream.write(msg)
+                stream.write(self.terminator)
+                if hasattr(stream, "flush"):
+                    stream.flush()
+        except Exception:
+            self.handleError(record)
+
+    @contextlib.contextmanager
+    def register_stream(self, stream):
+        if stream == self.fallback_stream:
+            yield
+            return
+        key = threading.get_ident()
+        self._streams[key] = stream
+        try:
+            yield
+        finally:
+            del self._streams[key]
+
+
 def get_default_handler():
     handler = logging.StreamHandler(sys.stderr)
     return handler
 
 
+@contextlib.contextmanager
+def register_stream(stream):
+    """Register global errors stream for the current thread."""
+    yield
+
+
 def setup():
     """Set global logging up."""
     global register_stream, unregister_stream
-    handler = get_default_handler()
+    handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
     logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
+    register_stream = handler.register_stream
     set_debug(True)
 
 

+ 1 - 0
radicale/tests/__init__.py

@@ -43,6 +43,7 @@ class BaseTest:
             data = data.encode("utf-8")
             args["wsgi.input"] = BytesIO(data)
             args["CONTENT_LENGTH"] = str(len(data))
+        args["wsgi.errors"] = sys.stderr
         self.application._answer = self.application(args, self.start_response)
 
         return (