Просмотр исходного кода

Add missing UIDs instead of failing

Unrud 8 лет назад
Родитель
Сommit
2b3fd1fb9b
3 измененных файлов с 175 добавлено и 23 удалено
  1. 4 2
      radicale/__init__.py
  2. 83 21
      radicale/storage.py
  3. 88 0
      radicale/tests/test_base.py

+ 4 - 2
radicale/__init__.py

@@ -815,8 +815,10 @@ class Application:
 
 
             try:
             try:
                 items = list(vobject.readComponents(content or ""))
                 items = list(vobject.readComponents(content or ""))
-                for item in items:
-                    storage.check_item(item)
+                for i in items:
+                    storage.check_and_sanitize_item(
+                        i, is_collection=write_whole_collection, uid=item.uid
+                        if not write_whole_collection and item else None)
             except Exception as e:
             except Exception as e:
                 self.logger.warning(
                 self.logger.warning(
                     "Bad PUT request on %r: %s", path, e, exc_info=True)
                     "Bad PUT request on %r: %s", path, e, exc_info=True)

+ 83 - 21
radicale/storage.py

@@ -116,14 +116,44 @@ def load(configuration, logger):
     return CollectionCopy
     return CollectionCopy
 
 
 
 
-def check_item(vobject_item):
-    """Check vobject items for common errors."""
+def check_and_sanitize_item(vobject_item, is_collection=False, uid=None):
+    """Check vobject items for common errors and add missing UIDs.
+
+    ``multiple`` indicates that the vobject_item contains unrelated components.
+
+    If ``uid`` is not set, the UID is generated randomly.
+
+    """
     if vobject_item.name == "VCALENDAR":
     if vobject_item.name == "VCALENDAR":
+        component_name = None
+        object_uid = None
+        object_uid_set = False
         for component in vobject_item.components():
         for component in vobject_item.components():
-            if component.name not in ("VTODO", "VEVENT", "VJOURNAL"):
+            # https://tools.ietf.org/html/rfc4791#section-4.1
+            if component.name == "VTIMEZONE":
+                continue
+            if component_name is None or is_collection:
+                component_name = component.name
+            elif component_name != component.name:
+                raise ValueError("Muliple component types in object: %r, %r" %
+                                 (component_name, component.name))
+            if component_name not in ("VTODO", "VEVENT", "VJOURNAL"):
                 continue
                 continue
-            if not get_uid(component):
-                raise ValueError("UID in %s is missing" % component.name)
+            component_uid = get_uid(component)
+            if not object_uid_set or is_collection:
+                object_uid_set = True
+                object_uid = component_uid
+                if component_uid is None:
+                    component.add("UID").value = uid or random_uuid4()
+                elif not component_uid:
+                    component.uid.value = uid or random_uuid4()
+            elif not object_uid or not component_uid:
+                raise ValueError("Multiple %s components without UID in "
+                                 "object" % component_name)
+            elif object_uid != component_uid:
+                raise ValueError(
+                    "Muliple %s components with different UIDs in object: "
+                    "%r, %r" % (component_name, object_uid, component_uid))
             # vobject interprets recurrence rules on demand
             # vobject interprets recurrence rules on demand
             try:
             try:
                 component.rruleset
                 component.rruleset
@@ -131,8 +161,12 @@ def check_item(vobject_item):
                 raise ValueError("invalid recurrence rules in %s" %
                 raise ValueError("invalid recurrence rules in %s" %
                                  component.name) from e
                                  component.name) from e
     elif vobject_item.name == "VCARD":
     elif vobject_item.name == "VCARD":
-        if not get_uid(vobject_item):
-            raise ValueError("UID in VCARD is missing")
+        # https://tools.ietf.org/html/rfc6352#section-5.1
+        object_uid = get_uid(vobject_item)
+        if object_uid is None:
+            vobject_item.add("UID").value = uid or random_uuid4()
+        elif not object_uid:
+            vobject_item.uid.value = uid or random_uuid4()
     else:
     else:
         raise ValueError("Unknown item type: %r" % vobject_item.name)
         raise ValueError("Unknown item type: %r" % vobject_item.name)
 
 
