Просмотр исходного кода

Use forking for internal server when available

Unrud 7 лет назад
Родитель
Сommit
30a9ecc06b
2 измененных файлов с 60 добавлено и 27 удалено
  1. 28 1
      radicale/log.py
  2. 32 26
      radicale/server.py

+ 28 - 1
radicale/log.py

@@ -25,6 +25,7 @@ http://docs.python.org/library/logging.config.html
 import contextlib
 import io
 import logging
+import multiprocessing
 import os
 import sys
 import threading
@@ -35,7 +36,7 @@ except ImportError:
     journal = None
 
 LOGGER_NAME = "radicale"
-LOGGER_FORMAT = "[%(processName)s/%(threadName)s] %(levelname)s: %(message)s"
+LOGGER_FORMAT = "[%(ident)s] %(levelname)s: %(message)s"
 
 root_logger = logging.getLogger()
 logger = logging.getLogger(LOGGER_NAME)
@@ -50,6 +51,27 @@ class RemoveTracebackFilter(logging.Filter):
 removeTracebackFilter = RemoveTracebackFilter()
 
 
+class IdentLogRecordFactory:
+    """LogRecordFactory that adds ``ident`` attribute."""
+
+    def __init__(self, upstream_factory):
+        self.upstream_factory = upstream_factory
+        self.main_pid = os.getpid()
+        self.main_thread_name = threading.current_thread().name
+
+    def __call__(self, *args, **kwargs):
+        record = self.upstream_factory(*args, **kwargs)
+        pid = os.getpid()
+        thread_name = threading.current_thread().name
+        ident = "%x" % self.main_pid
+        if pid != self.main_pid:
+            ident += "%+x" % (pid - self.main_pid)
+        if thread_name != self.main_thread_name:
+            ident += "/%s" % thread_name
+        record.ident = ident
+        return record
+
+
 class ThreadStreamsHandler(logging.Handler):
 
     terminator = "\n"
@@ -60,6 +82,9 @@ class ThreadStreamsHandler(logging.Handler):
         self.fallback_stream = fallback_stream
         self.fallback_handler = fallback_handler
 
+    def createLock(self):
+        self.lock = multiprocessing.Lock()
+
     def setFormatter(self, form):
         super().setFormatter(form)
         self.fallback_handler.setFormatter(form)
@@ -116,6 +141,8 @@ def setup():
     handler = ThreadStreamsHandler(sys.stderr, get_default_handler())
     logging.basicConfig(format=LOGGER_FORMAT, handlers=[handler])
     register_stream = handler.register_stream
+    log_record_factory = IdentLogRecordFactory(logging.getLogRecordFactory())
+    logging.setLogRecordFactory(log_record_factory)
     set_level(logging.DEBUG)
 
 

+ 32 - 26
radicale/server.py

@@ -22,6 +22,7 @@ Radicale WSGI server.
 """
 
 import contextlib
+import multiprocessing
 import os
 import select
 import signal
@@ -29,16 +30,20 @@ import socket
 import socketserver
 import ssl
 import sys
-import threading
 import wsgiref.simple_server
 from urllib.parse import unquote
 
 from radicale import Application
 from radicale.log import logger
 
+if hasattr(socketserver, "ForkingMixIn"):
+    ParallelizationMixIn = socketserver.ForkingMixIn
+else:
+    ParallelizationMixIn = socketserver.ThreadingMixIn
 
-class HTTPServer(wsgiref.simple_server.WSGIServer):
-    """HTTP server."""
+
+class ParallelHTTPServer(ParallelizationMixIn,
+                         wsgiref.simple_server.WSGIServer):
 
     # These class attributes must be set before creating instance
     client_timeout = None
@@ -59,7 +64,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
             self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
 
         if self.max_connections:
-            self.connections_guard = threading.BoundedSemaphore(
+            self.connections_guard = multiprocessing.BoundedSemaphore(
                 self.max_connections)
         else:
             # use dummy context manager
@@ -75,10 +80,14 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
 
     def get_request(self):
         # Set timeout for client
-        _socket, address = super().get_request()
+        socket_, address = super().get_request()
         if self.client_timeout:
-            _socket.settimeout(self.client_timeout)
-        return _socket, address
+            socket_.settimeout(self.client_timeout)
+        return socket_, address
+
+    def finish_request(self, request, client_address):
+        with self.connections_guard:
+            return super().finish_request(request, client_address)
 
     def handle_error(self, request, client_address):
         if issubclass(sys.exc_info()[0], socket.timeout):
@@ -88,8 +97,7 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
                          sys.exc_info()[1], exc_info=True)
 
 
-class HTTPSServer(HTTPServer):
-    """HTTPS server."""
+class ParallelHTTPSServer(ParallelHTTPServer):
 
     # These class attributes must be set before creating instance
     certificate = None
@@ -98,9 +106,11 @@ class HTTPSServer(HTTPServer):
     ciphers = None
     certificate_authority = None
 
-    def __init__(self, address, handler):
+    def __init__(self, address, handler, bind_and_activate=True):
         """Create server by wrapping HTTP socket in an SSL socket."""
-        super().__init__(address, handler, bind_and_activate=False)
+
+        # Do not bind and activate, as we change the socket
+        super().__init__(address, handler, False)
 
         self.socket = ssl.wrap_socket(
             self.socket, self.key, self.certificate, server_side=True,
@@ -110,18 +120,15 @@ class HTTPSServer(HTTPServer):
             ssl_version=self.protocol, ciphers=self.ciphers,
             do_handshake_on_connect=False)
 
-        self.server_bind()
-        self.server_activate()
-
-
-class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
-    def process_request_thread(self, request, client_address):
-        with self.connections_guard:
-            return super().process_request_thread(request, client_address)
-
+        if bind_and_activate:
+            try:
+                self.server_bind()
+                self.server_activate()
+            except BaseException:
+                self.server_close()
+                raise
 
-class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
-    def process_request_thread(self, request, client_address):
+    def finish_request(self, request, client_address):
         try:
             try:
                 request.do_handshake()
@@ -135,8 +142,7 @@ class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
             finally:
                 self.shutdown_request(request)
             return
-        with self.connections_guard:
-            return super().process_request_thread(request, client_address)
+        return super().finish_request(request, client_address)
 
 
 class ServerHandler(wsgiref.simple_server.ServerHandler):
@@ -197,7 +203,7 @@ def serve(configuration):
     # Create collection servers
     servers = {}
     if configuration.getboolean("server", "ssl"):
-        server_class = ThreadedHTTPSServer
+        server_class = ParallelHTTPSServer
         server_class.certificate = configuration.get("server", "certificate")
         server_class.key = configuration.get("server", "key")
         server_class.certificate_authority = configuration.get(
@@ -216,7 +222,7 @@ def serve(configuration):
                 raise RuntimeError("Failed to read SSL %s %r: %s" %
                                    (name, filename, e)) from e
     else:
-        server_class = ThreadedHTTPServer
+        server_class = ParallelHTTPServer
     server_class.client_timeout = configuration.getint("server", "timeout")
     server_class.max_connections = configuration.getint(
         "server", "max_connections")