|
@@ -29,9 +29,11 @@ should have been included in this package.
|
|
|
import os
|
|
import os
|
|
|
import pprint
|
|
import pprint
|
|
|
import base64
|
|
import base64
|
|
|
|
|
+import contextlib
|
|
|
import socket
|
|
import socket
|
|
|
import socketserver
|
|
import socketserver
|
|
|
import ssl
|
|
import ssl
|
|
|
|
|
+import threading
|
|
|
import wsgiref.simple_server
|
|
import wsgiref.simple_server
|
|
|
import re
|
|
import re
|
|
|
import zlib
|
|
import zlib
|
|
@@ -54,6 +56,11 @@ WELL_KNOWN_RE = re.compile(r"/\.well-known/(carddav|caldav)/?$")
|
|
|
|
|
|
|
|
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|
class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|
|
"""HTTP server."""
|
|
"""HTTP server."""
|
|
|
|
|
+
|
|
|
|
|
+ # These class attributes must be set before creating instance
|
|
|
|
|
+ client_timeout = None
|
|
|
|
|
+ max_connections = None
|
|
|
|
|
+
|
|
|
def __init__(self, address, handler, bind_and_activate=True):
|
|
def __init__(self, address, handler, bind_and_activate=True):
|
|
|
"""Create server."""
|
|
"""Create server."""
|
|
|
ipv6 = ":" in address[0]
|
|
ipv6 = ":" in address[0]
|
|
@@ -72,6 +79,20 @@ class HTTPServer(wsgiref.simple_server.WSGIServer):
|
|
|
self.server_bind()
|
|
self.server_bind()
|
|
|
self.server_activate()
|
|
self.server_activate()
|
|
|
|
|
|
|
|
|
|
+ if self.max_connections:
|
|
|
|
|
+ self.connections_guard = threading.BoundedSemaphore(
|
|
|
|
|
+ self.max_connections)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # use dummy context manager
|
|
|
|
|
+ self.connections_guard = contextlib.suppress()
|
|
|
|
|
+
|
|
|
|
|
+ def get_request(self):
|
|
|
|
|
+ # Set timeout for client
|
|
|
|
|
+ _socket, address = super().get_request()
|
|
|
|
|
+ if self.client_timeout:
|
|
|
|
|
+ _socket.settimeout(self.client_timeout)
|
|
|
|
|
+ return _socket, address
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class HTTPSServer(HTTPServer):
|
|
class HTTPSServer(HTTPServer):
|
|
|
"""HTTPS server."""
|
|
"""HTTPS server."""
|
|
@@ -95,11 +116,15 @@ class HTTPSServer(HTTPServer):
|
|
|
|
|
|
|
|
|
|
|
|
|
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
|
class ThreadedHTTPServer(socketserver.ThreadingMixIn, HTTPServer):
|
|
|
- pass
|
|
|
|
|
|
|
+ def process_request_thread(self, request, client_address):
|
|
|
|
|
+ with self.connections_guard:
|
|
|
|
|
+ return super().process_request_thread(request, client_address)
|
|
|
|
|
|
|
|
|
|
|
|
|
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
|
class ThreadedHTTPSServer(socketserver.ThreadingMixIn, HTTPSServer):
|
|
|
- pass
|
|
|
|
|
|
|
+ def process_request_thread(self, request, client_address):
|
|
|
|
|
+ with self.connections_guard:
|
|
|
|
|
+ return super().process_request_thread(request, client_address)
|
|
|
|
|
|
|
|
|
|
|
|
|
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
|
class RequestHandler(wsgiref.simple_server.WSGIRequestHandler):
|
|
@@ -218,6 +243,15 @@ class Application:
|
|
|
|
|
|
|
|
def __call__(self, environ, start_response):
|
|
def __call__(self, environ, start_response):
|
|
|
"""Manage a request."""
|
|
"""Manage a request."""
|
|
|
|
|
+ def response(status, headers={}, answer=None):
|
|
|
|
|
+ # Start response
|
|
|
|
|
+ status = "%i %s" % (status,
|
|
|
|
|
+ client.responses.get(status, "Unknown"))
|
|
|
|
|
+ self.logger.debug("Answer status: %s" % status)
|
|
|
|
|
+ start_response(status, list(headers.items()))
|
|
|
|
|
+ # Return response content
|
|
|
|
|
+ return [answer] if answer else []
|
|
|
|
|
+
|
|
|
self.logger.info("%s request at %s received" % (
|
|
self.logger.info("%s request at %s received" % (
|
|
|
environ["REQUEST_METHOD"], environ["PATH_INFO"]))
|
|
environ["REQUEST_METHOD"], environ["PATH_INFO"]))
|
|
|
headers = pprint.pformat(self.headers_log(environ))
|
|
headers = pprint.pformat(self.headers_log(environ))
|
|
@@ -234,9 +268,7 @@ class Application:
|
|
|
# Request path not starting with base_prefix, not allowed
|
|
# Request path not starting with base_prefix, not allowed
|
|
|
self.logger.debug(
|
|
self.logger.debug(
|
|
|
"Path not starting with prefix: %s", environ["PATH_INFO"])
|
|
"Path not starting with prefix: %s", environ["PATH_INFO"])
|
|
|
- status, headers, _ = NOT_ALLOWED
|
|
|
|
|
- start_response(status, list(headers.items()))
|
|
|
|
|
- return []
|
|
|
|
|
|
|
+ return response(*NOT_ALLOWED)
|
|
|
|
|
|
|
|
# Sanitize request URI
|
|
# Sanitize request URI
|
|
|
environ["PATH_INFO"] = storage.sanitize_path(
|
|
environ["PATH_INFO"] = storage.sanitize_path(
|
|
@@ -275,10 +307,7 @@ class Application:
|
|
|
status = client.SEE_OTHER
|
|
status = client.SEE_OTHER
|
|
|
self.logger.info("/.well-known/ redirection to: %s" % redirect)
|
|
self.logger.info("/.well-known/ redirection to: %s" % redirect)
|
|
|
headers = {"Location": redirect}
|
|
headers = {"Location": redirect}
|
|
|
- status = "%i %s" % (
|
|
|
|
|
- status, client.responses.get(status, "Unknown"))
|
|
|
|
|
- start_response(status, list(headers.items()))
|
|
|
|
|
- return []
|
|
|
|
|
|
|
+ return response(status, headers)
|
|
|
|
|
|
|
|
is_authenticated = self.is_authenticated(user, password)
|
|
is_authenticated = self.is_authenticated(user, password)
|
|
|
is_valid_user = is_authenticated or not user
|
|
is_valid_user = is_authenticated or not user
|
|
@@ -286,8 +315,17 @@ class Application:
|
|
|
# Get content
|
|
# Get content
|
|
|
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
|
content_length = int(environ.get("CONTENT_LENGTH") or 0)
|
|
|
if content_length:
|
|
if content_length:
|
|
|
- content = self.decode(
|
|
|
|
|
- environ["wsgi.input"].read(content_length), environ)
|
|
|
|
|
|
|
+ max_content_length = self.configuration.getint(
|
|
|
|
|
+ "server", "max_content_length")
|
|
|
|
|
+ if max_content_length and content_length > max_content_length:
|
|
|
|
|
+ self.logger.debug(
|
|
|
|
|
+ "Request body too large: %d", content_length)
|
|
|
|
|
+ return response(client.REQUEST_ENTITY_TOO_LARGE)
|
|
|
|
|
+ try:
|
|
|
|
|
+ content = self.decode(
|
|
|
|
|
+ environ["wsgi.input"].read(content_length), environ)
|
|
|
|
|
+ except socket.timeout:
|
|
|
|
|
+ return response(client.REQUEST_TIMEOUT)
|
|
|
self.logger.debug("Request content:\n%s" % content)
|
|
self.logger.debug("Request content:\n%s" % content)
|
|
|
else:
|
|
else:
|
|
|
content = None
|
|
content = None
|
|
@@ -345,13 +383,7 @@ class Application:
|
|
|
for key in self.configuration.options("headers"):
|
|
for key in self.configuration.options("headers"):
|
|
|
headers[key] = self.configuration.get("headers", key)
|
|
headers[key] = self.configuration.get("headers", key)
|
|
|
|
|
|
|
|
- # Start response
|
|
|
|
|
- status = "%i %s" % (status, client.responses.get(status, "Unknown"))
|
|
|
|
|
- self.logger.debug("Answer status: %s" % status)
|
|
|
|
|
- start_response(status, list(headers.items()))
|
|
|
|
|
-
|
|
|
|
|
- # Return response content
|
|
|
|
|
- return [answer] if answer else []
|
|
|
|
|
|
|
+ return response(status, headers, answer)
|
|
|
|
|
|
|
|
# All these functions must have the same parameters, some are useless
|
|
# All these functions must have the same parameters, some are useless
|
|
|
# pylint: disable=W0612,W0613,R0201
|
|
# pylint: disable=W0612,W0613,R0201
|