Parcourir la source

Associate messages with a user

binwiederhier il y a 3 ans
Parent
commit
2b78a8cb51

+ 33 - 15
server/message_cache.go

@@ -40,6 +40,7 @@ const (
 			attachment_expires INT NOT NULL,
 			attachment_url TEXT NOT NULL,
 			sender TEXT NOT NULL,
+			user TEXT NOT NULL,		
 			encoding TEXT NOT NULL,
 			published INT NOT NULL
 		);
@@ -49,46 +50,47 @@ const (
 		COMMIT;
 	`
 	insertMessageQuery = `
-		INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding, published)
-		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+		INSERT INTO messages (mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding, published)
+		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 	`
 	pruneMessagesQuery           = `DELETE FROM messages WHERE time < ? AND published = 1`
 	selectRowIDFromMessageID     = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics
 	selectMessagesSinceTimeQuery = `
-		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
+		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
 		FROM messages 
 		WHERE topic = ? AND time >= ? AND published = 1
 		ORDER BY time, id
 	`
 	selectMessagesSinceTimeIncludeScheduledQuery = `
-		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
+		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
 		FROM messages 
 		WHERE topic = ? AND time >= ?
 		ORDER BY time, id
 	`
 	selectMessagesSinceIDQuery = `
-		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
+		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
 		FROM messages 
 		WHERE topic = ? AND id > ? AND published = 1 
 		ORDER BY time, id
 	`
 	selectMessagesSinceIDIncludeScheduledQuery = `
-		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
+		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
 		FROM messages 
 		WHERE topic = ? AND (id > ? OR published = 0)
 		ORDER BY time, id
 	`
 	selectMessagesDueQuery = `
-		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, encoding
+		SELECT mid, time, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding
 		FROM messages 
 		WHERE time <= ? AND published = 0
 		ORDER BY time, id
 	`
-	updateMessagePublishedQuery     = `UPDATE messages SET published = 1 WHERE mid = ?`
-	selectMessagesCountQuery        = `SELECT COUNT(*) FROM messages`
-	selectMessageCountPerTopicQuery = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
-	selectTopicsQuery               = `SELECT topic FROM messages GROUP BY topic`
-	selectAttachmentsSizeQuery      = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
+	updateMessagePublishedQuery        = `UPDATE messages SET published = 1 WHERE mid = ?`
+	selectMessagesCountQuery           = `SELECT COUNT(*) FROM messages`
+	selectMessageCountPerTopicQuery    = `SELECT topic, COUNT(*) FROM messages GROUP BY topic`
+	selectTopicsQuery                  = `SELECT topic FROM messages GROUP BY topic`
+	selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE sender = ? AND attachment_expires >= ?`
+	selectAttachmentsSizeByUserQuery   = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
 )
 
 // Schema management queries
@@ -316,6 +318,7 @@ func (c *messageCache) addMessages(ms []*message) error {
 			attachmentExpires,
 			attachmentURL,
 			sender,
+			m.User,
 			m.Encoding,
 			published,
 		)
@@ -442,11 +445,23 @@ func (c *messageCache) Prune(olderThan time.Time) error {
 	return nil
 }
 
-func (c *messageCache) AttachmentBytesUsed(sender string) (int64, error) {
-	rows, err := c.db.Query(selectAttachmentsSizeQuery, sender, time.Now().Unix())
+func (c *messageCache) AttachmentBytesUsedBySender(sender string) (int64, error) {
+	rows, err := c.db.Query(selectAttachmentsSizeBySenderQuery, sender, time.Now().Unix())
 	if err != nil {
 		return 0, err
 	}
+	return c.readAttachmentBytesUsed(rows)
+}
+
+func (c *messageCache) AttachmentBytesUsedByUser(user string) (int64, error) {
+	rows, err := c.db.Query(selectAttachmentsSizeByUserQuery, user, time.Now().Unix())
+	if err != nil {
+		return 0, err
+	}
+	return c.readAttachmentBytesUsed(rows)
+}
+
+func (c *messageCache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
 	defer rows.Close()
 	var size int64
 	if !rows.Next() {
@@ -477,7 +492,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 	for rows.Next() {
 		var timestamp, attachmentSize, attachmentExpires int64
 		var priority int
-		var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, encoding string
+		var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string
 		err := rows.Scan(
 			&id,
 			&timestamp,
@@ -495,6 +510,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 			&attachmentExpires,
 			&attachmentURL,
 			&sender,
+			&user,
 			&encoding,
 		)
 		if err != nil {
@@ -538,6 +554,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 			Actions:    actions,
 			Attachment: att,
 			Sender:     senderIP, // Must parse assuming database must be correct
+			User:       user,
 			Encoding:   encoding,
 		})
 	}
@@ -598,6 +615,7 @@ func setupCacheDB(db *sql.DB, startupQueries string) error {
 	} else if schemaVersion == 8 {
 		return migrateFrom8(db)
 	}
+	// TODO add user column
 	return fmt.Errorf("unexpected schema version found: %d", schemaVersion)
 }
 

+ 2 - 2
server/message_cache_test.go

@@ -343,11 +343,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
 	require.Equal(t, "1.2.3.4", messages[1].Sender.String())
 
-	size, err := c.AttachmentBytesUsed("1.2.3.4")
+	size, err := c.AttachmentBytesUsedBySender("1.2.3.4")
 	require.Nil(t, err)
 	require.Equal(t, int64(30000), size)
 
-	size, err = c.AttachmentBytesUsed("5.6.7.8")
+	size, err = c.AttachmentBytesUsedBySender("5.6.7.8")
 	require.Nil(t, err)
 	require.Equal(t, int64(0), size)
 }