@@ -175,9 +209,24 @@ def get_etag(text):
     return '"%s"' % etag.hexdigest()
     return '"%s"' % etag.hexdigest()
 
 
 
 
-def get_uid(item):
+def get_uid(vobject_component):
     """UID value of an item if defined."""
     """UID value of an item if defined."""
-    return (hasattr(item, "uid") or None) and item.uid.value
+    return ((hasattr(vobject_component, "uid") or None) and
+            vobject_component.uid.value)
+
+
+def get_uid_from_object(vobject_item):
+    """UID value of an calendar/addressbook object."""
+    if vobject_item.name == "VCALENDAR":
+        if hasattr(vobject_item, "vevent"):
+            return get_uid(vobject_item.vevent)
+        if hasattr(vobject_item, "vjournal"):
+            return get_uid(vobject_item.vjournal)
+        if hasattr(vobject_item, "vtodo"):
+            return get_uid(vobject_item.vtodo)
+    elif vobject_item.name == "VCARD":
+        return get_uid(vobject_item)
+    return None
 
 
 
 
 def sanitize_path(path):
 def sanitize_path(path):
@@ -272,7 +321,7 @@ class ComponentNotFoundError(ValueError):
 
 
 class Item:
 class Item:
     def __init__(self, collection, item=None, href=None, last_modified=None,
     def __init__(self, collection, item=None, href=None, last_modified=None,
-                 text=None, etag=None):
+                 text=None, etag=None, uid=None):
         """Initialize an item.
         """Initialize an item.
 
 
         ``collection`` the parent collection.
         ``collection`` the parent collection.
@@ -288,6 +337,8 @@ class Item:
 
 
         ``etag`` the etag of the item (optional). See ``get_etag``.
         ``etag`` the etag of the item (optional). See ``get_etag``.
 
 
+        ``uid`` the UID of the object (optional). See ``get_uid_from_object``.
+
         """
         """
         if text is None and item is None:
         if text is None and item is None:
             raise ValueError("at least one of 'text' or 'item' must be set")
             raise ValueError("at least one of 'text' or 'item' must be set")
@@ -297,6 +348,7 @@ class Item:
         self._text = text
         self._text = text
         self._item = item
         self._item = item
         self._etag = etag
         self._etag = etag
+        self._uid = uid
 
 
     def __getattr__(self, attr):
     def __getattr__(self, attr):
         return getattr(self.item, attr)
         return getattr(self.item, attr)
@@ -323,6 +375,12 @@ class Item:
             self._etag = get_etag(self.serialize())
             self._etag = get_etag(self.serialize())
         return self._etag
         return self._etag
 
 
+    @property
+    def uid(self):
+        if self._uid is None:
+            self._uid = get_uid_from_object(self.item)
+        return self._uid
+
 
 
 class BaseCollection:
 class BaseCollection:
 
 
@@ -1034,21 +1092,23 @@ class Collection(BaseCollection):
         input_hash = input_hash.hexdigest()
         input_hash = input_hash.hexdigest()
         cache_folder = os.path.join(self._filesystem_path, ".Radicale.cache",
         cache_folder = os.path.join(self._filesystem_path, ".Radicale.cache",
                                     "item")
                                     "item")
+        cinput_hash = cuid = cetag = ctext = ctag = cstart = cend = None
         try:
         try:
             with open(os.path.join(cache_folder, href), "rb") as f:
             with open(os.path.join(cache_folder, href), "rb") as f:
-                cinput_hash, cetag, ctext, ctag, cstart, cend = pickle.load(f)
-        except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
-            if isinstance(e, (pickle.UnpicklingError, ValueError)):
-                self.logger.warning(
-                    "Failed to load item cache entry %r in %r: %s",
-                    href, self.path, e, exc_info=True)
-            cinput_hash = cetag = ctext = ctag = cstart = cend = None
+                cinput_hash, cuid, cetag, ctext, ctag, cstart, cend = \
+                    pickle.load(f)
+        except FileNotFoundError as e:
+            pass
+        except (pickle.UnpicklingError, ValueError) as e:
+            self.logger.warning(
+                "Failed to load item cache entry %r in %r: %s",
+                href, self.path, e, exc_info=True)
         vobject_item = None
         vobject_item = None
         if input_hash != cinput_hash:
         if input_hash != cinput_hash:
             try:
             try:
                 vobject_item = Item(self, href=href,
                 vobject_item = Item(self, href=href,
                                     text=btext.decode(self.encoding)).item
                                     text=btext.decode(self.encoding)).item
