topic.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package server
  2. import (
  3. "heckel.io/ntfy/log"
  4. "math/rand"
  5. "sync"
  6. "time"
  7. )
  8. const (
  9. topicExpiryDuration = 6 * time.Hour
  10. )
  11. // topic represents a channel to which subscribers can subscribe, and publishers
  12. // can publish a message
  13. type topic struct {
  14. ID string
  15. subscribers map[int]*topicSubscriber
  16. rateVisitor *visitor
  17. lastAccess time.Time
  18. mu sync.RWMutex
  19. }
  20. type topicSubscriber struct {
  21. userID string // User ID associated with this subscription, may be empty
  22. subscriber subscriber
  23. cancel func()
  24. }
  25. // subscriber is a function that is called for every new message on a topic
  26. type subscriber func(v *visitor, msg *message) error
  27. // newTopic creates a new topic
  28. func newTopic(id string) *topic {
  29. return &topic{
  30. ID: id,
  31. subscribers: make(map[int]*topicSubscriber),
  32. lastAccess: time.Now(),
  33. }
  34. }
  35. // Subscribe subscribes to this topic
  36. func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
  37. t.mu.Lock()
  38. defer t.mu.Unlock()
  39. subscriberID := rand.Int()
  40. t.subscribers[subscriberID] = &topicSubscriber{
  41. userID: userID, // May be empty
  42. subscriber: s,
  43. cancel: cancel,
  44. }
  45. t.lastAccess = time.Now()
  46. return subscriberID
  47. }
  48. func (t *topic) Stale() bool {
  49. t.mu.Lock()
  50. defer t.mu.Unlock()
  51. if t.rateVisitor != nil && !t.rateVisitor.Stale() {
  52. return false
  53. }
  54. return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration
  55. }
  56. func (t *topic) SetRateVisitor(v *visitor) {
  57. t.mu.Lock()
  58. defer t.mu.Unlock()
  59. t.rateVisitor = v
  60. t.lastAccess = time.Now()
  61. }
  62. func (t *topic) RateVisitor() *visitor {
  63. t.mu.Lock()
  64. defer t.mu.Unlock()
  65. if t.rateVisitor != nil && t.rateVisitor.Stale() {
  66. t.rateVisitor = nil
  67. }
  68. return t.rateVisitor
  69. }
  70. // Unsubscribe removes the subscription from the list of subscribers
  71. func (t *topic) Unsubscribe(id int) {
  72. t.mu.Lock()
  73. defer t.mu.Unlock()
  74. delete(t.subscribers, id)
  75. }
  76. // Publish asynchronously publishes to all subscribers
  77. func (t *topic) Publish(v *visitor, m *message) error {
  78. go func() {
  79. // We want to lock the topic as short as possible, so we make a shallow copy of the
  80. // subscribers map here. Actually sending out the messages then doesn't have to lock.
  81. subscribers := t.subscribersCopy()
  82. if len(subscribers) > 0 {
  83. logvm(v, m).Tag(tagPublish).Debug("Forwarding to %d subscriber(s)", len(subscribers))
  84. for _, s := range subscribers {
  85. // We call the subscriber functions in their own Go routines because they are blocking, and
  86. // we don't want individual slow subscribers to be able to block others.
  87. go func(s subscriber) {
  88. if err := s(v, m); err != nil {
  89. logvm(v, m).Tag(tagPublish).Err(err).Warn("Error forwarding to subscriber")
  90. }
  91. }(s.subscriber)
  92. }
  93. } else {
  94. logvm(v, m).Tag(tagPublish).Trace("No stream or WebSocket subscribers, not forwarding")
  95. }
  96. t.Keepalive()
  97. }()
  98. return nil
  99. }
  100. // Stats returns the number of subscribers and last access to this topic
  101. func (t *topic) Stats() (int, time.Time) {
  102. t.mu.RLock()
  103. defer t.mu.RUnlock()
  104. return len(t.subscribers), t.lastAccess
  105. }
  106. // Keepalive sets the last access time and ensures that Stale does not return true
  107. func (t *topic) Keepalive() {
  108. t.mu.Lock()
  109. defer t.mu.Unlock()
  110. t.lastAccess = time.Now()
  111. }
  112. // CancelSubscribers calls the cancel function for all subscribers, forcing
  113. func (t *topic) CancelSubscribers(exceptUserID string) {
  114. t.mu.Lock()
  115. defer t.mu.Unlock()
  116. for _, s := range t.subscribers {
  117. if s.userID != exceptUserID {
  118. log.
  119. Tag(tagSubscribe).
  120. With(t).
  121. Fields(log.Context{
  122. "user_id": s.userID,
  123. }).
  124. Debug("Canceling subscriber %s", s.userID)
  125. s.cancel()
  126. }
  127. }
  128. }
  129. func (t *topic) Context() log.Context {
  130. t.mu.RLock()
  131. defer t.mu.RUnlock()
  132. fields := map[string]any{
  133. "topic": t.ID,
  134. "topic_subscribers": len(t.subscribers),
  135. }
  136. if t.rateVisitor != nil {
  137. for k, v := range t.rateVisitor.Context() {
  138. fields["topic_rate_"+k] = v
  139. }
  140. }
  141. return fields
  142. }
  143. // subscribersCopy returns a shallow copy of the subscribers map
  144. func (t *topic) subscribersCopy() map[int]*topicSubscriber {
  145. t.mu.Lock()
  146. defer t.mu.Unlock()
  147. subscribers := make(map[int]*topicSubscriber)
  148. for k, sub := range t.subscribers {
  149. subscribers[k] = &topicSubscriber{
  150. userID: sub.userID,
  151. subscriber: sub.subscriber,
  152. cancel: sub.cancel,
  153. }
  154. }
  155. return subscribers
  156. }