+ 6 - 2
server/server.go

@@ -495,6 +495,10 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if m.PollID != "" {
 		m = newPollRequestMessage(t.ID, m.PollID)
 	}
+	if v.user != nil {
+		log.Info("user is %s", v.user.Name)
+		m.User = v.user.Name
+	}
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return nil, err
 	}
@@ -502,8 +506,8 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 		m.Message = emptyMessageBody
 	}
 	delayed := m.Time > time.Now().Unix()
-	log.Debug("%s Received message: event=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
-		logMessagePrefix(v, m), m.Event, len(m.Message), delayed, firebase, cache, unifiedpush, email)
+	log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
+		logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
 	if log.IsTrace() {
 		log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
 	}

+ 9 - 10
server/server_account.go

@@ -75,19 +75,18 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, r *http.Request, v *vis
 				Code:       v.user.Plan.Code,
 				Upgradable: v.user.Plan.Upgradable,
 			}
+		} else if v.user.Role == auth.RoleAdmin {
+			response.Plan = &apiAccountPlan{
+				Code:       string(auth.PlanUnlimited),
+				Upgradable: false,
+			}
 		} else {
-			if v.user.Role == auth.RoleAdmin {
-				response.Plan = &apiAccountPlan{
-					Code:       string(auth.PlanUnlimited),
-					Upgradable: false,
-				}
-			} else {
-				response.Plan = &apiAccountPlan{
-					Code:       string(auth.PlanDefault),
-					Upgradable: true,
-				}
+			response.Plan = &apiAccountPlan{
+				Code:       string(auth.PlanDefault),
+				Upgradable: true,
 			}
 		}
+
 	} else {
 		response.Username = auth.Everyone
 		response.Role = string(auth.RoleAnonymous)

+ 3 - 3
server/server_test.go

@@ -1151,7 +1151,7 @@ func TestServer_PublishAttachment(t *testing.T) {
 	require.Equal(t, "", response.Body.String())
 
 	// Slightly unrelated cross-test: make sure we add an owner for internal attachments
-	size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
+	size, err := s.messageCache.AttachmentBytesUsedBySender("9.9.9.9") // See request()
 	require.Nil(t, err)
 	require.Equal(t, int64(5000), size)
 }
@@ -1180,7 +1180,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
 	require.Equal(t, content, response.Body.String())
 
 	// Slightly unrelated cross-test: make sure we add an owner for internal attachments
-	size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4")
+	size, err := s.messageCache.AttachmentBytesUsedBySender("1.2.3.4")
 	require.Nil(t, err)
 	require.Equal(t, int64(21), size)
 }
@@ -1200,7 +1200,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
 	require.Equal(t, netip.Addr{}, msg.Sender)
 
 	// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
-	size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
+	size, err := s.messageCache.AttachmentBytesUsedBySender("127.0.0.1")
 	require.Nil(t, err)
 	require.Equal(t, int64(0), size)
 }

+ 2 - 1
server/types.go

@@ -36,8 +36,9 @@ type message struct {
 	Actions    []*action   `json:"actions,omitempty"`
 	Attachment *attachment `json:"attachment,omitempty"`
 	PollID     string      `json:"poll_id,omitempty"`
-	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting
 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
+	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting
+	User       string      `json:"-"`                  // Username of the uploader, used to associated attachments
 }
 
 type attachment struct {

+ 16 - 8
server/visitor.go

@@ -151,12 +151,10 @@ func (v *visitor) IncrEmails() {
 }
 
 func (v *visitor) Stats() (*visitorStats, error) {
-	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
-	if err != nil {
-		return nil, err
-	}
 	v.mu.Lock()
-	defer v.mu.Unlock()
+	messages := v.messages
+	emails := v.emails
+	v.mu.Unlock()
 	stats := &visitorStats{}
 	if v.user != nil && v.user.Role == auth.RoleAdmin {
 		stats.Basis = "role"
@@ -174,12 +172,22 @@ func (v *visitor) Stats() (*visitorStats, error) {
 		stats.Basis = "ip"
 		stats.MessagesLimit = replenishDurationToDailyLimit(v.config.VisitorRequestLimitReplenish)
 		stats.EmailsLimit = replenishDurationToDailyLimit(v.config.VisitorEmailLimitReplenish)
-		stats.AttachmentTotalSizeLimit = v.config.AttachmentTotalSizeLimit
+		stats.AttachmentTotalSizeLimit = v.config.VisitorAttachmentTotalSizeLimit
 		stats.AttachmentFileSizeLimit = v.config.AttachmentFileSizeLimit
 	}
-	stats.Messages = v.messages
+	var attachmentsBytesUsed int64
+	var err error
+	if v.user != nil {
+		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.Name)
+	} else {
+		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
+	}
+	if err != nil {
+		return nil, err
+	}
+	stats.Messages = messages
 	stats.MessagesRemaining = zeroIfNegative(stats.MessagesLimit - stats.MessagesLimit)
-	stats.Emails = v.emails
+	stats.Emails = emails
 	stats.EmailsRemaining = zeroIfNegative(stats.EmailsLimit - stats.EmailsRemaining)
 	stats.AttachmentTotalSize = attachmentsBytesUsed
 	stats.AttachmentTotalSizeRemaining = zeroIfNegative(stats.AttachmentTotalSizeLimit - stats.AttachmentTotalSize)