Преглед изворни кода

WIP: persist message stats

binwiederhier пре 2 година
родитељ
комит
6be95f8285
4 измењених фајлова са 130 додато и 31 уклоњено
  1. 66 11
      server/message_cache.go
  2. 44 18
      server/server.go
  3. 15 2
      server/server_manager.go
  4. 5 0
      server/types.go

+ 66 - 11
server/message_cache.go

@@ -17,6 +17,7 @@ import (
 var (
 	errUnexpectedMessageType = errors.New("unexpected message type")
 	errMessageNotFound       = errors.New("message not found")
+	errNoRows                = errors.New("no rows found")
 )
 
 // Messages cache
@@ -54,6 +55,11 @@ const (
 		CREATE INDEX IF NOT EXISTS idx_sender ON messages (sender);
 		CREATE INDEX IF NOT EXISTS idx_user ON messages (user);
 		CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
+		CREATE TABLE IF NOT EXISTS stats (
+			key TEXT PRIMARY KEY,
+			value INT
+		);
+		INSERT INTO stats (key, value) VALUES ('messages', 0);
 		COMMIT;
 	`
 	insertMessageQuery = `
@@ -108,11 +114,14 @@ const (
 	selectAttachmentsExpiredQuery      = `SELECT mid FROM messages WHERE attachment_expires > 0 AND attachment_expires <= ? AND attachment_deleted = 0`
 	selectAttachmentsSizeBySenderQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = '' AND sender = ? AND attachment_expires >= ?`
 	selectAttachmentsSizeByUserIDQuery = `SELECT IFNULL(SUM(attachment_size), 0) FROM messages WHERE user = ? AND attachment_expires >= ?`
+
+	selectStatsQuery = `SELECT value FROM stats WHERE key = 'messages'`
+	updateStatsQuery = `UPDATE stats SET value = ? WHERE key = 'messages'`
 )
 
 // Schema management queries
 const (
-	currentSchemaVersion          = 10
+	currentSchemaVersion          = 11
 	createSchemaVersionTableQuery = `
 		CREATE TABLE IF NOT EXISTS schemaVersion (
 			id INT PRIMARY KEY,
@@ -222,20 +231,30 @@ const (
 		CREATE INDEX IF NOT EXISTS idx_attachment_expires ON messages (attachment_expires);
 	`
 	migrate9To10UpdateMessageExpiryQuery = `UPDATE messages SET expires = time + ?`
+
+	// 10 -> 11
+	migrate10To11AlterMessagesTableQuery = `
+		CREATE TABLE IF NOT EXISTS stats (
+			key TEXT PRIMARY KEY,
+			value INT
+		);
+		INSERT INTO stats (key, value) VALUES ('messages', 0);
+	`
 )
 
 var (
 	migrations = map[int]func(db *sql.DB, cacheDuration time.Duration) error{
-		0: migrateFrom0,
-		1: migrateFrom1,
-		2: migrateFrom2,
-		3: migrateFrom3,
-		4: migrateFrom4,
-		5: migrateFrom5,
-		6: migrateFrom6,
-		7: migrateFrom7,
-		8: migrateFrom8,
-		9: migrateFrom9,
+		0:  migrateFrom0,
+		1:  migrateFrom1,
+		2:  migrateFrom2,
+		3:  migrateFrom3,
+		4:  migrateFrom4,
+		5:  migrateFrom5,
+		6:  migrateFrom6,
+		7:  migrateFrom7,
+		8:  migrateFrom8,
+		9:  migrateFrom9,
+		10: migrateFrom10,
 	}
 )
 
@@ -706,6 +725,26 @@ func readMessage(rows *sql.Rows) (*message, error) {
 	}, nil
 }
 
+func (c *messageCache) UpdateStats(messages int64) error {
+	_, err := c.db.Exec(updateStatsQuery, messages)
+	return err
+}
+
+func (c *messageCache) Stats() (messages int64, err error) {
+	rows, err := c.db.Query(selectStatsQuery)
+	if err != nil {
+		return 0, err
+	}
+	defer rows.Close()
+	if !rows.Next() {
+		return 0, errNoRows
+	}
+	if err := rows.Scan(&messages); err != nil {
+		return 0, err
+	}
+	return messages, nil
+}
+
 func (c *messageCache) Close() error {
 	return c.db.Close()
 }
@@ -889,3 +928,19 @@ func migrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
 	}
 	return tx.Commit()
 }
+
+func migrateFrom10(db *sql.DB, cacheDuration time.Duration) error {
+	log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
+	tx, err := db.Begin()
+	if err != nil {
+		return err
+	}
+	defer tx.Rollback()
+	if _, err := tx.Exec(migrate10To11AlterMessagesTableQuery); err != nil {
+		return err
+	}
+	if _, err := tx.Exec(updateSchemaVersion, 11); err != nil {
+		return err
+	}
+	return tx.Commit()
+}

+ 44 - 18
server/server.go

@@ -48,7 +48,8 @@ type Server struct {
 	topics            map[string]*topic
 	visitors          map[string]*visitor // ip:<ip> or user:<user>
 	firebaseClient    *firebaseClient
-	messages          int64
+	messages          int64                               // Total number of messages (persisted if messageCache enabled)
+	messagesHistory   []int64                             // Last n values of the messages counter, used to determine rate
 	userManager       *user.Manager                       // Might be nil!
 	messageCache      *messageCache                       // Database that stores the messages
 	fileCache         *fileCache                          // File system based cache that stores attachments
@@ -56,7 +57,7 @@ type Server struct {
 	priceCache        *util.LookupCache[map[string]int64] // Stripe price ID -> price as cents (USD implied!)
 	metricsHandler    http.Handler                        // Handles /metrics if enable-metrics set, and listen-metrics-http not set
 	closeChan         chan bool
-	mu                sync.Mutex
+	mu                sync.RWMutex
 }
 
 // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
@@ -79,7 +80,8 @@ var (
 	matrixPushPath                                       = "/_matrix/push/v1/notify"
 	metricsPath                                          = "/metrics"
 	apiHealthPath                                        = "/v1/health"
-	apiTiers                                             = "/v1/tiers"
+	apiStatsPath                                         = "/v1/stats"
+	apiTiersPath                                         = "/v1/tiers"
 	apiAccountPath                                       = "/v1/account"
 	apiAccountTokenPath                                  = "/v1/account/token"
 	apiAccountPasswordPath                               = "/v1/account/password"
@@ -116,9 +118,10 @@ const (
 	newMessageBody           = "New message"             // Used in poll requests as generic message
 	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
 	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages
-	jsonBodyBytesLimit       = 16384
-	unifiedPushTopicPrefix   = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
-	unifiedPushTopicLength   = 14
+	jsonBodyBytesLimit       = 16384                     // Max number of bytes for a JSON request body
+	unifiedPushTopicPrefix   = "up"                      // Temporarily, we rate limit all "up*" topics based on the subscriber
+	unifiedPushTopicLength   = 14                        // Length of UnifiedPush topics, including the "up" part
+	messagesHistoryMax       = 10                        // Number of message count values to keep in memory
 )
 
 // WebSocket constants
@@ -148,6 +151,10 @@ func New(conf *Config) (*Server, error) {
 	if err != nil {
 		return nil, err
 	}
+	messages, err := messageCache.Stats()
+	if err != nil {
+		return nil, err
+	}
 	var fileCache *fileCache
 	if conf.AttachmentCacheDir != "" {
 		fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
@@ -177,15 +184,17 @@ func New(conf *Config) (*Server, error) {
 		firebaseClient = newFirebaseClient(sender, auther)
 	}
 	s := &Server{
-		config:         conf,
-		messageCache:   messageCache,
-		fileCache:      fileCache,
-		firebaseClient: firebaseClient,
-		smtpSender:     mailer,
-		topics:         topics,
-		userManager:    userManager,
-		visitors:       make(map[string]*visitor),
-		stripe:         stripe,
+		config:          conf,
+		messageCache:    messageCache,
+		fileCache:       fileCache,
+		firebaseClient:  firebaseClient,
+		smtpSender:      mailer,
+		topics:          topics,
+		userManager:     userManager,
+		messages:        messages,
+		messagesHistory: []int64{messages},
+		visitors:        make(map[string]*visitor),
+		stripe:          stripe,
 	}
 	s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
 	return s, nil
@@ -441,7 +450,9 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 		return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
 		return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
-	} else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
+	} else if r.Method == http.MethodGet && r.URL.Path == apiStatsPath {
+		return s.handleStats(w, r, v)
+	} else if r.Method == http.MethodGet && r.URL.Path == apiTiersPath {
 		return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
 	} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
 		return s.handleMatrixDiscovery(w)
@@ -546,17 +557,32 @@ func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request, _ *visito
 	return nil
 }
 
+// handleStatic returns all static resources (excluding the docs), including the web app
 func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
 	r.URL.Path = webSiteDir + r.URL.Path
 	util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
 	return nil
 }
 
+// handleDocs returns static resources related to the docs
 func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
 	util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
 	return nil
 }
 
+// handleStats returns the publicly available server stats
+func (s *Server) handleStats(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
+	s.mu.RLock()
+	n := len(s.messagesHistory)
+	rate := float64(s.messagesHistory[n-1]-s.messagesHistory[0]) / (float64(n-1) * s.config.ManagerInterval.Seconds())
+	response := &apiStatsResponse{
+		Messages:     s.messages,
+		MessagesRate: rate,
+	}
+	s.mu.RUnlock()
+	return s.writeJSON(w, response)
+}
+
 // handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file.
 // Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it
 // can associate the download bandwidth with the uploader.
@@ -1580,9 +1606,9 @@ func (s *Server) sendDelayedMessages() error {
 
 func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 	logvm(v, m).Debug("Sending delayed message")
-	s.mu.Lock()
+	s.mu.RLock()
 	t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
-	s.mu.Unlock()
+	s.mu.RUnlock()
 	if ok {
 		go func() {
 			// We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler

+ 15 - 2
server/server_manager.go

@@ -73,9 +73,9 @@ func (s *Server) execManager() {
 	}
 
 	// Print stats
-	s.mu.Lock()
+	s.mu.RLock()
 	messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
-	s.mu.Unlock()
+	s.mu.RUnlock()
 	log.
 		Tag(tagManager).
 		Fields(log.Context{
@@ -98,6 +98,19 @@ func (s *Server) execManager() {
 	mset(metricUsers, usersCount)
 	mset(metricSubscribers, subscribers)
 	mset(metricTopics, topicsCount)
+
+	// Write stats
+	s.mu.Lock()
+	s.messagesHistory = append(s.messagesHistory, messagesCount)
+	if len(s.messagesHistory) > messagesHistoryMax {
+		s.messagesHistory = s.messagesHistory[1:]
+	}
+	s.mu.Unlock()
+	go func() {
+		if err := s.messageCache.UpdateStats(messagesCount); err != nil {
+			log.Tag(tagManager).Err(err).Warn("Cannot write messages stats")
+		}
+	}()
 }
 
 func (s *Server) pruneVisitors() {

+ 5 - 0
server/types.go

@@ -239,6 +239,11 @@ type apiHealthResponse struct {
 	Healthy bool `json:"healthy"`
 }
 
+type apiStatsResponse struct {
+	Messages     int64   `json:"messages"`
+	MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
+}
+
 type apiAccountCreateRequest struct {
 	Username string `json:"username"`
 	Password string `json:"password"`