server.go 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "heckel.io/ntfy/user"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "heckel.io/ntfy/log"
  28. "github.com/emersion/go-smtp"
  29. "github.com/gorilla/websocket"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/util"
  32. )
  33. /*
  34. TODO
  35. --
  36. - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
  37. - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
  38. - HIGH Rate limiting: Bandwidth limit must be in tier + TESTS
  39. - HIGH Sync problems with "deleteAfter=0" and "displayName="
  40. - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
  41. - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
  42. - Reservation table delete button: dialog "keep or delete messages?"
  43. - UI: Flickering upgrade banner when logging in
  44. - JS constants
  45. races:
  46. - v.user --> see publishSyncEventAsync() test
  47. payments:
  48. - reconciliation
  49. delete messages + reserved topics on ResetTier delete attachments in access.go
  50. Limits & rate limiting:
  51. users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
  52. when ResetStats() is run, reset messagesLimiter (and others)?
  53. Delete visitor when tier is changed to refresh rate limiters
  54. Make sure account endpoints make sense for admins
  55. Tests:
  56. - Payment endpoints (make mocks)
  57. - test that the visitor is based on the IP address when a user has no tier
  58. */
  59. // Server is the main server, providing the UI and API for ntfy
  60. type Server struct {
  61. config *Config
  62. httpServer *http.Server
  63. httpsServer *http.Server
  64. unixListener net.Listener
  65. smtpServer *smtp.Server
  66. smtpServerBackend *smtpBackend
  67. smtpSender mailer
  68. topics map[string]*topic
  69. visitors map[string]*visitor // ip:<ip> or user:<user>
  70. firebaseClient *firebaseClient
  71. messages int64
  72. userManager *user.Manager // Might be nil!
  73. messageCache *messageCache // Database that stores the messages
  74. fileCache *fileCache // File system based cache that stores attachments
  75. stripe stripeAPI // Stripe API, can be replaced with a mock
  76. priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
  77. closeChan chan bool
  78. mu sync.Mutex
  79. }
  80. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  81. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  82. var (
  83. // If changed, don't forget to update Android App and auth_sqlite.go
  84. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  85. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  86. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  87. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  88. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  89. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  90. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  91. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  92. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  93. webConfigPath = "/config.js"
  94. accountPath = "/account"
  95. matrixPushPath = "/_matrix/push/v1/notify"
  96. apiHealthPath = "/v1/health"
  97. apiTiers = "/v1/tiers"
  98. apiAccountPath = "/v1/account"
  99. apiAccountTokenPath = "/v1/account/token"
  100. apiAccountPasswordPath = "/v1/account/password"
  101. apiAccountSettingsPath = "/v1/account/settings"
  102. apiAccountSubscriptionPath = "/v1/account/subscription"
  103. apiAccountReservationPath = "/v1/account/reservation"
  104. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  105. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  106. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  107. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  108. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  109. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  110. apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
  111. staticRegex = regexp.MustCompile(`^/static/.+`)
  112. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  113. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  114. disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
  115. urlRegex = regexp.MustCompile(`^https?://`)
  116. //go:embed site
  117. webFs embed.FS
  118. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  119. webSiteDir = "/site"
  120. webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
  121. webAppIndex = "/app.html" // React app
  122. //go:embed docs
  123. docsStaticFs embed.FS
  124. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  125. )
  126. const (
  127. firebaseControlTopic = "~control" // See Android if changed
  128. firebasePollTopic = "~poll" // See iOS if changed
  129. emptyMessageBody = "triggered" // Used if message body is empty
  130. newMessageBody = "New message" // Used in poll requests as generic message
  131. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  132. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  133. jsonBodyBytesLimit = 16384
  134. )
  135. // WebSocket constants
  136. const (
  137. wsWriteWait = 2 * time.Second
  138. wsBufferSize = 1024
  139. wsReadLimit = 64 // We only ever receive PINGs
  140. wsPongWait = 15 * time.Second
  141. )
  142. // New instantiates a new Server. It creates the cache and adds a Firebase
  143. // subscriber (if configured).
  144. func New(conf *Config) (*Server, error) {
  145. var mailer mailer
  146. if conf.SMTPSenderAddr != "" {
  147. mailer = &smtpSender{config: conf}
  148. }
  149. var stripe stripeAPI
  150. if conf.StripeSecretKey != "" {
  151. stripe = newStripeAPI()
  152. }
  153. messageCache, err := createMessageCache(conf)
  154. if err != nil {
  155. return nil, err
  156. }
  157. topics, err := messageCache.Topics()
  158. if err != nil {
  159. return nil, err
  160. }
  161. var fileCache *fileCache
  162. if conf.AttachmentCacheDir != "" {
  163. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  164. if err != nil {
  165. return nil, err
  166. }
  167. }
  168. var userManager *user.Manager
  169. if conf.AuthFile != "" {
  170. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
  171. if err != nil {
  172. return nil, err
  173. }
  174. }
  175. var firebaseClient *firebaseClient
  176. if conf.FirebaseKeyFile != "" {
  177. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  178. if err != nil {
  179. return nil, err
  180. }
  181. firebaseClient = newFirebaseClient(sender, userManager)
  182. }
  183. s := &Server{
  184. config: conf,
  185. messageCache: messageCache,
  186. fileCache: fileCache,
  187. firebaseClient: firebaseClient,
  188. smtpSender: mailer,
  189. topics: topics,
  190. userManager: userManager,
  191. visitors: make(map[string]*visitor),
  192. stripe: stripe,
  193. }
  194. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  195. return s, nil
  196. }
  197. func createMessageCache(conf *Config) (*messageCache, error) {
  198. if conf.CacheDuration == 0 {
  199. return newNopCache()
  200. } else if conf.CacheFile != "" {
  201. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  202. }
  203. return newMemCache()
  204. }
  205. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  206. // a manager go routine to print stats and prune messages.
  207. func (s *Server) Run() error {
  208. var listenStr string
  209. if s.config.ListenHTTP != "" {
  210. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  211. }
  212. if s.config.ListenHTTPS != "" {
  213. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  214. }
  215. if s.config.ListenUnix != "" {
  216. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  217. }
  218. if s.config.SMTPServerListen != "" {
  219. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  220. }
  221. log.Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  222. mux := http.NewServeMux()
  223. mux.HandleFunc("/", s.handle)
  224. errChan := make(chan error)
  225. s.mu.Lock()
  226. s.closeChan = make(chan bool)
  227. if s.config.ListenHTTP != "" {
  228. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  229. go func() {
  230. errChan <- s.httpServer.ListenAndServe()
  231. }()
  232. }
  233. if s.config.ListenHTTPS != "" {
  234. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  235. go func() {
  236. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  237. }()
  238. }
  239. if s.config.ListenUnix != "" {
  240. go func() {
  241. var err error
  242. s.mu.Lock()
  243. os.Remove(s.config.ListenUnix)
  244. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  245. if err != nil {
  246. s.mu.Unlock()
  247. errChan <- err
  248. return
  249. }
  250. defer s.unixListener.Close()
  251. if s.config.ListenUnixMode > 0 {
  252. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  253. s.mu.Unlock()
  254. errChan <- err
  255. return
  256. }
  257. }
  258. s.mu.Unlock()
  259. httpServer := &http.Server{Handler: mux}
  260. errChan <- httpServer.Serve(s.unixListener)
  261. }()
  262. }
  263. if s.config.SMTPServerListen != "" {
  264. go func() {
  265. errChan <- s.runSMTPServer()
  266. }()
  267. }
  268. s.mu.Unlock()
  269. go s.runManager()
  270. go s.runStatsResetter()
  271. go s.runDelayedSender()
  272. go s.runFirebaseKeepaliver()
  273. return <-errChan
  274. }
  275. // Stop stops HTTP (+HTTPS) server and all managers
  276. func (s *Server) Stop() {
  277. s.mu.Lock()
  278. defer s.mu.Unlock()
  279. if s.httpServer != nil {
  280. s.httpServer.Close()
  281. }
  282. if s.httpsServer != nil {
  283. s.httpsServer.Close()
  284. }
  285. if s.unixListener != nil {
  286. s.unixListener.Close()
  287. }
  288. if s.smtpServer != nil {
  289. s.smtpServer.Close()
  290. }
  291. close(s.closeChan)
  292. }
  293. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  294. v, err := s.visitor(r) // Note: Always returns v, even when error is returned
  295. if err == nil {
  296. log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
  297. if log.IsTrace() {
  298. log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
  299. }
  300. err = s.handleInternal(w, r, v)
  301. }
  302. if err != nil {
  303. if websocket.IsWebSocketUpgrade(r) {
  304. isNormalError := strings.Contains(err.Error(), "i/o timeout")
  305. if isNormalError {
  306. log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error())
  307. } else {
  308. log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error())
  309. }
  310. return // Do not attempt to write to upgraded connection
  311. }
  312. if matrixErr, ok := err.(*errMatrix); ok {
  313. writeMatrixError(w, r, v, matrixErr)
  314. return
  315. }
  316. httpErr, ok := err.(*errHTTP)
  317. if !ok {
  318. httpErr = errHTTPInternalError
  319. }
  320. isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest
  321. if isNormalError {
  322. log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  323. } else {
  324. log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  325. }
  326. w.Header().Set("Content-Type", "application/json")
  327. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  328. w.WriteHeader(httpErr.HTTPCode)
  329. io.WriteString(w, httpErr.JSON()+"\n")
  330. }
  331. }
  332. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  333. if r.Method == http.MethodGet && r.URL.Path == "/" {
  334. return s.ensureWebEnabled(s.handleHome)(w, r, v)
  335. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  336. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  337. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  338. return s.handleHealth(w, r, v)
  339. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  340. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  341. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  342. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  343. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  344. return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
  345. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  346. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  347. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  348. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  349. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  350. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  351. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  352. return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
  353. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  354. return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
  355. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  356. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  357. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  358. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  359. } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  360. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  361. } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  362. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  363. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  364. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  365. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  366. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  367. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  368. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  369. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  370. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  371. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  372. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  373. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  374. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  375. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  376. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  377. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  378. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  379. } else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
  380. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  381. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  382. return s.handleMatrixDiscovery(w)
  383. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  384. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  385. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  386. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  387. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  388. return s.limitRequests(s.handleFile)(w, r, v)
  389. } else if r.Method == http.MethodOptions {
  390. return s.ensureWebEnabled(s.handleOptions)(w, r, v)
  391. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  392. return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  393. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  394. return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  395. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  396. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  397. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  398. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  399. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  400. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  401. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  402. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  403. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  404. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  405. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  406. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  407. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  408. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  409. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  410. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  411. }
  412. return errHTTPNotFound
  413. }
  414. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
  415. if s.config.WebRootIsApp {
  416. r.URL.Path = webAppIndex
  417. } else {
  418. r.URL.Path = webHomeIndex
  419. }
  420. return s.handleStatic(w, r, v)
  421. }
  422. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  423. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  424. if unifiedpush {
  425. w.Header().Set("Content-Type", "application/json")
  426. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  427. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  428. return err
  429. }
  430. r.URL.Path = webAppIndex
  431. return s.handleStatic(w, r, v)
  432. }
  433. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  434. return nil
  435. }
  436. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  437. return s.writeJSON(w, newSuccessResponse())
  438. }
  439. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  440. response := &apiHealthResponse{
  441. Healthy: true,
  442. }
  443. return s.writeJSON(w, response)
  444. }
  445. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  446. appRoot := "/"
  447. if !s.config.WebRootIsApp {
  448. appRoot = "/app"
  449. }
  450. response := &apiConfigResponse{
  451. BaseURL: "", // Will translate to window.location.origin
  452. AppRoot: appRoot,
  453. EnableLogin: s.config.EnableLogin,
  454. EnableSignup: s.config.EnableSignup,
  455. EnablePayments: s.config.StripeSecretKey != "",
  456. EnableReservations: s.config.EnableReservations,
  457. DisallowedTopics: disallowedTopics,
  458. }
  459. b, err := json.MarshalIndent(response, "", " ")
  460. if err != nil {
  461. return err
  462. }
  463. w.Header().Set("Content-Type", "text/javascript")
  464. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  465. return err
  466. }
  467. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  468. r.URL.Path = webSiteDir + r.URL.Path
  469. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  470. return nil
  471. }
  472. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  473. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  474. return nil
  475. }
  476. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  477. if s.config.AttachmentCacheDir == "" {
  478. return errHTTPInternalError
  479. }
  480. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  481. if len(matches) != 2 {
  482. return errHTTPInternalErrorInvalidPath
  483. }
  484. messageID := matches[1]
  485. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  486. stat, err := os.Stat(file)
  487. if err != nil {
  488. return errHTTPNotFound
  489. }
  490. if r.Method == http.MethodGet {
  491. if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
  492. return errHTTPTooManyRequestsLimitAttachmentBandwidth
  493. }
  494. }
  495. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  496. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  497. if r.Method == http.MethodGet {
  498. f, err := os.Open(file)
  499. if err != nil {
  500. return err
  501. }
  502. defer f.Close()
  503. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  504. return err
  505. }
  506. return nil
  507. }
  508. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  509. if s.config.BaseURL == "" {
  510. return errHTTPInternalErrorMissingBaseURL
  511. }
  512. return writeMatrixDiscoveryResponse(w)
  513. }
  514. func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
  515. t, err := s.topicFromPath(r.URL.Path)
  516. if err != nil {
  517. return nil, err
  518. }
  519. if err := v.MessageAllowed(); err != nil {
  520. return nil, errHTTPTooManyRequestsLimitMessages
  521. }
  522. body, err := util.Peek(r.Body, s.config.MessageLimit)
  523. if err != nil {
  524. return nil, err
  525. }
  526. m := newDefaultMessage(t.ID, "")
  527. cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
  528. if err != nil {
  529. return nil, err
  530. }
  531. if m.PollID != "" {
  532. m = newPollRequestMessage(t.ID, m.PollID)
  533. }
  534. if v.user != nil {
  535. m.User = v.user.ID
  536. }
  537. m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
  538. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  539. return nil, err
  540. }
  541. if m.Message == "" {
  542. m.Message = emptyMessageBody
  543. }
  544. delayed := m.Time > time.Now().Unix()
  545. log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
  546. logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
  547. if log.IsTrace() {
  548. log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
  549. }
  550. if !delayed {
  551. if err := t.Publish(v, m); err != nil {
  552. return nil, err
  553. }
  554. if s.firebaseClient != nil && firebase {
  555. go s.sendToFirebase(v, m)
  556. }
  557. if s.smtpSender != nil && email != "" {
  558. v.IncrementEmails()
  559. go s.sendEmail(v, m, email)
  560. }
  561. if s.config.UpstreamBaseURL != "" {
  562. go s.forwardPollRequest(v, m)
  563. }
  564. } else {
  565. log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
  566. }
  567. if cache {
  568. log.Debug("%s Adding message to cache", logMessagePrefix(v, m))
  569. if err := s.messageCache.AddMessage(m); err != nil {
  570. return nil, err
  571. }
  572. }
  573. v.IncrementMessages()
  574. if s.userManager != nil && v.user != nil {
  575. s.userManager.EnqueueStats(v.user)
  576. }
  577. s.mu.Lock()
  578. s.messages++
  579. s.mu.Unlock()
  580. return m, nil
  581. }
  582. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  583. m, err := s.handlePublishWithoutResponse(r, v)
  584. if err != nil {
  585. return err
  586. }
  587. return s.writeJSON(w, m)
  588. }
  589. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  590. _, err := s.handlePublishWithoutResponse(r, v)
  591. if err != nil {
  592. return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
  593. }
  594. return writeMatrixSuccess(w)
  595. }
  596. func (s *Server) sendToFirebase(v *visitor, m *message) {
  597. log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m))
  598. if err := s.firebaseClient.Send(v, m); err != nil {
  599. if err == errFirebaseTemporarilyBanned {
  600. log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  601. } else {
  602. log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  603. }
  604. }
  605. }
  606. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  607. log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email)
  608. if err := s.smtpSender.Send(v, m, email); err != nil {
  609. log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error())
  610. }
  611. }
  612. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  613. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  614. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  615. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  616. log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL)
  617. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  618. if err != nil {
  619. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  620. return
  621. }
  622. req.Header.Set("X-Poll-ID", m.ID)
  623. var httpClient = &http.Client{
  624. Timeout: time.Second * 10,
  625. }
  626. response, err := httpClient.Do(req)
  627. if err != nil {
  628. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  629. return
  630. } else if response.StatusCode != http.StatusOK {
  631. log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode)
  632. return
  633. }
  634. }
  635. func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
  636. cache = readBoolParam(r, true, "x-cache", "cache")
  637. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  638. m.Title = readParam(r, "x-title", "title", "t")
  639. m.Click = readParam(r, "x-click", "click")
  640. icon := readParam(r, "x-icon", "icon")
  641. filename := readParam(r, "x-filename", "filename", "file", "f")
  642. attach := readParam(r, "x-attach", "attach", "a")
  643. if attach != "" || filename != "" {
  644. m.Attachment = &attachment{}
  645. }
  646. if filename != "" {
  647. m.Attachment.Name = filename
  648. }
  649. if attach != "" {
  650. if !urlRegex.MatchString(attach) {
  651. return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
  652. }
  653. m.Attachment.URL = attach
  654. if m.Attachment.Name == "" {
  655. u, err := url.Parse(m.Attachment.URL)
  656. if err == nil {
  657. m.Attachment.Name = path.Base(u.Path)
  658. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  659. m.Attachment.Name = ""
  660. }
  661. }
  662. }
  663. if m.Attachment.Name == "" {
  664. m.Attachment.Name = "attachment"
  665. }
  666. }
  667. if icon != "" {
  668. if !urlRegex.MatchString(icon) {
  669. return false, false, "", false, errHTTPBadRequestIconURLInvalid
  670. }
  671. m.Icon = icon
  672. }
  673. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  674. if email != "" {
  675. if err := v.EmailAllowed(); err != nil {
  676. return false, false, "", false, errHTTPTooManyRequestsLimitEmails
  677. }
  678. }
  679. if s.smtpSender == nil && email != "" {
  680. return false, false, "", false, errHTTPBadRequestEmailDisabled
  681. }
  682. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  683. if messageStr != "" {
  684. m.Message = messageStr
  685. }
  686. m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  687. if err != nil {
  688. return false, false, "", false, errHTTPBadRequestPriorityInvalid
  689. }
  690. tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
  691. if tagsStr != "" {
  692. m.Tags = make([]string, 0)
  693. for _, s := range util.SplitNoEmpty(tagsStr, ",") {
  694. m.Tags = append(m.Tags, strings.TrimSpace(s))
  695. }
  696. }
  697. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  698. if delayStr != "" {
  699. if !cache {
  700. return false, false, "", false, errHTTPBadRequestDelayNoCache
  701. }
  702. if email != "" {
  703. return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  704. }
  705. delay, err := util.ParseFutureTime(delayStr, time.Now())
  706. if err != nil {
  707. return false, false, "", false, errHTTPBadRequestDelayCannotParse
  708. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  709. return false, false, "", false, errHTTPBadRequestDelayTooSmall
  710. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  711. return false, false, "", false, errHTTPBadRequestDelayTooLarge
  712. }
  713. m.Time = delay.Unix()
  714. m.Sender = v.ip // Important for rate limiting
  715. }
  716. actionsStr := readParam(r, "x-actions", "actions", "action")
  717. if actionsStr != "" {
  718. m.Actions, err = parseActions(actionsStr)
  719. if err != nil {
  720. return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
  721. }
  722. }
  723. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  724. if unifiedpush {
  725. firebase = false
  726. unifiedpush = true
  727. }
  728. m.PollID = readParam(r, "x-poll-id", "poll-id")
  729. if m.PollID != "" {
  730. unifiedpush = false
  731. cache = false
  732. email = ""
  733. }
  734. return cache, firebase, email, unifiedpush, nil
  735. }
  736. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  737. //
  738. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  739. // If a message is flagged as poll request, the body does not matter and is discarded
  740. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  741. // If body is binary, encode as base64, if not do not encode
  742. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  743. // Body must be a message, because we attached an external URL
  744. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  745. // Body must be attachment, because we passed a filename
  746. // 5. curl -T file.txt ntfy.sh/mytopic
  747. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  748. // 6. curl -T file.txt ntfy.sh/mytopic
  749. // If file.txt is > message limit, treat it as an attachment
  750. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  751. if m.Event == pollRequestEvent { // Case 1
  752. return s.handleBodyDiscard(body)
  753. } else if unifiedpush {
  754. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  755. } else if m.Attachment != nil && m.Attachment.URL != "" {
  756. return s.handleBodyAsTextMessage(m, body) // Case 3
  757. } else if m.Attachment != nil && m.Attachment.Name != "" {
  758. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  759. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  760. return s.handleBodyAsTextMessage(m, body) // Case 5
  761. }
  762. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  763. }
  764. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  765. _, err := io.Copy(io.Discard, body)
  766. _ = body.Close()
  767. return err
  768. }
  769. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  770. if utf8.Valid(body.PeekedBytes) {
  771. m.Message = string(body.PeekedBytes) // Do not trim
  772. } else {
  773. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  774. m.Encoding = encodingBase64
  775. }
  776. return nil
  777. }
  778. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  779. if !utf8.Valid(body.PeekedBytes) {
  780. return errHTTPBadRequestMessageNotUTF8
  781. }
  782. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  783. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  784. }
  785. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  786. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  787. }
  788. return nil
  789. }
  790. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  791. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  792. return errHTTPBadRequestAttachmentsDisallowed
  793. }
  794. vinfo, err := v.Info()
  795. if err != nil {
  796. return err
  797. }
  798. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  799. if m.Time > attachmentExpiry {
  800. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
  801. }
  802. contentLengthStr := r.Header.Get("Content-Length")
  803. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  804. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  805. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  806. return errHTTPEntityTooLargeAttachment
  807. }
  808. }
  809. if m.Attachment == nil {
  810. m.Attachment = &attachment{}
  811. }
  812. var ext string
  813. m.Sender = v.ip // Important for attachment rate limiting
  814. m.Attachment.Expires = attachmentExpiry
  815. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  816. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  817. if m.Attachment.Name == "" {
  818. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  819. }
  820. if m.Message == "" {
  821. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  822. }
  823. limiters := []util.Limiter{
  824. v.BandwidthLimiter(),
  825. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  826. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  827. }
  828. fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
  829. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  830. if err == util.ErrLimitReached {
  831. return errHTTPEntityTooLargeAttachment
  832. } else if err != nil {
  833. return err
  834. }
  835. return nil
  836. }
  837. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  838. encoder := func(msg *message) (string, error) {
  839. var buf bytes.Buffer
  840. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  841. return "", err
  842. }
  843. return buf.String(), nil
  844. }
  845. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  846. }
  847. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  848. encoder := func(msg *message) (string, error) {
  849. var buf bytes.Buffer
  850. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  851. return "", err
  852. }
  853. if msg.Event != messageEvent {
  854. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  855. }
  856. return fmt.Sprintf("data: %s\n", buf.String()), nil
  857. }
  858. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  859. }
  860. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  861. encoder := func(msg *message) (string, error) {
  862. if msg.Event == messageEvent { // only handle default events
  863. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  864. }
  865. return "\n", nil // "keepalive" and "open" events just send an empty line
  866. }
  867. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  868. }
  869. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  870. log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
  871. defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
  872. if err := v.SubscriptionAllowed(); err != nil {
  873. return errHTTPTooManyRequestsLimitSubscriptions
  874. }
  875. defer v.RemoveSubscription()
  876. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  877. if err != nil {
  878. return err
  879. }
  880. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  881. if err != nil {
  882. return err
  883. }
  884. var wlock sync.Mutex
  885. defer func() {
  886. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  887. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  888. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  889. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  890. wlock.TryLock()
  891. }()
  892. sub := func(v *visitor, msg *message) error {
  893. if !filters.Pass(msg) {
  894. return nil
  895. }
  896. m, err := encoder(msg)
  897. if err != nil {
  898. return err
  899. }
  900. wlock.Lock()
  901. defer wlock.Unlock()
  902. if _, err := w.Write([]byte(m)); err != nil {
  903. return err
  904. }
  905. if fl, ok := w.(http.Flusher); ok {
  906. fl.Flush()
  907. }
  908. return nil
  909. }
  910. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  911. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  912. if poll {
  913. return s.sendOldMessages(topics, since, scheduled, v, sub)
  914. }
  915. ctx, cancel := context.WithCancel(context.Background())
  916. defer cancel()
  917. subscriberIDs := make([]int, 0)
  918. for _, t := range topics {
  919. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  920. }
  921. defer func() {
  922. for i, subscriberID := range subscriberIDs {
  923. topics[i].Unsubscribe(subscriberID) // Order!
  924. }
  925. }()
  926. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  927. return err
  928. }
  929. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  930. return err
  931. }
  932. for {
  933. select {
  934. case <-ctx.Done():
  935. return nil
  936. case <-r.Context().Done():
  937. return nil
  938. case <-time.After(s.config.KeepaliveInterval):
  939. log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r))
  940. v.Keepalive()
  941. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  942. return err
  943. }
  944. }
  945. }
  946. }
  947. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  948. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  949. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  950. }
  951. if err := v.SubscriptionAllowed(); err != nil {
  952. return errHTTPTooManyRequestsLimitSubscriptions
  953. }
  954. defer v.RemoveSubscription()
  955. log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r))
  956. defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r))
  957. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  958. if err != nil {
  959. return err
  960. }
  961. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  962. if err != nil {
  963. return err
  964. }
  965. upgrader := &websocket.Upgrader{
  966. ReadBufferSize: wsBufferSize,
  967. WriteBufferSize: wsBufferSize,
  968. CheckOrigin: func(r *http.Request) bool {
  969. return true // We're open for business!
  970. },
  971. }
  972. conn, err := upgrader.Upgrade(w, r, nil)
  973. if err != nil {
  974. return err
  975. }
  976. defer conn.Close()
  977. // Subscription connections can be canceled externally, see topic.CancelSubscribers
  978. subscriberContext, cancel := context.WithCancel(context.Background())
  979. defer cancel()
  980. // Use errgroup to run WebSocket reader and writer in Go routines
  981. var wlock sync.Mutex
  982. g, gctx := errgroup.WithContext(context.Background())
  983. g.Go(func() error {
  984. <-subscriberContext.Done()
  985. log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
  986. conn.Close()
  987. return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
  988. })
  989. g.Go(func() error {
  990. pongWait := s.config.KeepaliveInterval + wsPongWait
  991. conn.SetReadLimit(wsReadLimit)
  992. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  993. return err
  994. }
  995. conn.SetPongHandler(func(appData string) error {
  996. log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r))
  997. return conn.SetReadDeadline(time.Now().Add(pongWait))
  998. })
  999. for {
  1000. _, _, err := conn.NextReader()
  1001. if err != nil {
  1002. return err
  1003. }
  1004. select {
  1005. case <-gctx.Done():
  1006. return nil
  1007. default:
  1008. }
  1009. }
  1010. })
  1011. g.Go(func() error {
  1012. ping := func() error {
  1013. wlock.Lock()
  1014. defer wlock.Unlock()
  1015. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1016. return err
  1017. }
  1018. log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r))
  1019. return conn.WriteMessage(websocket.PingMessage, nil)
  1020. }
  1021. for {
  1022. select {
  1023. case <-gctx.Done():
  1024. return nil
  1025. case <-time.After(s.config.KeepaliveInterval):
  1026. v.Keepalive()
  1027. if err := ping(); err != nil {
  1028. return err
  1029. }
  1030. }
  1031. }
  1032. })
  1033. sub := func(v *visitor, msg *message) error {
  1034. if !filters.Pass(msg) {
  1035. return nil
  1036. }
  1037. wlock.Lock()
  1038. defer wlock.Unlock()
  1039. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1040. return err
  1041. }
  1042. return conn.WriteJSON(msg)
  1043. }
  1044. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1045. if poll {
  1046. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1047. }
  1048. subscriberIDs := make([]int, 0)
  1049. for _, t := range topics {
  1050. subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
  1051. }
  1052. defer func() {
  1053. for i, subscriberID := range subscriberIDs {
  1054. topics[i].Unsubscribe(subscriberID) // Order!
  1055. }
  1056. }()
  1057. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1058. return err
  1059. }
  1060. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1061. return err
  1062. }
  1063. err = g.Wait()
  1064. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  1065. log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error())
  1066. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1067. }
  1068. return err
  1069. }
  1070. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
  1071. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1072. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1073. since, err = parseSince(r, poll)
  1074. if err != nil {
  1075. return
  1076. }
  1077. filters, err = parseQueryFilters(r)
  1078. if err != nil {
  1079. return
  1080. }
  1081. return
  1082. }
  1083. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1084. // marker, returning only messages that are newer than the marker.
  1085. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1086. if since.IsNone() {
  1087. return nil
  1088. }
  1089. messages := make([]*message, 0)
  1090. for _, t := range topics {
  1091. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1092. if err != nil {
  1093. return err
  1094. }
  1095. messages = append(messages, topicMessages...)
  1096. }
  1097. sort.Slice(messages, func(i, j int) bool {
  1098. return messages[i].Time < messages[j].Time
  1099. })
  1100. for _, m := range messages {
  1101. if err := sub(v, m); err != nil {
  1102. return err
  1103. }
  1104. }
  1105. return nil
  1106. }
  1107. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1108. //
  1109. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1110. // "all" for all messages.
  1111. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1112. since := readParam(r, "x-since", "since", "si")
  1113. // Easy cases (empty, all, none)
  1114. if since == "" {
  1115. if poll {
  1116. return sinceAllMessages, nil
  1117. }
  1118. return sinceNoMessages, nil
  1119. } else if since == "all" {
  1120. return sinceAllMessages, nil
  1121. } else if since == "none" {
  1122. return sinceNoMessages, nil
  1123. }
  1124. // ID, timestamp, duration
  1125. if validMessageID(since) {
  1126. return newSinceID(since), nil
  1127. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1128. return newSinceTime(s), nil
  1129. } else if d, err := time.ParseDuration(since); err == nil {
  1130. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1131. }
  1132. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1133. }
  1134. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1135. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1136. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1137. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1138. return nil
  1139. }
  1140. func (s *Server) topicFromPath(path string) (*topic, error) {
  1141. parts := strings.Split(path, "/")
  1142. if len(parts) < 2 {
  1143. return nil, errHTTPBadRequestTopicInvalid
  1144. }
  1145. return s.topicFromID(parts[1])
  1146. }
  1147. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1148. parts := strings.Split(path, "/")
  1149. if len(parts) < 2 {
  1150. return nil, "", errHTTPBadRequestTopicInvalid
  1151. }
  1152. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1153. topics, err := s.topicsFromIDs(topicIDs...)
  1154. if err != nil {
  1155. return nil, "", errHTTPBadRequestTopicInvalid
  1156. }
  1157. return topics, parts[1], nil
  1158. }
  1159. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1160. s.mu.Lock()
  1161. defer s.mu.Unlock()
  1162. topics := make([]*topic, 0)
  1163. for _, id := range ids {
  1164. if util.Contains(disallowedTopics, id) {
  1165. return nil, errHTTPBadRequestTopicDisallowed
  1166. }
  1167. if _, ok := s.topics[id]; !ok {
  1168. if len(s.topics) >= s.config.TotalTopicLimit {
  1169. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1170. }
  1171. s.topics[id] = newTopic(id)
  1172. }
  1173. topics = append(topics, s.topics[id])
  1174. }
  1175. return topics, nil
  1176. }
  1177. func (s *Server) topicFromID(id string) (*topic, error) {
  1178. topics, err := s.topicsFromIDs(id)
  1179. if err != nil {
  1180. return nil, err
  1181. }
  1182. return topics[0], nil
  1183. }
  1184. func (s *Server) execManager() {
  1185. log.Debug("Manager: Starting")
  1186. defer log.Debug("Manager: Finished")
  1187. // WARNING: Make sure to only selectively lock with the mutex, and be aware that this
  1188. // there is no mutex for the entire function.
  1189. // Expire visitors from rate visitors map
  1190. s.mu.Lock()
  1191. staleVisitors := 0
  1192. for ip, v := range s.visitors {
  1193. if v.Stale() {
  1194. log.Trace("Deleting stale visitor %s", v.ip)
  1195. delete(s.visitors, ip)
  1196. staleVisitors++
  1197. }
  1198. }
  1199. s.mu.Unlock()
  1200. log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
  1201. // Delete expired user tokens and users
  1202. if s.userManager != nil {
  1203. if err := s.userManager.RemoveExpiredTokens(); err != nil {
  1204. log.Warn("Error expiring user tokens: %s", err.Error())
  1205. }
  1206. if err := s.userManager.RemoveDeletedUsers(); err != nil {
  1207. log.Warn("Error deleting soft-deleted users: %s", err.Error())
  1208. }
  1209. }
  1210. // Delete expired attachments
  1211. if s.fileCache != nil {
  1212. ids, err := s.messageCache.AttachmentsExpired()
  1213. if err != nil {
  1214. log.Warn("Manager: Error retrieving expired attachments: %s", err.Error())
  1215. } else if len(ids) > 0 {
  1216. if log.IsDebug() {
  1217. log.Debug("Manager: Deleting attachments %s", strings.Join(ids, ", "))
  1218. }
  1219. if err := s.fileCache.Remove(ids...); err != nil {
  1220. log.Warn("Manager: Error deleting attachments: %s", err.Error())
  1221. }
  1222. if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
  1223. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1224. }
  1225. } else {
  1226. log.Debug("Manager: No expired attachments to delete")
  1227. }
  1228. }
  1229. // Prune messages
  1230. log.Debug("Manager: Pruning messages")
  1231. expiredMessageIDs, err := s.messageCache.MessagesExpired()
  1232. if err != nil {
  1233. log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
  1234. } else if len(expiredMessageIDs) > 0 {
  1235. if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
  1236. log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
  1237. }
  1238. if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
  1239. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1240. }
  1241. } else {
  1242. log.Debug("Manager: No expired messages to delete")
  1243. }
  1244. // Message count per topic
  1245. var messages int
  1246. messageCounts, err := s.messageCache.MessageCounts()
  1247. if err != nil {
  1248. log.Warn("Manager: Cannot get message counts: %s", err.Error())
  1249. messageCounts = make(map[string]int) // Empty, so we can continue
  1250. }
  1251. for _, count := range messageCounts {
  1252. messages += count
  1253. }
  1254. // Remove subscriptions without subscribers
  1255. s.mu.Lock()
  1256. var subscribers int
  1257. for _, t := range s.topics {
  1258. subs := t.SubscribersCount()
  1259. msgs, exists := messageCounts[t.ID]
  1260. if subs == 0 && (!exists || msgs == 0) {
  1261. log.Trace("Deleting empty topic %s", t.ID)
  1262. delete(s.topics, t.ID)
  1263. continue
  1264. }
  1265. subscribers += subs
  1266. }
  1267. s.mu.Unlock()
  1268. // Mail stats
  1269. var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
  1270. if s.smtpServerBackend != nil {
  1271. receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
  1272. }
  1273. var sentMailTotal, sentMailSuccess, sentMailFailure int64
  1274. if s.smtpSender != nil {
  1275. sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
  1276. }
  1277. // Print stats
  1278. s.mu.Lock()
  1279. messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
  1280. s.mu.Unlock()
  1281. log.Info("Stats: %d messages published, %d in cache, %d topic(s) active, %d subscriber(s), %d visitor(s), %d mails received (%d successful, %d failed), %d mails sent (%d successful, %d failed)",
  1282. messagesCount, messages, topicsCount, subscribers, visitorsCount,
  1283. receivedMailTotal, receivedMailSuccess, receivedMailFailure,
  1284. sentMailTotal, sentMailSuccess, sentMailFailure)
  1285. }
  1286. func (s *Server) runSMTPServer() error {
  1287. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1288. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1289. s.smtpServer.Addr = s.config.SMTPServerListen
  1290. s.smtpServer.Domain = s.config.SMTPServerDomain
  1291. s.smtpServer.ReadTimeout = 10 * time.Second
  1292. s.smtpServer.WriteTimeout = 10 * time.Second
  1293. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1294. s.smtpServer.MaxRecipients = 1
  1295. s.smtpServer.AllowInsecureAuth = true
  1296. return s.smtpServer.ListenAndServe()
  1297. }
  1298. func (s *Server) runManager() {
  1299. for {
  1300. select {
  1301. case <-time.After(s.config.ManagerInterval):
  1302. s.execManager()
  1303. case <-s.closeChan:
  1304. return
  1305. }
  1306. }
  1307. }
  1308. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1309. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1310. func (s *Server) runStatsResetter() {
  1311. for {
  1312. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1313. timer := time.NewTimer(time.Until(runAt))
  1314. log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
  1315. select {
  1316. case <-timer.C:
  1317. s.resetStats()
  1318. case <-s.closeChan:
  1319. timer.Stop()
  1320. return
  1321. }
  1322. }
  1323. }
  1324. func (s *Server) resetStats() {
  1325. log.Info("Resetting all visitor stats (daily task)")
  1326. s.mu.Lock()
  1327. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1328. for _, v := range s.visitors {
  1329. v.ResetStats()
  1330. }
  1331. if s.userManager != nil {
  1332. if err := s.userManager.ResetStats(); err != nil {
  1333. log.Warn("Failed to write to database: %s", err.Error())
  1334. }
  1335. }
  1336. }
  1337. func (s *Server) runFirebaseKeepaliver() {
  1338. if s.firebaseClient == nil {
  1339. return
  1340. }
  1341. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1342. for {
  1343. select {
  1344. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1345. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1346. case <-time.After(s.config.FirebasePollInterval):
  1347. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1348. case <-s.closeChan:
  1349. return
  1350. }
  1351. }
  1352. }
  1353. func (s *Server) runDelayedSender() {
  1354. for {
  1355. select {
  1356. case <-time.After(s.config.DelayedSenderInterval):
  1357. if err := s.sendDelayedMessages(); err != nil {
  1358. log.Warn("Error sending delayed messages: %s", err.Error())
  1359. }
  1360. case <-s.closeChan:
  1361. return
  1362. }
  1363. }
  1364. }
  1365. func (s *Server) sendDelayedMessages() error {
  1366. messages, err := s.messageCache.MessagesDue()
  1367. if err != nil {
  1368. return err
  1369. }
  1370. for _, m := range messages {
  1371. var v *visitor
  1372. if s.userManager != nil && m.User != "" {
  1373. u, err := s.userManager.User(m.User)
  1374. if err != nil {
  1375. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1376. continue
  1377. }
  1378. v = s.visitorFromUser(u, m.Sender)
  1379. } else {
  1380. v = s.visitorFromIP(m.Sender)
  1381. }
  1382. if err := s.sendDelayedMessage(v, m); err != nil {
  1383. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1384. }
  1385. }
  1386. return nil
  1387. }
  1388. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1389. log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
  1390. s.mu.Lock()
  1391. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1392. s.mu.Unlock()
  1393. if ok {
  1394. go func() {
  1395. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1396. if err := t.Publish(v, m); err != nil {
  1397. log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error())
  1398. }
  1399. }()
  1400. }
  1401. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1402. go s.sendToFirebase(v, m)
  1403. }
  1404. if s.config.UpstreamBaseURL != "" {
  1405. go s.forwardPollRequest(v, m)
  1406. }
  1407. if err := s.messageCache.MarkPublished(m); err != nil {
  1408. return err
  1409. }
  1410. return nil
  1411. }
  1412. func (s *Server) limitRequests(next handleFunc) handleFunc {
  1413. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1414. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  1415. return next(w, r, v)
  1416. } else if err := v.RequestAllowed(); err != nil {
  1417. return errHTTPTooManyRequestsLimitRequests
  1418. }
  1419. return next(w, r, v)
  1420. }
  1421. }
  1422. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1423. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1424. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1425. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1426. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
  1427. if err != nil {
  1428. return err
  1429. }
  1430. if !topicRegex.MatchString(m.Topic) {
  1431. return errHTTPBadRequestTopicInvalid
  1432. }
  1433. if m.Message == "" {
  1434. m.Message = emptyMessageBody
  1435. }
  1436. r.URL.Path = "/" + m.Topic
  1437. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1438. if m.Title != "" {
  1439. r.Header.Set("X-Title", m.Title)
  1440. }
  1441. if m.Priority != 0 {
  1442. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1443. }
  1444. if m.Tags != nil && len(m.Tags) > 0 {
  1445. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1446. }
  1447. if m.Attach != "" {
  1448. r.Header.Set("X-Attach", m.Attach)
  1449. }
  1450. if m.Filename != "" {
  1451. r.Header.Set("X-Filename", m.Filename)
  1452. }
  1453. if m.Click != "" {
  1454. r.Header.Set("X-Click", m.Click)
  1455. }
  1456. if m.Icon != "" {
  1457. r.Header.Set("X-Icon", m.Icon)
  1458. }
  1459. if len(m.Actions) > 0 {
  1460. actionsStr, err := json.Marshal(m.Actions)
  1461. if err != nil {
  1462. return errHTTPBadRequestMessageJSONInvalid
  1463. }
  1464. r.Header.Set("X-Actions", string(actionsStr))
  1465. }
  1466. if m.Email != "" {
  1467. r.Header.Set("X-Email", m.Email)
  1468. }
  1469. if m.Delay != "" {
  1470. r.Header.Set("X-Delay", m.Delay)
  1471. }
  1472. return next(w, r, v)
  1473. }
  1474. }
  1475. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1476. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1477. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1478. if err != nil {
  1479. return err
  1480. }
  1481. if err := next(w, newRequest, v); err != nil {
  1482. return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
  1483. }
  1484. return nil
  1485. }
  1486. }
  1487. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1488. return s.autorizeTopic(next, user.PermissionWrite)
  1489. }
  1490. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1491. return s.autorizeTopic(next, user.PermissionRead)
  1492. }
  1493. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1494. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1495. if s.userManager == nil {
  1496. return next(w, r, v)
  1497. }
  1498. topics, _, err := s.topicsFromPath(r.URL.Path)
  1499. if err != nil {
  1500. return err
  1501. }
  1502. for _, t := range topics {
  1503. if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
  1504. log.Info("unauthorized: %s", err.Error())
  1505. return errHTTPForbidden
  1506. }
  1507. }
  1508. return next(w, r, v)
  1509. }
  1510. }
  1511. // visitor creates or retrieves a rate.Limiter for the given visitor.
  1512. // Note that this function will always return a visitor, even if an error occurs.
  1513. func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
  1514. ip := extractIPAddress(r, s.config.BehindProxy)
  1515. var u *user.User // may stay nil if no auth header!
  1516. if u, err = s.authenticate(r); err != nil {
  1517. log.Debug("authentication failed: %s", err.Error())
  1518. err = errHTTPUnauthorized // Always return visitor, even when error occurs!
  1519. }
  1520. if u != nil {
  1521. v = s.visitorFromUser(u, ip)
  1522. } else {
  1523. v = s.visitorFromIP(ip)
  1524. }
  1525. v.mu.Lock()
  1526. v.user = u
  1527. v.mu.Unlock()
  1528. return v, err // Always return visitor, even when error occurs!
  1529. }
  1530. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1531. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1532. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1533. // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
  1534. func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
  1535. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1536. queryParam := readQueryParam(r, "authorization", "auth")
  1537. if queryParam != "" {
  1538. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1539. if err != nil {
  1540. return nil, err
  1541. }
  1542. value = strings.TrimSpace(string(a))
  1543. }
  1544. if value == "" {
  1545. return nil, nil
  1546. } else if s.userManager == nil {
  1547. return nil, errHTTPUnauthorized
  1548. }
  1549. if strings.HasPrefix(value, "Bearer") {
  1550. return s.authenticateBearerAuth(value)
  1551. }
  1552. return s.authenticateBasicAuth(r, value)
  1553. }
  1554. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1555. r.Header.Set("Authorization", value)
  1556. username, password, ok := r.BasicAuth()
  1557. if !ok {
  1558. return nil, errors.New("invalid basic auth")
  1559. }
  1560. return s.userManager.Authenticate(username, password)
  1561. }
  1562. func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
  1563. token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
  1564. return s.userManager.AuthenticateToken(token)
  1565. }
  1566. func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
  1567. s.mu.Lock()
  1568. defer s.mu.Unlock()
  1569. v, exists := s.visitors[visitorID]
  1570. if !exists {
  1571. s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1572. return s.visitors[visitorID]
  1573. }
  1574. v.Keepalive()
  1575. return v
  1576. }
  1577. func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
  1578. return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
  1579. }
  1580. func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
  1581. return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
  1582. }
  1583. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1584. w.Header().Set("Content-Type", "application/json")
  1585. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1586. if err := json.NewEncoder(w).Encode(v); err != nil {
  1587. return err
  1588. }
  1589. return nil
  1590. }