server.go 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "heckel.io/ntfy/user"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "heckel.io/ntfy/log"
  28. "github.com/emersion/go-smtp"
  29. "github.com/gorilla/websocket"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/util"
  32. )
  33. /*
  34. TODO
  35. --
  36. - Rate limiting: Sensitive endpoints (account/login/change-password/...)
  37. - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
  38. - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
  39. - Reservation icons (UI)
  40. - reservation table delete button: dialog "keep or delete messages?"
  41. - UI: Flickering upgrade banner when logging in
  42. - JS constants
  43. races:
  44. - v.user --> see publishSyncEventAsync() test
  45. payments:
  46. - reconciliation
  47. delete messages + reserved topics on ResetTier delete attachments in access.go
  48. Limits & rate limiting:
  49. rate limiting weirdness. wth is going on?
  50. bandwidth limit must be in tier
  51. users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
  52. when ResetStats() is run, reset messagesLimiter (and others)?
  53. Delete visitor when tier is changed to refresh rate limiters
  54. Make sure account endpoints make sense for admins
  55. Sync:
  56. - sync problems with "deleteAfter=0" and "displayName="
  57. Tests:
  58. - Payment endpoints (make mocks)
  59. - Message rate limiting and reset tests
  60. - Bandwidth limit test
  61. - test that the visitor is based on the IP address when a user has no tier
  62. */
  63. // Server is the main server, providing the UI and API for ntfy
  64. type Server struct {
  65. config *Config
  66. httpServer *http.Server
  67. httpsServer *http.Server
  68. unixListener net.Listener
  69. smtpServer *smtp.Server
  70. smtpServerBackend *smtpBackend
  71. smtpSender mailer
  72. topics map[string]*topic
  73. visitors map[string]*visitor // ip:<ip> or user:<user>
  74. firebaseClient *firebaseClient
  75. messages int64
  76. userManager *user.Manager // Might be nil!
  77. messageCache *messageCache // Database that stores the messages
  78. fileCache *fileCache // File system based cache that stores attachments
  79. stripe stripeAPI // Stripe API, can be replaced with a mock
  80. priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
  81. closeChan chan bool
  82. mu sync.Mutex
  83. }
  84. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  85. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  86. var (
  87. // If changed, don't forget to update Android App and auth_sqlite.go
  88. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  89. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  90. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  91. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  92. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  93. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  94. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  95. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  96. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  97. webConfigPath = "/config.js"
  98. accountPath = "/account"
  99. matrixPushPath = "/_matrix/push/v1/notify"
  100. apiHealthPath = "/v1/health"
  101. apiTiers = "/v1/tiers"
  102. apiAccountPath = "/v1/account"
  103. apiAccountTokenPath = "/v1/account/token"
  104. apiAccountPasswordPath = "/v1/account/password"
  105. apiAccountSettingsPath = "/v1/account/settings"
  106. apiAccountSubscriptionPath = "/v1/account/subscription"
  107. apiAccountReservationPath = "/v1/account/reservation"
  108. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  109. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  110. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  111. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  112. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  113. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  114. apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
  115. staticRegex = regexp.MustCompile(`^/static/.+`)
  116. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  117. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  118. disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
  119. urlRegex = regexp.MustCompile(`^https?://`)
  120. //go:embed site
  121. webFs embed.FS
  122. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  123. webSiteDir = "/site"
  124. webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
  125. webAppIndex = "/app.html" // React app
  126. //go:embed docs
  127. docsStaticFs embed.FS
  128. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  129. )
  130. const (
  131. firebaseControlTopic = "~control" // See Android if changed
  132. firebasePollTopic = "~poll" // See iOS if changed
  133. emptyMessageBody = "triggered" // Used if message body is empty
  134. newMessageBody = "New message" // Used in poll requests as generic message
  135. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  136. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  137. jsonBodyBytesLimit = 16384
  138. )
  139. // WebSocket constants
  140. const (
  141. wsWriteWait = 2 * time.Second
  142. wsBufferSize = 1024
  143. wsReadLimit = 64 // We only ever receive PINGs
  144. wsPongWait = 15 * time.Second
  145. )
  146. // New instantiates a new Server. It creates the cache and adds a Firebase
  147. // subscriber (if configured).
  148. func New(conf *Config) (*Server, error) {
  149. var mailer mailer
  150. if conf.SMTPSenderAddr != "" {
  151. mailer = &smtpSender{config: conf}
  152. }
  153. var stripe stripeAPI
  154. if conf.StripeSecretKey != "" {
  155. stripe = newStripeAPI()
  156. }
  157. messageCache, err := createMessageCache(conf)
  158. if err != nil {
  159. return nil, err
  160. }
  161. topics, err := messageCache.Topics()
  162. if err != nil {
  163. return nil, err
  164. }
  165. var fileCache *fileCache
  166. if conf.AttachmentCacheDir != "" {
  167. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  168. if err != nil {
  169. return nil, err
  170. }
  171. }
  172. var userManager *user.Manager
  173. if conf.AuthFile != "" {
  174. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
  175. if err != nil {
  176. return nil, err
  177. }
  178. }
  179. var firebaseClient *firebaseClient
  180. if conf.FirebaseKeyFile != "" {
  181. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  182. if err != nil {
  183. return nil, err
  184. }
  185. firebaseClient = newFirebaseClient(sender, userManager)
  186. }
  187. s := &Server{
  188. config: conf,
  189. messageCache: messageCache,
  190. fileCache: fileCache,
  191. firebaseClient: firebaseClient,
  192. smtpSender: mailer,
  193. topics: topics,
  194. userManager: userManager,
  195. visitors: make(map[string]*visitor),
  196. stripe: stripe,
  197. }
  198. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  199. return s, nil
  200. }
  201. func createMessageCache(conf *Config) (*messageCache, error) {
  202. if conf.CacheDuration == 0 {
  203. return newNopCache()
  204. } else if conf.CacheFile != "" {
  205. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  206. }
  207. return newMemCache()
  208. }
  209. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  210. // a manager go routine to print stats and prune messages.
  211. func (s *Server) Run() error {
  212. var listenStr string
  213. if s.config.ListenHTTP != "" {
  214. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  215. }
  216. if s.config.ListenHTTPS != "" {
  217. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  218. }
  219. if s.config.ListenUnix != "" {
  220. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  221. }
  222. if s.config.SMTPServerListen != "" {
  223. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  224. }
  225. log.Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  226. mux := http.NewServeMux()
  227. mux.HandleFunc("/", s.handle)
  228. errChan := make(chan error)
  229. s.mu.Lock()
  230. s.closeChan = make(chan bool)
  231. if s.config.ListenHTTP != "" {
  232. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  233. go func() {
  234. errChan <- s.httpServer.ListenAndServe()
  235. }()
  236. }
  237. if s.config.ListenHTTPS != "" {
  238. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  239. go func() {
  240. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  241. }()
  242. }
  243. if s.config.ListenUnix != "" {
  244. go func() {
  245. var err error
  246. s.mu.Lock()
  247. os.Remove(s.config.ListenUnix)
  248. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  249. if err != nil {
  250. s.mu.Unlock()
  251. errChan <- err
  252. return
  253. }
  254. defer s.unixListener.Close()
  255. if s.config.ListenUnixMode > 0 {
  256. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  257. s.mu.Unlock()
  258. errChan <- err
  259. return
  260. }
  261. }
  262. s.mu.Unlock()
  263. httpServer := &http.Server{Handler: mux}
  264. errChan <- httpServer.Serve(s.unixListener)
  265. }()
  266. }
  267. if s.config.SMTPServerListen != "" {
  268. go func() {
  269. errChan <- s.runSMTPServer()
  270. }()
  271. }
  272. s.mu.Unlock()
  273. go s.runManager()
  274. go s.runStatsResetter()
  275. go s.runDelayedSender()
  276. go s.runFirebaseKeepaliver()
  277. return <-errChan
  278. }
  279. // Stop stops HTTP (+HTTPS) server and all managers
  280. func (s *Server) Stop() {
  281. s.mu.Lock()
  282. defer s.mu.Unlock()
  283. if s.httpServer != nil {
  284. s.httpServer.Close()
  285. }
  286. if s.httpsServer != nil {
  287. s.httpsServer.Close()
  288. }
  289. if s.unixListener != nil {
  290. s.unixListener.Close()
  291. }
  292. if s.smtpServer != nil {
  293. s.smtpServer.Close()
  294. }
  295. close(s.closeChan)
  296. }
  297. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  298. v, err := s.visitor(r) // Note: Always returns v, even when error is returned
  299. if err == nil {
  300. log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
  301. if log.IsTrace() {
  302. log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
  303. }
  304. err = s.handleInternal(w, r, v)
  305. }
  306. if err != nil {
  307. if websocket.IsWebSocketUpgrade(r) {
  308. isNormalError := strings.Contains(err.Error(), "i/o timeout")
  309. if isNormalError {
  310. log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error())
  311. } else {
  312. log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error())
  313. }
  314. return // Do not attempt to write to upgraded connection
  315. }
  316. if matrixErr, ok := err.(*errMatrix); ok {
  317. writeMatrixError(w, r, v, matrixErr)
  318. return
  319. }
  320. httpErr, ok := err.(*errHTTP)
  321. if !ok {
  322. httpErr = errHTTPInternalError
  323. }
  324. isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest
  325. if isNormalError {
  326. log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  327. } else {
  328. log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  329. }
  330. w.Header().Set("Content-Type", "application/json")
  331. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  332. w.WriteHeader(httpErr.HTTPCode)
  333. io.WriteString(w, httpErr.JSON()+"\n")
  334. }
  335. }
  336. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  337. if r.Method == http.MethodGet && r.URL.Path == "/" {
  338. return s.ensureWebEnabled(s.handleHome)(w, r, v)
  339. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  340. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  341. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  342. return s.handleHealth(w, r, v)
  343. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  344. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  345. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  346. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  347. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  348. return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
  349. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  350. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  351. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  352. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  353. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  354. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  355. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  356. return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
  357. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  358. return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
  359. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  360. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  361. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  362. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  363. } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  364. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  365. } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  366. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  367. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  368. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  369. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  370. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  371. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  372. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  373. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  374. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  375. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  376. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  377. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  378. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  379. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  380. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  381. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  382. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  383. } else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
  384. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  385. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  386. return s.handleMatrixDiscovery(w)
  387. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  388. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  389. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  390. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  391. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  392. return s.limitRequests(s.handleFile)(w, r, v)
  393. } else if r.Method == http.MethodOptions {
  394. return s.ensureWebEnabled(s.handleOptions)(w, r, v)
  395. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  396. return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  397. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  398. return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  399. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  400. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  401. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  402. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  403. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  404. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  405. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  406. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  407. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  408. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  409. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  410. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  411. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  412. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  413. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  414. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  415. }
  416. return errHTTPNotFound
  417. }
  418. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
  419. if s.config.WebRootIsApp {
  420. r.URL.Path = webAppIndex
  421. } else {
  422. r.URL.Path = webHomeIndex
  423. }
  424. return s.handleStatic(w, r, v)
  425. }
  426. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  427. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  428. if unifiedpush {
  429. w.Header().Set("Content-Type", "application/json")
  430. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  431. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  432. return err
  433. }
  434. r.URL.Path = webAppIndex
  435. return s.handleStatic(w, r, v)
  436. }
  437. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  438. return nil
  439. }
  440. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  441. return s.writeJSON(w, newSuccessResponse())
  442. }
  443. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  444. response := &apiHealthResponse{
  445. Healthy: true,
  446. }
  447. return s.writeJSON(w, response)
  448. }
  449. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  450. appRoot := "/"
  451. if !s.config.WebRootIsApp {
  452. appRoot = "/app"
  453. }
  454. response := &apiConfigResponse{
  455. BaseURL: "", // Will translate to window.location.origin
  456. AppRoot: appRoot,
  457. EnableLogin: s.config.EnableLogin,
  458. EnableSignup: s.config.EnableSignup,
  459. EnablePayments: s.config.StripeSecretKey != "",
  460. EnableReservations: s.config.EnableReservations,
  461. DisallowedTopics: disallowedTopics,
  462. }
  463. b, err := json.MarshalIndent(response, "", " ")
  464. if err != nil {
  465. return err
  466. }
  467. w.Header().Set("Content-Type", "text/javascript")
  468. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  469. return err
  470. }
  471. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  472. r.URL.Path = webSiteDir + r.URL.Path
  473. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  474. return nil
  475. }
  476. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  477. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  478. return nil
  479. }
  480. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  481. if s.config.AttachmentCacheDir == "" {
  482. return errHTTPInternalError
  483. }
  484. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  485. if len(matches) != 2 {
  486. return errHTTPInternalErrorInvalidPath
  487. }
  488. messageID := matches[1]
  489. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  490. stat, err := os.Stat(file)
  491. if err != nil {
  492. return errHTTPNotFound
  493. }
  494. if r.Method == http.MethodGet {
  495. if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
  496. return errHTTPTooManyRequestsLimitAttachmentBandwidth
  497. }
  498. }
  499. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  500. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  501. if r.Method == http.MethodGet {
  502. f, err := os.Open(file)
  503. if err != nil {
  504. return err
  505. }
  506. defer f.Close()
  507. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  508. return err
  509. }
  510. return nil
  511. }
  512. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  513. if s.config.BaseURL == "" {
  514. return errHTTPInternalErrorMissingBaseURL
  515. }
  516. return writeMatrixDiscoveryResponse(w)
  517. }
  518. func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
  519. t, err := s.topicFromPath(r.URL.Path)
  520. if err != nil {
  521. return nil, err
  522. }
  523. if err := v.MessageAllowed(); err != nil {
  524. return nil, errHTTPTooManyRequestsLimitMessages
  525. }
  526. body, err := util.Peek(r.Body, s.config.MessageLimit)
  527. if err != nil {
  528. return nil, err
  529. }
  530. m := newDefaultMessage(t.ID, "")
  531. cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
  532. if err != nil {
  533. return nil, err
  534. }
  535. if m.PollID != "" {
  536. m = newPollRequestMessage(t.ID, m.PollID)
  537. }
  538. if v.user != nil {
  539. m.User = v.user.ID
  540. }
  541. m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
  542. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  543. return nil, err
  544. }
  545. if m.Message == "" {
  546. m.Message = emptyMessageBody
  547. }
  548. delayed := m.Time > time.Now().Unix()
  549. log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
  550. logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
  551. if log.IsTrace() {
  552. log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
  553. }
  554. if !delayed {
  555. if err := t.Publish(v, m); err != nil {
  556. return nil, err
  557. }
  558. if s.firebaseClient != nil && firebase {
  559. go s.sendToFirebase(v, m)
  560. }
  561. if s.smtpSender != nil && email != "" {
  562. v.IncrementEmails()
  563. go s.sendEmail(v, m, email)
  564. }
  565. if s.config.UpstreamBaseURL != "" {
  566. go s.forwardPollRequest(v, m)
  567. }
  568. } else {
  569. log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
  570. }
  571. if cache {
  572. log.Debug("%s Adding message to cache", logMessagePrefix(v, m))
  573. if err := s.messageCache.AddMessage(m); err != nil {
  574. return nil, err
  575. }
  576. }
  577. v.IncrementMessages()
  578. if s.userManager != nil && v.user != nil {
  579. s.userManager.EnqueueStats(v.user)
  580. }
  581. s.mu.Lock()
  582. s.messages++
  583. s.mu.Unlock()
  584. return m, nil
  585. }
  586. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  587. m, err := s.handlePublishWithoutResponse(r, v)
  588. if err != nil {
  589. return err
  590. }
  591. return s.writeJSON(w, m)
  592. }
  593. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  594. _, err := s.handlePublishWithoutResponse(r, v)
  595. if err != nil {
  596. return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
  597. }
  598. return writeMatrixSuccess(w)
  599. }
  600. func (s *Server) sendToFirebase(v *visitor, m *message) {
  601. log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m))
  602. if err := s.firebaseClient.Send(v, m); err != nil {
  603. if err == errFirebaseTemporarilyBanned {
  604. log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  605. } else {
  606. log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  607. }
  608. }
  609. }
  610. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  611. log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email)
  612. if err := s.smtpSender.Send(v, m, email); err != nil {
  613. log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error())
  614. }
  615. }
  616. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  617. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  618. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  619. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  620. log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL)
  621. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  622. if err != nil {
  623. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  624. return
  625. }
  626. req.Header.Set("X-Poll-ID", m.ID)
  627. var httpClient = &http.Client{
  628. Timeout: time.Second * 10,
  629. }
  630. response, err := httpClient.Do(req)
  631. if err != nil {
  632. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  633. return
  634. } else if response.StatusCode != http.StatusOK {
  635. log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode)
  636. return
  637. }
  638. }
  639. func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
  640. cache = readBoolParam(r, true, "x-cache", "cache")
  641. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  642. m.Title = readParam(r, "x-title", "title", "t")
  643. m.Click = readParam(r, "x-click", "click")
  644. icon := readParam(r, "x-icon", "icon")
  645. filename := readParam(r, "x-filename", "filename", "file", "f")
  646. attach := readParam(r, "x-attach", "attach", "a")
  647. if attach != "" || filename != "" {
  648. m.Attachment = &attachment{}
  649. }
  650. if filename != "" {
  651. m.Attachment.Name = filename
  652. }
  653. if attach != "" {
  654. if !urlRegex.MatchString(attach) {
  655. return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
  656. }
  657. m.Attachment.URL = attach
  658. if m.Attachment.Name == "" {
  659. u, err := url.Parse(m.Attachment.URL)
  660. if err == nil {
  661. m.Attachment.Name = path.Base(u.Path)
  662. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  663. m.Attachment.Name = ""
  664. }
  665. }
  666. }
  667. if m.Attachment.Name == "" {
  668. m.Attachment.Name = "attachment"
  669. }
  670. }
  671. if icon != "" {
  672. if !urlRegex.MatchString(icon) {
  673. return false, false, "", false, errHTTPBadRequestIconURLInvalid
  674. }
  675. m.Icon = icon
  676. }
  677. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  678. if email != "" {
  679. if err := v.EmailAllowed(); err != nil {
  680. return false, false, "", false, errHTTPTooManyRequestsLimitEmails
  681. }
  682. }
  683. if s.smtpSender == nil && email != "" {
  684. return false, false, "", false, errHTTPBadRequestEmailDisabled
  685. }
  686. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  687. if messageStr != "" {
  688. m.Message = messageStr
  689. }
  690. m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  691. if err != nil {
  692. return false, false, "", false, errHTTPBadRequestPriorityInvalid
  693. }
  694. tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
  695. if tagsStr != "" {
  696. m.Tags = make([]string, 0)
  697. for _, s := range util.SplitNoEmpty(tagsStr, ",") {
  698. m.Tags = append(m.Tags, strings.TrimSpace(s))
  699. }
  700. }
  701. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  702. if delayStr != "" {
  703. if !cache {
  704. return false, false, "", false, errHTTPBadRequestDelayNoCache
  705. }
  706. if email != "" {
  707. return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  708. }
  709. delay, err := util.ParseFutureTime(delayStr, time.Now())
  710. if err != nil {
  711. return false, false, "", false, errHTTPBadRequestDelayCannotParse
  712. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  713. return false, false, "", false, errHTTPBadRequestDelayTooSmall
  714. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  715. return false, false, "", false, errHTTPBadRequestDelayTooLarge
  716. }
  717. m.Time = delay.Unix()
  718. m.Sender = v.ip // Important for rate limiting
  719. }
  720. actionsStr := readParam(r, "x-actions", "actions", "action")
  721. if actionsStr != "" {
  722. m.Actions, err = parseActions(actionsStr)
  723. if err != nil {
  724. return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
  725. }
  726. }
  727. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  728. if unifiedpush {
  729. firebase = false
  730. unifiedpush = true
  731. }
  732. m.PollID = readParam(r, "x-poll-id", "poll-id")
  733. if m.PollID != "" {
  734. unifiedpush = false
  735. cache = false
  736. email = ""
  737. }
  738. return cache, firebase, email, unifiedpush, nil
  739. }
  740. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  741. //
  742. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  743. // If a message is flagged as poll request, the body does not matter and is discarded
  744. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  745. // If body is binary, encode as base64, if not do not encode
  746. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  747. // Body must be a message, because we attached an external URL
  748. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  749. // Body must be attachment, because we passed a filename
  750. // 5. curl -T file.txt ntfy.sh/mytopic
  751. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  752. // 6. curl -T file.txt ntfy.sh/mytopic
  753. // If file.txt is > message limit, treat it as an attachment
  754. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  755. if m.Event == pollRequestEvent { // Case 1
  756. return s.handleBodyDiscard(body)
  757. } else if unifiedpush {
  758. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  759. } else if m.Attachment != nil && m.Attachment.URL != "" {
  760. return s.handleBodyAsTextMessage(m, body) // Case 3
  761. } else if m.Attachment != nil && m.Attachment.Name != "" {
  762. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  763. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  764. return s.handleBodyAsTextMessage(m, body) // Case 5
  765. }
  766. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  767. }
  768. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  769. _, err := io.Copy(io.Discard, body)
  770. _ = body.Close()
  771. return err
  772. }
  773. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  774. if utf8.Valid(body.PeekedBytes) {
  775. m.Message = string(body.PeekedBytes) // Do not trim
  776. } else {
  777. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  778. m.Encoding = encodingBase64
  779. }
  780. return nil
  781. }
  782. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  783. if !utf8.Valid(body.PeekedBytes) {
  784. return errHTTPBadRequestMessageNotUTF8
  785. }
  786. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  787. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  788. }
  789. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  790. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  791. }
  792. return nil
  793. }
  794. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  795. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  796. return errHTTPBadRequestAttachmentsDisallowed
  797. }
  798. vinfo, err := v.Info()
  799. if err != nil {
  800. return err
  801. }
  802. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  803. if m.Time > attachmentExpiry {
  804. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
  805. }
  806. contentLengthStr := r.Header.Get("Content-Length")
  807. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  808. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  809. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  810. return errHTTPEntityTooLargeAttachment
  811. }
  812. }
  813. if m.Attachment == nil {
  814. m.Attachment = &attachment{}
  815. }
  816. var ext string
  817. m.Sender = v.ip // Important for attachment rate limiting
  818. m.Attachment.Expires = attachmentExpiry
  819. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  820. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  821. if m.Attachment.Name == "" {
  822. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  823. }
  824. if m.Message == "" {
  825. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  826. }
  827. limiters := []util.Limiter{
  828. v.BandwidthLimiter(),
  829. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  830. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  831. }
  832. fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
  833. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  834. if err == util.ErrLimitReached {
  835. return errHTTPEntityTooLargeAttachment
  836. } else if err != nil {
  837. return err
  838. }
  839. return nil
  840. }
  841. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  842. encoder := func(msg *message) (string, error) {
  843. var buf bytes.Buffer
  844. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  845. return "", err
  846. }
  847. return buf.String(), nil
  848. }
  849. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  850. }
  851. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  852. encoder := func(msg *message) (string, error) {
  853. var buf bytes.Buffer
  854. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  855. return "", err
  856. }
  857. if msg.Event != messageEvent {
  858. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  859. }
  860. return fmt.Sprintf("data: %s\n", buf.String()), nil
  861. }
  862. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  863. }
  864. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  865. encoder := func(msg *message) (string, error) {
  866. if msg.Event == messageEvent { // only handle default events
  867. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  868. }
  869. return "\n", nil // "keepalive" and "open" events just send an empty line
  870. }
  871. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  872. }
  873. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  874. log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
  875. defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
  876. if err := v.SubscriptionAllowed(); err != nil {
  877. return errHTTPTooManyRequestsLimitSubscriptions
  878. }
  879. defer v.RemoveSubscription()
  880. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  881. if err != nil {
  882. return err
  883. }
  884. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  885. if err != nil {
  886. return err
  887. }
  888. var wlock sync.Mutex
  889. defer func() {
  890. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  891. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  892. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  893. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  894. wlock.TryLock()
  895. }()
  896. sub := func(v *visitor, msg *message) error {
  897. if !filters.Pass(msg) {
  898. return nil
  899. }
  900. m, err := encoder(msg)
  901. if err != nil {
  902. return err
  903. }
  904. wlock.Lock()
  905. defer wlock.Unlock()
  906. if _, err := w.Write([]byte(m)); err != nil {
  907. return err
  908. }
  909. if fl, ok := w.(http.Flusher); ok {
  910. fl.Flush()
  911. }
  912. return nil
  913. }
  914. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  915. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  916. if poll {
  917. return s.sendOldMessages(topics, since, scheduled, v, sub)
  918. }
  919. ctx, cancel := context.WithCancel(context.Background())
  920. defer cancel()
  921. subscriberIDs := make([]int, 0)
  922. for _, t := range topics {
  923. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  924. }
  925. defer func() {
  926. for i, subscriberID := range subscriberIDs {
  927. topics[i].Unsubscribe(subscriberID) // Order!
  928. }
  929. }()
  930. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  931. return err
  932. }
  933. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  934. return err
  935. }
  936. for {
  937. select {
  938. case <-ctx.Done():
  939. return nil
  940. case <-r.Context().Done():
  941. return nil
  942. case <-time.After(s.config.KeepaliveInterval):
  943. log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r))
  944. v.Keepalive()
  945. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  946. return err
  947. }
  948. }
  949. }
  950. }
  951. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  952. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  953. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  954. }
  955. if err := v.SubscriptionAllowed(); err != nil {
  956. return errHTTPTooManyRequestsLimitSubscriptions
  957. }
  958. defer v.RemoveSubscription()
  959. log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r))
  960. defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r))
  961. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  962. if err != nil {
  963. return err
  964. }
  965. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  966. if err != nil {
  967. return err
  968. }
  969. upgrader := &websocket.Upgrader{
  970. ReadBufferSize: wsBufferSize,
  971. WriteBufferSize: wsBufferSize,
  972. CheckOrigin: func(r *http.Request) bool {
  973. return true // We're open for business!
  974. },
  975. }
  976. conn, err := upgrader.Upgrade(w, r, nil)
  977. if err != nil {
  978. return err
  979. }
  980. defer conn.Close()
  981. // Subscription connections can be canceled externally, see topic.CancelSubscribers
  982. subscriberContext, cancel := context.WithCancel(context.Background())
  983. defer cancel()
  984. // Use errgroup to run WebSocket reader and writer in Go routines
  985. var wlock sync.Mutex
  986. g, gctx := errgroup.WithContext(context.Background())
  987. g.Go(func() error {
  988. <-subscriberContext.Done()
  989. log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
  990. conn.Close()
  991. return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
  992. })
  993. g.Go(func() error {
  994. pongWait := s.config.KeepaliveInterval + wsPongWait
  995. conn.SetReadLimit(wsReadLimit)
  996. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  997. return err
  998. }
  999. conn.SetPongHandler(func(appData string) error {
  1000. log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r))
  1001. return conn.SetReadDeadline(time.Now().Add(pongWait))
  1002. })
  1003. for {
  1004. _, _, err := conn.NextReader()
  1005. if err != nil {
  1006. return err
  1007. }
  1008. select {
  1009. case <-gctx.Done():
  1010. return nil
  1011. default:
  1012. }
  1013. }
  1014. })
  1015. g.Go(func() error {
  1016. ping := func() error {
  1017. wlock.Lock()
  1018. defer wlock.Unlock()
  1019. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1020. return err
  1021. }
  1022. log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r))
  1023. return conn.WriteMessage(websocket.PingMessage, nil)
  1024. }
  1025. for {
  1026. select {
  1027. case <-gctx.Done():
  1028. return nil
  1029. case <-time.After(s.config.KeepaliveInterval):
  1030. v.Keepalive()
  1031. if err := ping(); err != nil {
  1032. return err
  1033. }
  1034. }
  1035. }
  1036. })
  1037. sub := func(v *visitor, msg *message) error {
  1038. if !filters.Pass(msg) {
  1039. return nil
  1040. }
  1041. wlock.Lock()
  1042. defer wlock.Unlock()
  1043. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1044. return err
  1045. }
  1046. return conn.WriteJSON(msg)
  1047. }
  1048. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1049. if poll {
  1050. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1051. }
  1052. subscriberIDs := make([]int, 0)
  1053. for _, t := range topics {
  1054. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  1055. }
  1056. defer func() {
  1057. for i, subscriberID := range subscriberIDs {
  1058. topics[i].Unsubscribe(subscriberID) // Order!
  1059. }
  1060. }()
  1061. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1062. return err
  1063. }
  1064. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1065. return err
  1066. }
  1067. err = g.Wait()
  1068. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  1069. log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error())
  1070. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1071. }
  1072. return err
  1073. }
  1074. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
  1075. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1076. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1077. since, err = parseSince(r, poll)
  1078. if err != nil {
  1079. return
  1080. }
  1081. filters, err = parseQueryFilters(r)
  1082. if err != nil {
  1083. return
  1084. }
  1085. return
  1086. }
  1087. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1088. // marker, returning only messages that are newer than the marker.
  1089. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1090. if since.IsNone() {
  1091. return nil
  1092. }
  1093. messages := make([]*message, 0)
  1094. for _, t := range topics {
  1095. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1096. if err != nil {
  1097. return err
  1098. }
  1099. messages = append(messages, topicMessages...)
  1100. }
  1101. sort.Slice(messages, func(i, j int) bool {
  1102. return messages[i].Time < messages[j].Time
  1103. })
  1104. for _, m := range messages {
  1105. if err := sub(v, m); err != nil {
  1106. return err
  1107. }
  1108. }
  1109. return nil
  1110. }
  1111. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1112. //
  1113. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1114. // "all" for all messages.
  1115. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1116. since := readParam(r, "x-since", "since", "si")
  1117. // Easy cases (empty, all, none)
  1118. if since == "" {
  1119. if poll {
  1120. return sinceAllMessages, nil
  1121. }
  1122. return sinceNoMessages, nil
  1123. } else if since == "all" {
  1124. return sinceAllMessages, nil
  1125. } else if since == "none" {
  1126. return sinceNoMessages, nil
  1127. }
  1128. // ID, timestamp, duration
  1129. if validMessageID(since) {
  1130. return newSinceID(since), nil
  1131. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1132. return newSinceTime(s), nil
  1133. } else if d, err := time.ParseDuration(since); err == nil {
  1134. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1135. }
  1136. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1137. }
  1138. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1139. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1140. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1141. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1142. return nil
  1143. }
  1144. func (s *Server) topicFromPath(path string) (*topic, error) {
  1145. parts := strings.Split(path, "/")
  1146. if len(parts) < 2 {
  1147. return nil, errHTTPBadRequestTopicInvalid
  1148. }
  1149. return s.topicFromID(parts[1])
  1150. }
  1151. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1152. parts := strings.Split(path, "/")
  1153. if len(parts) < 2 {
  1154. return nil, "", errHTTPBadRequestTopicInvalid
  1155. }
  1156. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1157. topics, err := s.topicsFromIDs(topicIDs...)
  1158. if err != nil {
  1159. return nil, "", errHTTPBadRequestTopicInvalid
  1160. }
  1161. return topics, parts[1], nil
  1162. }
  1163. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1164. s.mu.Lock()
  1165. defer s.mu.Unlock()
  1166. topics := make([]*topic, 0)
  1167. for _, id := range ids {
  1168. if util.Contains(disallowedTopics, id) {
  1169. return nil, errHTTPBadRequestTopicDisallowed
  1170. }
  1171. if _, ok := s.topics[id]; !ok {
  1172. if len(s.topics) >= s.config.TotalTopicLimit {
  1173. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1174. }
  1175. s.topics[id] = newTopic(id)
  1176. }
  1177. topics = append(topics, s.topics[id])
  1178. }
  1179. return topics, nil
  1180. }
  1181. func (s *Server) topicFromID(id string) (*topic, error) {
  1182. topics, err := s.topicsFromIDs(id)
  1183. if err != nil {
  1184. return nil, err
  1185. }
  1186. return topics[0], nil
  1187. }
  1188. func (s *Server) execManager() {
  1189. log.Debug("Manager: Starting")
  1190. defer log.Debug("Manager: Finished")
  1191. // WARNING: Make sure to only selectively lock with the mutex, and be aware that this
  1192. // there is no mutex for the entire function.
  1193. // Expire visitors from rate visitors map
  1194. s.mu.Lock()
  1195. staleVisitors := 0
  1196. for ip, v := range s.visitors {
  1197. if v.Stale() {
  1198. log.Trace("Deleting stale visitor %s", v.ip)
  1199. delete(s.visitors, ip)
  1200. staleVisitors++
  1201. }
  1202. }
  1203. s.mu.Unlock()
  1204. log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
  1205. // Delete expired user tokens and users
  1206. if s.userManager != nil {
  1207. if err := s.userManager.RemoveExpiredTokens(); err != nil {
  1208. log.Warn("Error expiring user tokens: %s", err.Error())
  1209. }
  1210. if err := s.userManager.RemoveDeletedUsers(); err != nil {
  1211. log.Warn("Error deleting soft-deleted users: %s", err.Error())
  1212. }
  1213. }
  1214. // Delete expired attachments
  1215. if s.fileCache != nil {
  1216. ids, err := s.messageCache.AttachmentsExpired()
  1217. if err != nil {
  1218. log.Warn("Manager: Error retrieving expired attachments: %s", err.Error())
  1219. } else if len(ids) > 0 {
  1220. if log.IsDebug() {
  1221. log.Debug("Manager: Deleting attachments %s", strings.Join(ids, ", "))
  1222. }
  1223. if err := s.fileCache.Remove(ids...); err != nil {
  1224. log.Warn("Manager: Error deleting attachments: %s", err.Error())
  1225. }
  1226. if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
  1227. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1228. }
  1229. } else {
  1230. log.Debug("Manager: No expired attachments to delete")
  1231. }
  1232. }
  1233. // Prune messages
  1234. log.Debug("Manager: Pruning messages")
  1235. expiredMessageIDs, err := s.messageCache.MessagesExpired()
  1236. if err != nil {
  1237. log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
  1238. } else if len(expiredMessageIDs) > 0 {
  1239. if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
  1240. log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
  1241. }
  1242. if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
  1243. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1244. }
  1245. } else {
  1246. log.Debug("Manager: No expired messages to delete")
  1247. }
  1248. // Message count per topic
  1249. var messages int
  1250. messageCounts, err := s.messageCache.MessageCounts()
  1251. if err != nil {
  1252. log.Warn("Manager: Cannot get message counts: %s", err.Error())
  1253. messageCounts = make(map[string]int) // Empty, so we can continue
  1254. }
  1255. for _, count := range messageCounts {
  1256. messages += count
  1257. }
  1258. // Remove subscriptions without subscribers
  1259. s.mu.Lock()
  1260. var subscribers int
  1261. for _, t := range s.topics {
  1262. subs := t.SubscribersCount()
  1263. msgs, exists := messageCounts[t.ID]
  1264. if subs == 0 && (!exists || msgs == 0) {
  1265. log.Trace("Deleting empty topic %s", t.ID)
  1266. delete(s.topics, t.ID)
  1267. continue
  1268. }
  1269. subscribers += subs
  1270. }
  1271. s.mu.Unlock()
  1272. // Mail stats
  1273. var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
  1274. if s.smtpServerBackend != nil {
  1275. receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
  1276. }
  1277. var sentMailTotal, sentMailSuccess, sentMailFailure int64
  1278. if s.smtpSender != nil {
  1279. sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
  1280. }
  1281. // Print stats
  1282. s.mu.Lock()
  1283. messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
  1284. s.mu.Unlock()
  1285. log.Info("Stats: %d messages published, %d in cache, %d topic(s) active, %d subscriber(s), %d visitor(s), %d mails received (%d successful, %d failed), %d mails sent (%d successful, %d failed)",
  1286. messagesCount, messages, topicsCount, subscribers, visitorsCount,
  1287. receivedMailTotal, receivedMailSuccess, receivedMailFailure,
  1288. sentMailTotal, sentMailSuccess, sentMailFailure)
  1289. }
  1290. func (s *Server) runSMTPServer() error {
  1291. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1292. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1293. s.smtpServer.Addr = s.config.SMTPServerListen
  1294. s.smtpServer.Domain = s.config.SMTPServerDomain
  1295. s.smtpServer.ReadTimeout = 10 * time.Second
  1296. s.smtpServer.WriteTimeout = 10 * time.Second
  1297. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1298. s.smtpServer.MaxRecipients = 1
  1299. s.smtpServer.AllowInsecureAuth = true
  1300. return s.smtpServer.ListenAndServe()
  1301. }
  1302. func (s *Server) runManager() {
  1303. for {
  1304. select {
  1305. case <-time.After(s.config.ManagerInterval):
  1306. s.execManager()
  1307. case <-s.closeChan:
  1308. return
  1309. }
  1310. }
  1311. }
  1312. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1313. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1314. func (s *Server) runStatsResetter() {
  1315. for {
  1316. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1317. timer := time.NewTimer(time.Until(runAt))
  1318. log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
  1319. select {
  1320. case <-timer.C:
  1321. s.resetStats()
  1322. case <-s.closeChan:
  1323. timer.Stop()
  1324. return
  1325. }
  1326. }
  1327. }
  1328. func (s *Server) resetStats() {
  1329. log.Info("Resetting all visitor stats (daily task)")
  1330. s.mu.Lock()
  1331. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1332. for _, v := range s.visitors {
  1333. v.ResetStats()
  1334. }
  1335. if s.userManager != nil {
  1336. if err := s.userManager.ResetStats(); err != nil {
  1337. log.Warn("Failed to write to database: %s", err.Error())
  1338. }
  1339. }
  1340. }
  1341. func (s *Server) runFirebaseKeepaliver() {
  1342. if s.firebaseClient == nil {
  1343. return
  1344. }
  1345. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1346. for {
  1347. select {
  1348. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1349. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1350. case <-time.After(s.config.FirebasePollInterval):
  1351. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1352. case <-s.closeChan:
  1353. return
  1354. }
  1355. }
  1356. }
  1357. func (s *Server) runDelayedSender() {
  1358. for {
  1359. select {
  1360. case <-time.After(s.config.DelayedSenderInterval):
  1361. if err := s.sendDelayedMessages(); err != nil {
  1362. log.Warn("Error sending delayed messages: %s", err.Error())
  1363. }
  1364. case <-s.closeChan:
  1365. return
  1366. }
  1367. }
  1368. }
  1369. func (s *Server) sendDelayedMessages() error {
  1370. messages, err := s.messageCache.MessagesDue()
  1371. if err != nil {
  1372. return err
  1373. }
  1374. for _, m := range messages {
  1375. var v *visitor
  1376. if s.userManager != nil && m.User != "" {
  1377. u, err := s.userManager.User(m.User)
  1378. if err != nil {
  1379. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1380. continue
  1381. }
  1382. v = s.visitorFromUser(u, m.Sender)
  1383. } else {
  1384. v = s.visitorFromIP(m.Sender)
  1385. }
  1386. if err := s.sendDelayedMessage(v, m); err != nil {
  1387. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1388. }
  1389. }
  1390. return nil
  1391. }
  1392. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1393. log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
  1394. s.mu.Lock()
  1395. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1396. s.mu.Unlock()
  1397. if ok {
  1398. go func() {
  1399. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1400. if err := t.Publish(v, m); err != nil {
  1401. log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error())
  1402. }
  1403. }()
  1404. }
  1405. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1406. go s.sendToFirebase(v, m)
  1407. }
  1408. if s.config.UpstreamBaseURL != "" {
  1409. go s.forwardPollRequest(v, m)
  1410. }
  1411. if err := s.messageCache.MarkPublished(m); err != nil {
  1412. return err
  1413. }
  1414. return nil
  1415. }
  1416. func (s *Server) limitRequests(next handleFunc) handleFunc {
  1417. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1418. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  1419. return next(w, r, v)
  1420. } else if err := v.RequestAllowed(); err != nil {
  1421. return errHTTPTooManyRequestsLimitRequests
  1422. }
  1423. return next(w, r, v)
  1424. }
  1425. }
  1426. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1427. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1428. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1429. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1430. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
  1431. if err != nil {
  1432. return err
  1433. }
  1434. if !topicRegex.MatchString(m.Topic) {
  1435. return errHTTPBadRequestTopicInvalid
  1436. }
  1437. if m.Message == "" {
  1438. m.Message = emptyMessageBody
  1439. }
  1440. r.URL.Path = "/" + m.Topic
  1441. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1442. if m.Title != "" {
  1443. r.Header.Set("X-Title", m.Title)
  1444. }
  1445. if m.Priority != 0 {
  1446. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1447. }
  1448. if m.Tags != nil && len(m.Tags) > 0 {
  1449. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1450. }
  1451. if m.Attach != "" {
  1452. r.Header.Set("X-Attach", m.Attach)
  1453. }
  1454. if m.Filename != "" {
  1455. r.Header.Set("X-Filename", m.Filename)
  1456. }
  1457. if m.Click != "" {
  1458. r.Header.Set("X-Click", m.Click)
  1459. }
  1460. if m.Icon != "" {
  1461. r.Header.Set("X-Icon", m.Icon)
  1462. }
  1463. if len(m.Actions) > 0 {
  1464. actionsStr, err := json.Marshal(m.Actions)
  1465. if err != nil {
  1466. return errHTTPBadRequestMessageJSONInvalid
  1467. }
  1468. r.Header.Set("X-Actions", string(actionsStr))
  1469. }
  1470. if m.Email != "" {
  1471. r.Header.Set("X-Email", m.Email)
  1472. }
  1473. if m.Delay != "" {
  1474. r.Header.Set("X-Delay", m.Delay)
  1475. }
  1476. return next(w, r, v)
  1477. }
  1478. }
  1479. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1480. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1481. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1482. if err != nil {
  1483. return err
  1484. }
  1485. if err := next(w, newRequest, v); err != nil {
  1486. return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
  1487. }
  1488. return nil
  1489. }
  1490. }
  1491. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1492. return s.autorizeTopic(next, user.PermissionWrite)
  1493. }
  1494. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1495. return s.autorizeTopic(next, user.PermissionRead)
  1496. }
  1497. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1498. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1499. if s.userManager == nil {
  1500. return next(w, r, v)
  1501. }
  1502. topics, _, err := s.topicsFromPath(r.URL.Path)
  1503. if err != nil {
  1504. return err
  1505. }
  1506. for _, t := range topics {
  1507. if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
  1508. log.Info("unauthorized: %s", err.Error())
  1509. return errHTTPForbidden
  1510. }
  1511. }
  1512. return next(w, r, v)
  1513. }
  1514. }
  1515. // visitor creates or retrieves a rate.Limiter for the given visitor.
  1516. // Note that this function will always return a visitor, even if an error occurs.
  1517. func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
  1518. ip := extractIPAddress(r, s.config.BehindProxy)
  1519. var u *user.User // may stay nil if no auth header!
  1520. if u, err = s.authenticate(r); err != nil {
  1521. log.Debug("authentication failed: %s", err.Error())
  1522. err = errHTTPUnauthorized // Always return visitor, even when error occurs!
  1523. }
  1524. if u != nil {
  1525. v = s.visitorFromUser(u, ip)
  1526. } else {
  1527. v = s.visitorFromIP(ip)
  1528. }
  1529. v.mu.Lock()
  1530. v.user = u
  1531. v.mu.Unlock()
  1532. return v, err // Always return visitor, even when error occurs!
  1533. }
  1534. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1535. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1536. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1537. // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
  1538. func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
  1539. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1540. queryParam := readQueryParam(r, "authorization", "auth")
  1541. if queryParam != "" {
  1542. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1543. if err != nil {
  1544. return nil, err
  1545. }
  1546. value = strings.TrimSpace(string(a))
  1547. }
  1548. if value == "" {
  1549. return nil, nil
  1550. } else if s.userManager == nil {
  1551. return nil, errHTTPUnauthorized
  1552. }
  1553. if strings.HasPrefix(value, "Bearer") {
  1554. return s.authenticateBearerAuth(value)
  1555. }
  1556. return s.authenticateBasicAuth(r, value)
  1557. }
  1558. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1559. r.Header.Set("Authorization", value)
  1560. username, password, ok := r.BasicAuth()
  1561. if !ok {
  1562. return nil, errors.New("invalid basic auth")
  1563. }
  1564. return s.userManager.Authenticate(username, password)
  1565. }
  1566. func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
  1567. token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
  1568. return s.userManager.AuthenticateToken(token)
  1569. }
  1570. func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
  1571. s.mu.Lock()
  1572. defer s.mu.Unlock()
  1573. v, exists := s.visitors[visitorID]
  1574. if !exists {
  1575. s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1576. return s.visitors[visitorID]
  1577. }
  1578. v.Keepalive()
  1579. return v
  1580. }
  1581. func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
  1582. return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
  1583. }
  1584. func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
  1585. return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
  1586. }
  1587. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1588. w.Header().Set("Content-Type", "application/json")
  1589. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1590. if err := json.NewEncoder(w).Encode(v); err != nil {
  1591. return err
  1592. }
  1593. return nil
  1594. }