-                check_item(vobject_item)
+                check_and_sanitize_item(vobject_item, uid=cuid)
             except Exception as e:
             except Exception as e:
                 raise RuntimeError("Failed to load item %r in %r: %s" %
                 raise RuntimeError("Failed to load item %r in %r: %s" %
                                    (href, self.path, e)) from e
                                    (href, self.path, e)) from e
@@ -1056,6 +1116,7 @@ class Collection(BaseCollection):
             # The storage may have been edited externally.
             # The storage may have been edited externally.
             ctext = vobject_item.serialize()
             ctext = vobject_item.serialize()
             cetag = get_etag(ctext)
             cetag = get_etag(ctext)
+            cuid = get_uid_from_object(vobject_item)
             try:
             try:
                 try:
                 try:
                     ctag, cstart, cend = xmlutils.find_tag_and_time_range(
                     ctag, cstart, cend = xmlutils.find_tag_and_time_range(
@@ -1080,7 +1141,7 @@ class Collection(BaseCollection):
                 # file.
                 # file.
                 with self._atomic_write(os.path.join(cache_folder, href),
                 with self._atomic_write(os.path.join(cache_folder, href),
                                         "wb") as f:
                                         "wb") as f:
-                    pickle.dump((input_hash, cetag, ctext,
+                    pickle.dump((input_hash, cuid, cetag, ctext,
                                  ctag, cstart, cend), f)
                                  ctag, cstart, cend), f)
             except PermissionError:
             except PermissionError:
                 pass
                 pass
@@ -1095,8 +1156,9 @@ class Collection(BaseCollection):
         last_modified = time.strftime(
         last_modified = time.strftime(
             "%a, %d %b %Y %H:%M:%S GMT",
             "%a, %d %b %Y %H:%M:%S GMT",
             time.gmtime(os.path.getmtime(path)))
             time.gmtime(os.path.getmtime(path)))
-        return Item(self, href=href, last_modified=last_modified, etag=cetag,
-                    text=ctext, item=vobject_item), (ctag, cstart, cend)
+        return Item(
+            self, href=href, last_modified=last_modified, etag=cetag,
+            text=ctext, item=vobject_item, uid=cuid), (ctag, cstart, cend)
 
 
     def get_multi2(self, hrefs):
     def get_multi2(self, hrefs):
         # It's faster to check for file name collissions here, because
         # It's faster to check for file name collissions here, because

+ 88 - 0
radicale/tests/test_base.py

@@ -75,6 +75,30 @@ class BaseRequestsMixIn:
         assert "Event" in answer
         assert "Event" in answer
         assert "UID:event" in answer
         assert "UID:event" in answer
 
 
+    def test_add_event_without_uid(self):
+        """Add an event without UID."""
+        status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
+        assert status == 201
+        event = get_file_content("event1.ics").replace("UID:event1\n", "")
+        assert "\nUID:" not in event
+        path = "/calendar.ics/event.ics"
+        status, headers, answer = self.request("PUT", path, event)
+        assert status == 201
+        status, headers, answer = self.request("GET", path)
+        assert status == 200
+        uids = []
+        for line in answer.split("\r\n"):
+            if line.startswith("UID:"):
+                uids.append(line[len("UID:"):])
+        assert len(uids) == 1 and uids[0]
+        # Overwrite the event with an event without UID and check that the UID
+        # is still the same
+        status, _, _ = self.request("PUT", path, event)
+        assert status == 201
+        status, _, answer = self.request("GET", path)
+        assert status == 200
+        assert "\r\nUID:%s\r\n" % uids[0] in answer
+
     def test_add_todo(self):
     def test_add_todo(self):
         """Add a todo."""
         """Add a todo."""
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
@@ -121,6 +145,31 @@ class BaseRequestsMixIn:
         assert status == 200
         assert status == 200
         assert "UID:contact1" in answer
         assert "UID:contact1" in answer
 
 
+    def test_add_contact_without_uid(self):
+        """Add a contact."""
+        status, _, _ = self._create_addressbook("/contacts.vcf/")
+        assert status == 201
+        contact = get_file_content("contact1.vcf").replace("UID:contact1\n",
+                                                           "")
+        assert "\nUID" not in contact
+        path = "/contacts.vcf/contact.vcf"
+        status, _, _ = self.request("PUT", path, contact)
+        assert status == 201
+        status, _, answer = self.request("GET", path)
+        assert status == 200
+        uids = []
+        for line in answer.split("\r\n"):
+            if line.startswith("UID:"):
+                uids.append(line[len("UID:"):])
+        assert len(uids) == 1 and uids[0]
+        # Overwrite the contact with an contact without UID and check that the
+        # UID is still the same
+        status, headers, answer = self.request("PUT", path, contact)
+        assert status == 201
+        status, headers, answer = self.request("GET", path)
+        assert status == 200
+        assert "\r\nUID:%s\r\n" % uids[0] in answer
+
     def test_update(self):
     def test_update(self):
         """Update an event."""
         """Update an event."""
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
@@ -176,6 +225,25 @@ class BaseRequestsMixIn:
         assert "\r\nUID:event\r\n" in answer and "\r\nUID:todo\r\n" in answer
         assert "\r\nUID:event\r\n" in answer and "\r\nUID:todo\r\n" in answer
         assert "\r\nUID:event1\r\n" not in answer
         assert "\r\nUID:event1\r\n" not in answer
 
 
+    def test_put_whole_calendar_without_uids(self):
+        """Create a whole calendar without UID."""
+        event = get_file_content("event_multiple.ics")
+        event = event.replace("UID:event\n", "").replace("UID:todo\n", "")
+        assert "\nUID:" not in event
+        status, _, _ = self.request("PUT", "/calendar.ics/", event)
+        assert status == 201
+        status, _, answer = self.request("GET", "/calendar.ics")
+        assert status == 200
+        uids = []
+        for line in answer.split("\r\n"):
+            if line.startswith("UID:"):
+                uids.append(line[len("UID:"):])
+        assert len(uids) == 2
+        for i, uid1 in enumerate(uids):
+            assert uid1
+            for uid2 in uids[i + 1:]:
+                assert uid1 != uid2
+
     def test_put_whole_addressbook(self):
     def test_put_whole_addressbook(self):
         """Create and overwrite a whole addressbook."""
         """Create and overwrite a whole addressbook."""
         contacts = get_file_content("contact_multiple.vcf")
         contacts = get_file_content("contact_multiple.vcf")
@@ -186,6 +254,26 @@ class BaseRequestsMixIn:
         assert ("\r\nUID:contact1\r\n" in answer and
         assert ("\r\nUID:contact1\r\n" in answer and
                 "\r\nUID:contact2\r\n" in answer)
                 "\r\nUID:contact2\r\n" in answer)
 
 
+    def test_put_whole_addressbook_without_uids(self):
+        """Create a whole addressbook without UID."""
+        contacts = get_file_content("contact_multiple.vcf")
+        contacts = contacts.replace("UID:contact1\n", "").replace(
+            "UID:contact2\n", "")
+        assert "\nUID:" not in contacts
+        status, _, _ = self.request("PUT", "/contacts.vcf/", contacts)
+        assert status == 201
+        status, _, answer = self.request("GET", "/contacts.vcf")
+        assert status == 200
+        uids = []
+        for line in answer.split("\r\n"):
+            if line.startswith("UID:"):
+                uids.append(line[len("UID:"):])
+        assert len(uids) == 2
+        for i, uid1 in enumerate(uids):
+            assert uid1
+            for uid2 in uids[i + 1:]:
+                assert uid1 != uid2
+
     def test_delete(self):
     def test_delete(self):
         """Delete an event."""
         """Delete an event."""
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")
         status, _, _ = self.request("MKCALENDAR", "/calendar.ics/")