ソースを参照

Merge pull request #428 from Unrud/patch-22

Add timeout to connections, limit size of request body and limit number of parallel connections
Guillaume Ayoub 9 年 前
コミット
ef63865e31
4 ファイル変更65 行追加18 行削除
  1. 9 0
      config
  2. 50 18
      radicale/__init__.py
  3. 3 0
      radicale/__main__.py
  4. 3 0
      radicale/config.py

+ 9 - 0
config

@@ -24,6 +24,15 @@
 # File storing the PID in daemon mode
 #pid =
 
+# Max parallel connections
+#max_connections = 20
+
+# Max size of request body (bytes)
+#max_content_length = 10000000
+
+# Socket timeout (seconds)
+#timeout = 10
+
 # SSL flag, enable HTTPS protocol
 #ssl = False
 

+ 50 - 18
radicale/__init__.py

@@ -29,9 +29,11 @@ should have been included in this package.
 import os
 import pprint
 import base64
+import contextlib
 import socket
 import socketserver
 import ssl
+import threading
 import wsgiref.simple_server
 import re
 import zlib
@@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$")
 
 class HTTPServer(wsgiref.simple_server.WSGIServer):
     """HTTP server."""
+
+    # These class attributes must be set before creating instance
+    client_timeout = None
+    max_connections = None
+
     def __init__(self, address, handler, bind_and_activate=True):
         """Create server."""
         ipv6 = ":" in address[0]
@@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
             self.server_bind()
             self.server_activate()
 
+        if self.max_connections:
+            self.connections_guard = threading.BoundedSemaphore(
+                self.max_connections)
+        else:
+            # use dummy context manager
+            self.connections_guard = contextlib.suppress()
+
+    def get_request(self):
+        # Set timeout for client
+        _socket, address = super().get_request()
+        if self.client_timeout:
+            _socket.settimeout(self.client_timeout)
+        return _socket, address
+
 
 class HTTPSServer(HTTPServer):
     """HTTPS server."""
@@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer):
 
 
 class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
-    pass
+    def process_request_thread(self, request, client_address):
+        with self.connections_guard:
+            return super().process_request_thread(request, client_address)
 
 
 class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
-    pass
+    def process_request_thread(self, request, client_address):
+        with self.connections_guard:
+            return super().process_request_thread(request, client_address)
 
 
 class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
@@ -218,6 +243,15 @@ class Application:
 
     def __call__(self, environ, start_response):
         """Manage a request."""
+        def response(status, headers={}, answer=None):
+            # Start response
+            status = "%i %s" % (status,
+                                client.responses.get(status, "Unknown"))
+            self.logger.debug("Answer status: %s" % status)
+            start_response(status, list(headers.items()))
+            # Return response content
+            return [answer] if answer else []
+
         self.logger.info("%s request at %s received" % (
             environ["REQUEST_METHOD"], environ["PATH_INFO"]))
         headers = pprint.pformat(self.headers_log(environ))
@@ -234,9 +268,7 @@ class Application:
             # Request path not starting with base_prefix, not allowed
             self.logger.debug(
                 "Path not starting with prefix: %s", environ["PATH_INFO"])
-            status, headers, _ = NOT_ALLOWED
-            start_response(status, list(headers.items()))
-            return []
+            return response(*NOT_ALLOWED)
 
         # Sanitize request URI
         environ["PATH_INFO"] = storage.sanitize_path(
@@ -275,10 +307,7 @@ class Application:
                 status = client.SEE_OTHER
                 self.logger.info("/.well-known/ redirection to: %s" % redirect)
                 headers = {"Location": redirect}
-            status = "%i %s" % (
-                status, client.responses.get(status, "Unknown"))
-            start_response(status, list(headers.items()))
-            return []
+            return response(status, headers)
 
         is_authenticated = self.is_authenticated(user, password)
         is_valid_user = is_authenticated or not user
@@ -286,8 +315,17 @@ class Application:
         # Get content
         content_length = int(environ.get("CONTENT_LENGTH") or 0)
         if content_length:
-            content = self.decode(
-                environ["wsgi.input"].read(content_length), environ)
+            max_content_length = self.configuration.getint(
+                "server", "max_content_length")
+            if max_content_length and content_length > max_content_length:
+                self.logger.debug(
+                    "Request body too large: %d", content_length)
+                return response(client.REQUEST_ENTITY_TOO_LARGE)
+            try:
+                content = self.decode(
+                    environ["wsgi.input"].read(content_length), environ)
+            except socket.timeout:
+                return response(client.REQUEST_TIMEOUT)
             self.logger.debug("Request content:\n%s" % content)
         else:
             content = None
@@ -345,13 +383,7 @@ class Application:
             for key in self.configuration.options("headers"):
                 headers[key] = self.configuration.get("headers", key)
 
-        # Start response
-        status = "%i %s" % (status, client.responses.get(status, "Unknown"))
-        self.logger.debug("Answer status: %s" % status)
-        start_response(status, list(headers.items()))
-
-        # Return response content
-        return [answer] if answer else []
+        return response(status, headers, answer)
 
     # All these functions must have the same parameters, some are useless
     # pylint: disable=W0612,W0613,R0201

+ 3 - 0
radicale/__main__.py

@@ -175,6 +175,9 @@ def serve(configuration, logger):
                         name, filename, exception))
     else:
         server_class = ThreadedHTTPServer
+    server_class.client_timeout = configuration.getint("server", "timeout")
+    server_class.max_connections = configuration.getint("server",
+                                                        "max_connections")
 
     if not configuration.getboolean("server", "dns_lookup"):
         RequestHandler.address_string = lambda self: self.client_address[0]

+ 3 - 0
radicale/config.py

@@ -32,6 +32,9 @@ INITIAL_CONFIG = {
         "hosts": "0.0.0.0:5232",
         "daemon": "False",
         "pid": "",
+        "max_connections": "20",
+        "max_content_length": "10000000",
+        "timeout": "10",
         "ssl": "False",
         "certificate": "/etc/apache2/ssl/server.crt",
         "key": "/etc/apache2/ssl/server.key",