topic.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "log"
  6. "math/rand"
  7. "sync"
  8. "time"
  9. )
  10. // topic represents a channel to which subscribers can subscribe, and publishers
  11. // can publish a message
  12. type topic struct {
  13. id string
  14. subscribers map[int]subscriber
  15. messages int
  16. last time.Time
  17. ctx context.Context
  18. cancel context.CancelFunc
  19. mu sync.Mutex
  20. }
  21. // subscriber is a function that is called for every new message on a topic
  22. type subscriber func(msg *message) error
  23. // newTopic creates a new topic
  24. func newTopic(id string) *topic {
  25. ctx, cancel := context.WithCancel(context.Background())
  26. return &topic{
  27. id: id,
  28. subscribers: make(map[int]subscriber),
  29. last: time.Now(),
  30. ctx: ctx,
  31. cancel: cancel,
  32. }
  33. }
  34. func (t *topic) Subscribe(s subscriber) int {
  35. t.mu.Lock()
  36. defer t.mu.Unlock()
  37. subscriberID := rand.Int()
  38. t.subscribers[subscriberID] = s
  39. t.last = time.Now()
  40. return subscriberID
  41. }
  42. func (t *topic) Unsubscribe(id int) int {
  43. t.mu.Lock()
  44. defer t.mu.Unlock()
  45. delete(t.subscribers, id)
  46. return len(t.subscribers)
  47. }
  48. func (t *topic) Publish(m *message) error {
  49. t.mu.Lock()
  50. defer t.mu.Unlock()
  51. if len(t.subscribers) == 0 {
  52. return errors.New("no subscribers")
  53. }
  54. t.last = time.Now()
  55. t.messages++
  56. for _, s := range t.subscribers {
  57. if err := s(m); err != nil {
  58. log.Printf("error publishing message to subscriber")
  59. }
  60. }
  61. return nil
  62. }
  63. func (t *topic) Stats() (subscribers int, messages int) {
  64. t.mu.Lock()
  65. defer t.mu.Unlock()
  66. return len(t.subscribers), t.messages
  67. }
  68. func (t *topic) Close() {
  69. t.cancel()
  70. }