main.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "errors"
  6. "io"
  7. "log"
  8. "math/rand"
  9. "net/http"
  10. "sync"
  11. "time"
  12. )
  13. type Message struct {
  14. Time int64 `json:"time"`
  15. Message string `json:"message"`
  16. }
  17. type Channel struct {
  18. id string
  19. listeners map[int]listener
  20. last time.Time
  21. ctx context.Context
  22. mu sync.Mutex
  23. }
  24. type Server struct {
  25. channels map[string]*Channel
  26. mu sync.Mutex
  27. }
  28. type listener func(msg *Message)
  29. func main() {
  30. s := &Server{
  31. channels: make(map[string]*Channel),
  32. }
  33. go func() {
  34. for {
  35. time.Sleep(5 * time.Second)
  36. s.mu.Lock()
  37. log.Printf("channels: %d", len(s.channels))
  38. s.mu.Unlock()
  39. }
  40. }()
  41. http.HandleFunc("/", s.handle)
  42. if err := http.ListenAndServe(":9997", nil); err != nil {
  43. log.Fatalln(err)
  44. }
  45. }
  46. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  47. if err := s.handleInternal(w, r); err != nil {
  48. w.WriteHeader(http.StatusInternalServerError)
  49. _, _ = io.WriteString(w, err.Error())
  50. }
  51. }
  52. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
  53. if len(r.URL.Path) == 0 {
  54. return errors.New("invalid path")
  55. }
  56. channel := s.channel(r.URL.Path[1:])
  57. switch r.Method {
  58. case http.MethodGet:
  59. return s.handleGET(w, r, channel)
  60. case http.MethodPut:
  61. return s.handlePUT(w, r, channel)
  62. default:
  63. return errors.New("invalid method")
  64. }
  65. }
  66. func (s *Server) handleGET(w http.ResponseWriter, r *http.Request, ch *Channel) error {
  67. fl, ok := w.(http.Flusher)
  68. if !ok {
  69. return errors.New("not a flusher")
  70. }
  71. listenerID := rand.Int()
  72. l := func (msg *Message) {
  73. json.NewEncoder(w).Encode(&msg)
  74. fl.Flush()
  75. }
  76. ch.mu.Lock()
  77. ch.listeners[listenerID] = l
  78. ch.last = time.Now()
  79. ch.mu.Unlock()
  80. select {
  81. case <-ch.ctx.Done():
  82. case <-r.Context().Done():
  83. }
  84. ch.mu.Lock()
  85. delete(ch.listeners, listenerID)
  86. if len(ch.listeners) == 0 {
  87. s.mu.Lock()
  88. delete(s.channels, ch.id)
  89. s.mu.Unlock()
  90. }
  91. ch.mu.Unlock()
  92. return nil
  93. }
  94. func (s *Server) handlePUT(w http.ResponseWriter, r *http.Request, ch *Channel) error {
  95. ch.mu.Lock()
  96. defer ch.mu.Unlock()
  97. if len(ch.listeners) == 0 {
  98. return errors.New("no listeners")
  99. }
  100. defer r.Body.Close()
  101. ch.last = time.Now()
  102. msg, _ := io.ReadAll(r.Body)
  103. for _, l := range ch.listeners {
  104. l(&Message{
  105. Time: time.Now().UnixMilli(),
  106. Message: string(msg),
  107. })
  108. }
  109. return nil
  110. }
  111. func (s *Server) channel(channelID string) *Channel {
  112. s.mu.Lock()
  113. defer s.mu.Unlock()
  114. c, ok := s.channels[channelID]
  115. if !ok {
  116. ctx, _ := context.WithCancel(context.Background()) // FIXME
  117. c = &Channel{
  118. id: channelID,
  119. listeners: make(map[int]listener),
  120. last: time.Now(),
  121. ctx: ctx,
  122. mu: sync.Mutex{},
  123. }
  124. s.channels[channelID] = c
  125. }
  126. return c
  127. }