server.go 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "heckel.io/ntfy/user"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "heckel.io/ntfy/log"
  28. "github.com/emersion/go-smtp"
  29. "github.com/gorilla/websocket"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/util"
  32. )
  33. /*
  34. TODO
  35. --
  36. - Reservation: Kill existing subscribers when topic is reserved (deadcade)
  37. - Rate limiting: Sensitive endpoints (account/login/change-password/...)
  38. - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
  39. - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
  40. - Reservation icons (UI)
  41. races:
  42. - v.user --> see publishSyncEventAsync() test
  43. payments:
  44. - reconciliation
  45. delete messages + reserved topics on ResetTier delete attachments in access.go
  46. Limits & rate limiting:
  47. rate limiting weirdness. wth is going on?
  48. bandwidth limit must be in tier
  49. users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
  50. when ResetStats() is run, reset messagesLimiter (and others)?
  51. Delete visitor when tier is changed to refresh rate limiters
  52. Make sure account endpoints make sense for admins
  53. UI:
  54. -
  55. - reservation table delete button: dialog "keep or delete messages?"
  56. - flicker of upgrade banner
  57. - JS constants
  58. Sync:
  59. - sync problems with "deleteAfter=0" and "displayName="
  60. Tests:
  61. - Payment endpoints (make mocks)
  62. - Message rate limiting and reset tests
  63. - Bandwidth limit test
  64. - test that the visitor is based on the IP address when a user has no tier
  65. */
  66. // Server is the main server, providing the UI and API for ntfy
  67. type Server struct {
  68. config *Config
  69. httpServer *http.Server
  70. httpsServer *http.Server
  71. unixListener net.Listener
  72. smtpServer *smtp.Server
  73. smtpServerBackend *smtpBackend
  74. smtpSender mailer
  75. topics map[string]*topic
  76. visitors map[string]*visitor // ip:<ip> or user:<user>
  77. firebaseClient *firebaseClient
  78. messages int64
  79. userManager *user.Manager // Might be nil!
  80. messageCache *messageCache // Database that stores the messages
  81. fileCache *fileCache // File system based cache that stores attachments
  82. stripe stripeAPI // Stripe API, can be replaced with a mock
  83. priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
  84. closeChan chan bool
  85. mu sync.Mutex
  86. }
  87. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  88. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  89. var (
  90. // If changed, don't forget to update Android App and auth_sqlite.go
  91. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  92. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  93. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  94. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  95. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  96. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  97. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  98. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  99. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  100. webConfigPath = "/config.js"
  101. accountPath = "/account"
  102. matrixPushPath = "/_matrix/push/v1/notify"
  103. apiHealthPath = "/v1/health"
  104. apiTiers = "/v1/tiers"
  105. apiAccountPath = "/v1/account"
  106. apiAccountTokenPath = "/v1/account/token"
  107. apiAccountPasswordPath = "/v1/account/password"
  108. apiAccountSettingsPath = "/v1/account/settings"
  109. apiAccountSubscriptionPath = "/v1/account/subscription"
  110. apiAccountReservationPath = "/v1/account/reservation"
  111. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  112. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  113. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  114. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  115. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  116. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  117. apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
  118. staticRegex = regexp.MustCompile(`^/static/.+`)
  119. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  120. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  121. disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
  122. urlRegex = regexp.MustCompile(`^https?://`)
  123. //go:embed site
  124. webFs embed.FS
  125. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  126. webSiteDir = "/site"
  127. webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
  128. webAppIndex = "/app.html" // React app
  129. //go:embed docs
  130. docsStaticFs embed.FS
  131. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  132. )
  133. const (
  134. firebaseControlTopic = "~control" // See Android if changed
  135. firebasePollTopic = "~poll" // See iOS if changed
  136. emptyMessageBody = "triggered" // Used if message body is empty
  137. newMessageBody = "New message" // Used in poll requests as generic message
  138. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  139. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  140. jsonBodyBytesLimit = 16384
  141. )
  142. // WebSocket constants
  143. const (
  144. wsWriteWait = 2 * time.Second
  145. wsBufferSize = 1024
  146. wsReadLimit = 64 // We only ever receive PINGs
  147. wsPongWait = 15 * time.Second
  148. )
  149. // New instantiates a new Server. It creates the cache and adds a Firebase
  150. // subscriber (if configured).
  151. func New(conf *Config) (*Server, error) {
  152. var mailer mailer
  153. if conf.SMTPSenderAddr != "" {
  154. mailer = &smtpSender{config: conf}
  155. }
  156. var stripe stripeAPI
  157. if conf.StripeSecretKey != "" {
  158. stripe = newStripeAPI()
  159. }
  160. messageCache, err := createMessageCache(conf)
  161. if err != nil {
  162. return nil, err
  163. }
  164. topics, err := messageCache.Topics()
  165. if err != nil {
  166. return nil, err
  167. }
  168. var fileCache *fileCache
  169. if conf.AttachmentCacheDir != "" {
  170. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  171. if err != nil {
  172. return nil, err
  173. }
  174. }
  175. var userManager *user.Manager
  176. if conf.AuthFile != "" {
  177. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
  178. if err != nil {
  179. return nil, err
  180. }
  181. }
  182. var firebaseClient *firebaseClient
  183. if conf.FirebaseKeyFile != "" {
  184. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  185. if err != nil {
  186. return nil, err
  187. }
  188. firebaseClient = newFirebaseClient(sender, userManager)
  189. }
  190. s := &Server{
  191. config: conf,
  192. messageCache: messageCache,
  193. fileCache: fileCache,
  194. firebaseClient: firebaseClient,
  195. smtpSender: mailer,
  196. topics: topics,
  197. userManager: userManager,
  198. visitors: make(map[string]*visitor),
  199. stripe: stripe,
  200. }
  201. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  202. return s, nil
  203. }
  204. func createMessageCache(conf *Config) (*messageCache, error) {
  205. if conf.CacheDuration == 0 {
  206. return newNopCache()
  207. } else if conf.CacheFile != "" {
  208. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  209. }
  210. return newMemCache()
  211. }
  212. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  213. // a manager go routine to print stats and prune messages.
  214. func (s *Server) Run() error {
  215. var listenStr string
  216. if s.config.ListenHTTP != "" {
  217. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  218. }
  219. if s.config.ListenHTTPS != "" {
  220. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  221. }
  222. if s.config.ListenUnix != "" {
  223. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  224. }
  225. if s.config.SMTPServerListen != "" {
  226. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  227. }
  228. log.Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  229. mux := http.NewServeMux()
  230. mux.HandleFunc("/", s.handle)
  231. errChan := make(chan error)
  232. s.mu.Lock()
  233. s.closeChan = make(chan bool)
  234. if s.config.ListenHTTP != "" {
  235. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  236. go func() {
  237. errChan <- s.httpServer.ListenAndServe()
  238. }()
  239. }
  240. if s.config.ListenHTTPS != "" {
  241. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  242. go func() {
  243. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  244. }()
  245. }
  246. if s.config.ListenUnix != "" {
  247. go func() {
  248. var err error
  249. s.mu.Lock()
  250. os.Remove(s.config.ListenUnix)
  251. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  252. if err != nil {
  253. s.mu.Unlock()
  254. errChan <- err
  255. return
  256. }
  257. defer s.unixListener.Close()
  258. if s.config.ListenUnixMode > 0 {
  259. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  260. s.mu.Unlock()
  261. errChan <- err
  262. return
  263. }
  264. }
  265. s.mu.Unlock()
  266. httpServer := &http.Server{Handler: mux}
  267. errChan <- httpServer.Serve(s.unixListener)
  268. }()
  269. }
  270. if s.config.SMTPServerListen != "" {
  271. go func() {
  272. errChan <- s.runSMTPServer()
  273. }()
  274. }
  275. s.mu.Unlock()
  276. go s.runManager()
  277. go s.runStatsResetter()
  278. go s.runDelayedSender()
  279. go s.runFirebaseKeepaliver()
  280. return <-errChan
  281. }
  282. // Stop stops HTTP (+HTTPS) server and all managers
  283. func (s *Server) Stop() {
  284. s.mu.Lock()
  285. defer s.mu.Unlock()
  286. if s.httpServer != nil {
  287. s.httpServer.Close()
  288. }
  289. if s.httpsServer != nil {
  290. s.httpsServer.Close()
  291. }
  292. if s.unixListener != nil {
  293. s.unixListener.Close()
  294. }
  295. if s.smtpServer != nil {
  296. s.smtpServer.Close()
  297. }
  298. close(s.closeChan)
  299. }
  300. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  301. v, err := s.visitor(r) // Note: Always returns v, even when error is returned
  302. if err == nil {
  303. log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
  304. if log.IsTrace() {
  305. log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
  306. }
  307. err = s.handleInternal(w, r, v)
  308. }
  309. if err != nil {
  310. if websocket.IsWebSocketUpgrade(r) {
  311. isNormalError := strings.Contains(err.Error(), "i/o timeout")
  312. if isNormalError {
  313. log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error())
  314. } else {
  315. log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error())
  316. }
  317. return // Do not attempt to write to upgraded connection
  318. }
  319. if matrixErr, ok := err.(*errMatrix); ok {
  320. writeMatrixError(w, r, v, matrixErr)
  321. return
  322. }
  323. httpErr, ok := err.(*errHTTP)
  324. if !ok {
  325. httpErr = errHTTPInternalError
  326. }
  327. isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest
  328. if isNormalError {
  329. log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  330. } else {
  331. log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  332. }
  333. w.Header().Set("Content-Type", "application/json")
  334. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  335. w.WriteHeader(httpErr.HTTPCode)
  336. io.WriteString(w, httpErr.JSON()+"\n")
  337. }
  338. }
  339. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  340. if r.Method == http.MethodGet && r.URL.Path == "/" {
  341. return s.ensureWebEnabled(s.handleHome)(w, r, v)
  342. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  343. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  344. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  345. return s.handleHealth(w, r, v)
  346. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  347. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  348. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  349. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  350. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  351. return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
  352. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  353. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  354. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  355. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  356. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  357. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  358. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  359. return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
  360. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  361. return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
  362. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  363. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  364. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  365. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  366. } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  367. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  368. } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  369. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  370. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  371. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  372. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  373. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  374. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  375. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  376. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  377. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  378. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  379. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  380. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  381. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  382. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  383. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  384. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  385. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  386. } else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
  387. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  388. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  389. return s.handleMatrixDiscovery(w)
  390. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  391. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  392. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  393. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  394. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  395. return s.limitRequests(s.handleFile)(w, r, v)
  396. } else if r.Method == http.MethodOptions {
  397. return s.ensureWebEnabled(s.handleOptions)(w, r, v)
  398. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  399. return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  400. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  401. return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  402. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  403. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  404. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  405. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  406. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  407. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  408. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  409. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  410. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  411. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  412. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  413. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  414. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  415. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  416. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  417. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  418. }
  419. return errHTTPNotFound
  420. }
  421. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
  422. if s.config.WebRootIsApp {
  423. r.URL.Path = webAppIndex
  424. } else {
  425. r.URL.Path = webHomeIndex
  426. }
  427. return s.handleStatic(w, r, v)
  428. }
  429. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  430. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  431. if unifiedpush {
  432. w.Header().Set("Content-Type", "application/json")
  433. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  434. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  435. return err
  436. }
  437. r.URL.Path = webAppIndex
  438. return s.handleStatic(w, r, v)
  439. }
  440. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  441. return nil
  442. }
  443. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  444. return s.writeJSON(w, newSuccessResponse())
  445. }
  446. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  447. response := &apiHealthResponse{
  448. Healthy: true,
  449. }
  450. return s.writeJSON(w, response)
  451. }
  452. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  453. appRoot := "/"
  454. if !s.config.WebRootIsApp {
  455. appRoot = "/app"
  456. }
  457. response := &apiConfigResponse{
  458. BaseURL: "", // Will translate to window.location.origin
  459. AppRoot: appRoot,
  460. EnableLogin: s.config.EnableLogin,
  461. EnableSignup: s.config.EnableSignup,
  462. EnablePayments: s.config.StripeSecretKey != "",
  463. EnableReservations: s.config.EnableReservations,
  464. DisallowedTopics: disallowedTopics,
  465. }
  466. b, err := json.MarshalIndent(response, "", " ")
  467. if err != nil {
  468. return err
  469. }
  470. w.Header().Set("Content-Type", "text/javascript")
  471. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  472. return err
  473. }
  474. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  475. r.URL.Path = webSiteDir + r.URL.Path
  476. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  477. return nil
  478. }
  479. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  480. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  481. return nil
  482. }
  483. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  484. if s.config.AttachmentCacheDir == "" {
  485. return errHTTPInternalError
  486. }
  487. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  488. if len(matches) != 2 {
  489. return errHTTPInternalErrorInvalidPath
  490. }
  491. messageID := matches[1]
  492. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  493. stat, err := os.Stat(file)
  494. if err != nil {
  495. return errHTTPNotFound
  496. }
  497. if r.Method == http.MethodGet {
  498. if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
  499. return errHTTPTooManyRequestsLimitAttachmentBandwidth
  500. }
  501. }
  502. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  503. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  504. if r.Method == http.MethodGet {
  505. f, err := os.Open(file)
  506. if err != nil {
  507. return err
  508. }
  509. defer f.Close()
  510. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  511. return err
  512. }
  513. return nil
  514. }
  515. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  516. if s.config.BaseURL == "" {
  517. return errHTTPInternalErrorMissingBaseURL
  518. }
  519. return writeMatrixDiscoveryResponse(w)
  520. }
  521. func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
  522. t, err := s.topicFromPath(r.URL.Path)
  523. if err != nil {
  524. return nil, err
  525. }
  526. if err := v.MessageAllowed(); err != nil {
  527. return nil, errHTTPTooManyRequestsLimitMessages
  528. }
  529. body, err := util.Peek(r.Body, s.config.MessageLimit)
  530. if err != nil {
  531. return nil, err
  532. }
  533. m := newDefaultMessage(t.ID, "")
  534. cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
  535. if err != nil {
  536. return nil, err
  537. }
  538. if m.PollID != "" {
  539. m = newPollRequestMessage(t.ID, m.PollID)
  540. }
  541. if v.user != nil {
  542. m.User = v.user.ID
  543. }
  544. m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
  545. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  546. return nil, err
  547. }
  548. if m.Message == "" {
  549. m.Message = emptyMessageBody
  550. }
  551. delayed := m.Time > time.Now().Unix()
  552. log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
  553. logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
  554. if log.IsTrace() {
  555. log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
  556. }
  557. if !delayed {
  558. if err := t.Publish(v, m); err != nil {
  559. return nil, err
  560. }
  561. if s.firebaseClient != nil && firebase {
  562. go s.sendToFirebase(v, m)
  563. }
  564. if s.smtpSender != nil && email != "" {
  565. v.IncrementEmails()
  566. go s.sendEmail(v, m, email)
  567. }
  568. if s.config.UpstreamBaseURL != "" {
  569. go s.forwardPollRequest(v, m)
  570. }
  571. } else {
  572. log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
  573. }
  574. if cache {
  575. log.Debug("%s Adding message to cache", logMessagePrefix(v, m))
  576. if err := s.messageCache.AddMessage(m); err != nil {
  577. return nil, err
  578. }
  579. }
  580. v.IncrementMessages()
  581. if s.userManager != nil && v.user != nil {
  582. s.userManager.EnqueueStats(v.user)
  583. }
  584. s.mu.Lock()
  585. s.messages++
  586. s.mu.Unlock()
  587. return m, nil
  588. }
  589. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  590. m, err := s.handlePublishWithoutResponse(r, v)
  591. if err != nil {
  592. return err
  593. }
  594. return s.writeJSON(w, m)
  595. }
  596. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  597. _, err := s.handlePublishWithoutResponse(r, v)
  598. if err != nil {
  599. return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
  600. }
  601. return writeMatrixSuccess(w)
  602. }
  603. func (s *Server) sendToFirebase(v *visitor, m *message) {
  604. log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m))
  605. if err := s.firebaseClient.Send(v, m); err != nil {
  606. if err == errFirebaseTemporarilyBanned {
  607. log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  608. } else {
  609. log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  610. }
  611. }
  612. }
  613. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  614. log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email)
  615. if err := s.smtpSender.Send(v, m, email); err != nil {
  616. log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error())
  617. }
  618. }
  619. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  620. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  621. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  622. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  623. log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL)
  624. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  625. if err != nil {
  626. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  627. return
  628. }
  629. req.Header.Set("X-Poll-ID", m.ID)
  630. var httpClient = &http.Client{
  631. Timeout: time.Second * 10,
  632. }
  633. response, err := httpClient.Do(req)
  634. if err != nil {
  635. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  636. return
  637. } else if response.StatusCode != http.StatusOK {
  638. log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode)
  639. return
  640. }
  641. }
  642. func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
  643. cache = readBoolParam(r, true, "x-cache", "cache")
  644. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  645. m.Title = readParam(r, "x-title", "title", "t")
  646. m.Click = readParam(r, "x-click", "click")
  647. icon := readParam(r, "x-icon", "icon")
  648. filename := readParam(r, "x-filename", "filename", "file", "f")
  649. attach := readParam(r, "x-attach", "attach", "a")
  650. if attach != "" || filename != "" {
  651. m.Attachment = &attachment{}
  652. }
  653. if filename != "" {
  654. m.Attachment.Name = filename
  655. }
  656. if attach != "" {
  657. if !urlRegex.MatchString(attach) {
  658. return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
  659. }
  660. m.Attachment.URL = attach
  661. if m.Attachment.Name == "" {
  662. u, err := url.Parse(m.Attachment.URL)
  663. if err == nil {
  664. m.Attachment.Name = path.Base(u.Path)
  665. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  666. m.Attachment.Name = ""
  667. }
  668. }
  669. }
  670. if m.Attachment.Name == "" {
  671. m.Attachment.Name = "attachment"
  672. }
  673. }
  674. if icon != "" {
  675. if !urlRegex.MatchString(icon) {
  676. return false, false, "", false, errHTTPBadRequestIconURLInvalid
  677. }
  678. m.Icon = icon
  679. }
  680. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  681. if email != "" {
  682. if err := v.EmailAllowed(); err != nil {
  683. return false, false, "", false, errHTTPTooManyRequestsLimitEmails
  684. }
  685. }
  686. if s.smtpSender == nil && email != "" {
  687. return false, false, "", false, errHTTPBadRequestEmailDisabled
  688. }
  689. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  690. if messageStr != "" {
  691. m.Message = messageStr
  692. }
  693. m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  694. if err != nil {
  695. return false, false, "", false, errHTTPBadRequestPriorityInvalid
  696. }
  697. tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
  698. if tagsStr != "" {
  699. m.Tags = make([]string, 0)
  700. for _, s := range util.SplitNoEmpty(tagsStr, ",") {
  701. m.Tags = append(m.Tags, strings.TrimSpace(s))
  702. }
  703. }
  704. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  705. if delayStr != "" {
  706. if !cache {
  707. return false, false, "", false, errHTTPBadRequestDelayNoCache
  708. }
  709. if email != "" {
  710. return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  711. }
  712. delay, err := util.ParseFutureTime(delayStr, time.Now())
  713. if err != nil {
  714. return false, false, "", false, errHTTPBadRequestDelayCannotParse
  715. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  716. return false, false, "", false, errHTTPBadRequestDelayTooSmall
  717. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  718. return false, false, "", false, errHTTPBadRequestDelayTooLarge
  719. }
  720. m.Time = delay.Unix()
  721. m.Sender = v.ip // Important for rate limiting
  722. }
  723. actionsStr := readParam(r, "x-actions", "actions", "action")
  724. if actionsStr != "" {
  725. m.Actions, err = parseActions(actionsStr)
  726. if err != nil {
  727. return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
  728. }
  729. }
  730. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  731. if unifiedpush {
  732. firebase = false
  733. unifiedpush = true
  734. }
  735. m.PollID = readParam(r, "x-poll-id", "poll-id")
  736. if m.PollID != "" {
  737. unifiedpush = false
  738. cache = false
  739. email = ""
  740. }
  741. return cache, firebase, email, unifiedpush, nil
  742. }
  743. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  744. //
  745. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  746. // If a message is flagged as poll request, the body does not matter and is discarded
  747. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  748. // If body is binary, encode as base64, if not do not encode
  749. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  750. // Body must be a message, because we attached an external URL
  751. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  752. // Body must be attachment, because we passed a filename
  753. // 5. curl -T file.txt ntfy.sh/mytopic
  754. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  755. // 6. curl -T file.txt ntfy.sh/mytopic
  756. // If file.txt is > message limit, treat it as an attachment
  757. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  758. if m.Event == pollRequestEvent { // Case 1
  759. return s.handleBodyDiscard(body)
  760. } else if unifiedpush {
  761. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  762. } else if m.Attachment != nil && m.Attachment.URL != "" {
  763. return s.handleBodyAsTextMessage(m, body) // Case 3
  764. } else if m.Attachment != nil && m.Attachment.Name != "" {
  765. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  766. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  767. return s.handleBodyAsTextMessage(m, body) // Case 5
  768. }
  769. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  770. }
  771. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  772. _, err := io.Copy(io.Discard, body)
  773. _ = body.Close()
  774. return err
  775. }
  776. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  777. if utf8.Valid(body.PeekedBytes) {
  778. m.Message = string(body.PeekedBytes) // Do not trim
  779. } else {
  780. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  781. m.Encoding = encodingBase64
  782. }
  783. return nil
  784. }
  785. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  786. if !utf8.Valid(body.PeekedBytes) {
  787. return errHTTPBadRequestMessageNotUTF8
  788. }
  789. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  790. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  791. }
  792. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  793. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  794. }
  795. return nil
  796. }
  797. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  798. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  799. return errHTTPBadRequestAttachmentsDisallowed
  800. }
  801. vinfo, err := v.Info()
  802. if err != nil {
  803. return err
  804. }
  805. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  806. if m.Time > attachmentExpiry {
  807. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
  808. }
  809. contentLengthStr := r.Header.Get("Content-Length")
  810. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  811. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  812. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  813. return errHTTPEntityTooLargeAttachment
  814. }
  815. }
  816. if m.Attachment == nil {
  817. m.Attachment = &attachment{}
  818. }
  819. var ext string
  820. m.Sender = v.ip // Important for attachment rate limiting
  821. m.Attachment.Expires = attachmentExpiry
  822. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  823. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  824. if m.Attachment.Name == "" {
  825. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  826. }
  827. if m.Message == "" {
  828. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  829. }
  830. limiters := []util.Limiter{
  831. v.BandwidthLimiter(),
  832. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  833. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  834. }
  835. fmt.Printf("limiters = %#v\nv = %#v\n", limiters, v)
  836. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  837. if err == util.ErrLimitReached {
  838. return errHTTPEntityTooLargeAttachment
  839. } else if err != nil {
  840. return err
  841. }
  842. return nil
  843. }
  844. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  845. encoder := func(msg *message) (string, error) {
  846. var buf bytes.Buffer
  847. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  848. return "", err
  849. }
  850. return buf.String(), nil
  851. }
  852. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  853. }
  854. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  855. encoder := func(msg *message) (string, error) {
  856. var buf bytes.Buffer
  857. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  858. return "", err
  859. }
  860. if msg.Event != messageEvent {
  861. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  862. }
  863. return fmt.Sprintf("data: %s\n", buf.String()), nil
  864. }
  865. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  866. }
  867. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  868. encoder := func(msg *message) (string, error) {
  869. if msg.Event == messageEvent { // only handle default events
  870. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  871. }
  872. return "\n", nil // "keepalive" and "open" events just send an empty line
  873. }
  874. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  875. }
  876. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  877. log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
  878. defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
  879. if err := v.SubscriptionAllowed(); err != nil {
  880. return errHTTPTooManyRequestsLimitSubscriptions
  881. }
  882. defer v.RemoveSubscription()
  883. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  884. if err != nil {
  885. return err
  886. }
  887. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  888. if err != nil {
  889. return err
  890. }
  891. var wlock sync.Mutex
  892. defer func() {
  893. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  894. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  895. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  896. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  897. wlock.TryLock()
  898. }()
  899. sub := func(v *visitor, msg *message) error {
  900. if !filters.Pass(msg) {
  901. return nil
  902. }
  903. m, err := encoder(msg)
  904. if err != nil {
  905. return err
  906. }
  907. wlock.Lock()
  908. defer wlock.Unlock()
  909. if _, err := w.Write([]byte(m)); err != nil {
  910. return err
  911. }
  912. if fl, ok := w.(http.Flusher); ok {
  913. fl.Flush()
  914. }
  915. return nil
  916. }
  917. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  918. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  919. if poll {
  920. return s.sendOldMessages(topics, since, scheduled, v, sub)
  921. }
  922. subscriberIDs := make([]int, 0)
  923. for _, t := range topics {
  924. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  925. }
  926. defer func() {
  927. for i, subscriberID := range subscriberIDs {
  928. topics[i].Unsubscribe(subscriberID) // Order!
  929. }
  930. }()
  931. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  932. return err
  933. }
  934. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  935. return err
  936. }
  937. for {
  938. select {
  939. case <-r.Context().Done():
  940. return nil
  941. case <-time.After(s.config.KeepaliveInterval):
  942. log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r))
  943. v.Keepalive()
  944. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  945. return err
  946. }
  947. }
  948. }
  949. }
  950. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  951. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  952. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  953. }
  954. if err := v.SubscriptionAllowed(); err != nil {
  955. return errHTTPTooManyRequestsLimitSubscriptions
  956. }
  957. defer v.RemoveSubscription()
  958. log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r))
  959. defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r))
  960. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  961. if err != nil {
  962. return err
  963. }
  964. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  965. if err != nil {
  966. return err
  967. }
  968. upgrader := &websocket.Upgrader{
  969. ReadBufferSize: wsBufferSize,
  970. WriteBufferSize: wsBufferSize,
  971. CheckOrigin: func(r *http.Request) bool {
  972. return true // We're open for business!
  973. },
  974. }
  975. conn, err := upgrader.Upgrade(w, r, nil)
  976. if err != nil {
  977. return err
  978. }
  979. defer conn.Close()
  980. var wlock sync.Mutex
  981. g, ctx := errgroup.WithContext(context.Background())
  982. g.Go(func() error {
  983. pongWait := s.config.KeepaliveInterval + wsPongWait
  984. conn.SetReadLimit(wsReadLimit)
  985. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  986. return err
  987. }
  988. conn.SetPongHandler(func(appData string) error {
  989. log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r))
  990. return conn.SetReadDeadline(time.Now().Add(pongWait))
  991. })
  992. for {
  993. _, _, err := conn.NextReader()
  994. if err != nil {
  995. return err
  996. }
  997. }
  998. })
  999. g.Go(func() error {
  1000. ping := func() error {
  1001. wlock.Lock()
  1002. defer wlock.Unlock()
  1003. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1004. return err
  1005. }
  1006. log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r))
  1007. return conn.WriteMessage(websocket.PingMessage, nil)
  1008. }
  1009. for {
  1010. select {
  1011. case <-ctx.Done():
  1012. return nil
  1013. case <-time.After(s.config.KeepaliveInterval):
  1014. v.Keepalive()
  1015. if err := ping(); err != nil {
  1016. return err
  1017. }
  1018. }
  1019. }
  1020. })
  1021. sub := func(v *visitor, msg *message) error {
  1022. if !filters.Pass(msg) {
  1023. return nil
  1024. }
  1025. wlock.Lock()
  1026. defer wlock.Unlock()
  1027. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1028. return err
  1029. }
  1030. return conn.WriteJSON(msg)
  1031. }
  1032. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  1033. if poll {
  1034. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1035. }
  1036. subscriberIDs := make([]int, 0)
  1037. for _, t := range topics {
  1038. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  1039. }
  1040. defer func() {
  1041. for i, subscriberID := range subscriberIDs {
  1042. topics[i].Unsubscribe(subscriberID) // Order!
  1043. }
  1044. }()
  1045. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1046. return err
  1047. }
  1048. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1049. return err
  1050. }
  1051. err = g.Wait()
  1052. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  1053. log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error())
  1054. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1055. }
  1056. return err
  1057. }
  1058. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
  1059. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1060. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1061. since, err = parseSince(r, poll)
  1062. if err != nil {
  1063. return
  1064. }
  1065. filters, err = parseQueryFilters(r)
  1066. if err != nil {
  1067. return
  1068. }
  1069. return
  1070. }
  1071. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1072. // marker, returning only messages that are newer than the marker.
  1073. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1074. if since.IsNone() {
  1075. return nil
  1076. }
  1077. messages := make([]*message, 0)
  1078. for _, t := range topics {
  1079. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1080. if err != nil {
  1081. return err
  1082. }
  1083. messages = append(messages, topicMessages...)
  1084. }
  1085. sort.Slice(messages, func(i, j int) bool {
  1086. return messages[i].Time < messages[j].Time
  1087. })
  1088. for _, m := range messages {
  1089. if err := sub(v, m); err != nil {
  1090. return err
  1091. }
  1092. }
  1093. return nil
  1094. }
  1095. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1096. //
  1097. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1098. // "all" for all messages.
  1099. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1100. since := readParam(r, "x-since", "since", "si")
  1101. // Easy cases (empty, all, none)
  1102. if since == "" {
  1103. if poll {
  1104. return sinceAllMessages, nil
  1105. }
  1106. return sinceNoMessages, nil
  1107. } else if since == "all" {
  1108. return sinceAllMessages, nil
  1109. } else if since == "none" {
  1110. return sinceNoMessages, nil
  1111. }
  1112. // ID, timestamp, duration
  1113. if validMessageID(since) {
  1114. return newSinceID(since), nil
  1115. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1116. return newSinceTime(s), nil
  1117. } else if d, err := time.ParseDuration(since); err == nil {
  1118. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1119. }
  1120. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1121. }
  1122. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1123. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1124. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1125. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1126. return nil
  1127. }
  1128. func (s *Server) topicFromPath(path string) (*topic, error) {
  1129. parts := strings.Split(path, "/")
  1130. if len(parts) < 2 {
  1131. return nil, errHTTPBadRequestTopicInvalid
  1132. }
  1133. topics, err := s.topicsFromIDs(parts[1])
  1134. if err != nil {
  1135. return nil, err
  1136. }
  1137. return topics[0], nil
  1138. }
  1139. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1140. parts := strings.Split(path, "/")
  1141. if len(parts) < 2 {
  1142. return nil, "", errHTTPBadRequestTopicInvalid
  1143. }
  1144. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1145. topics, err := s.topicsFromIDs(topicIDs...)
  1146. if err != nil {
  1147. return nil, "", errHTTPBadRequestTopicInvalid
  1148. }
  1149. return topics, parts[1], nil
  1150. }
  1151. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1152. s.mu.Lock()
  1153. defer s.mu.Unlock()
  1154. topics := make([]*topic, 0)
  1155. for _, id := range ids {
  1156. if util.Contains(disallowedTopics, id) {
  1157. return nil, errHTTPBadRequestTopicDisallowed
  1158. }
  1159. if _, ok := s.topics[id]; !ok {
  1160. if len(s.topics) >= s.config.TotalTopicLimit {
  1161. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1162. }
  1163. s.topics[id] = newTopic(id)
  1164. }
  1165. topics = append(topics, s.topics[id])
  1166. }
  1167. return topics, nil
  1168. }
  1169. func (s *Server) execManager() {
  1170. log.Debug("Manager: Starting")
  1171. defer log.Debug("Manager: Finished")
  1172. // WARNING: Make sure to only selectively lock with the mutex, and be aware that this
  1173. // there is no mutex for the entire function.
  1174. // Expire visitors from rate visitors map
  1175. s.mu.Lock()
  1176. staleVisitors := 0
  1177. for ip, v := range s.visitors {
  1178. if v.Stale() {
  1179. log.Trace("Deleting stale visitor %s", v.ip)
  1180. delete(s.visitors, ip)
  1181. staleVisitors++
  1182. }
  1183. }
  1184. s.mu.Unlock()
  1185. log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
  1186. // Delete expired user tokens and users
  1187. if s.userManager != nil {
  1188. if err := s.userManager.RemoveExpiredTokens(); err != nil {
  1189. log.Warn("Error expiring user tokens: %s", err.Error())
  1190. }
  1191. if err := s.userManager.RemoveDeletedUsers(); err != nil {
  1192. log.Warn("Error deleting soft-deleted users: %s", err.Error())
  1193. }
  1194. }
  1195. // Delete expired attachments
  1196. if s.fileCache != nil {
  1197. ids, err := s.messageCache.AttachmentsExpired()
  1198. if err != nil {
  1199. log.Warn("Manager: Error retrieving expired attachments: %s", err.Error())
  1200. } else if len(ids) > 0 {
  1201. if log.IsDebug() {
  1202. log.Debug("Manager: Deleting attachments %s", strings.Join(ids, ", "))
  1203. }
  1204. if err := s.fileCache.Remove(ids...); err != nil {
  1205. log.Warn("Manager: Error deleting attachments: %s", err.Error())
  1206. }
  1207. if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
  1208. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1209. }
  1210. } else {
  1211. log.Debug("Manager: No expired attachments to delete")
  1212. }
  1213. }
  1214. // Prune messages
  1215. log.Debug("Manager: Pruning messages")
  1216. expiredMessageIDs, err := s.messageCache.MessagesExpired()
  1217. if err != nil {
  1218. log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
  1219. } else if len(expiredMessageIDs) > 0 {
  1220. if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
  1221. log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
  1222. }
  1223. if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
  1224. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1225. }
  1226. } else {
  1227. log.Debug("Manager: No expired messages to delete")
  1228. }
  1229. // Message count per topic
  1230. var messages int
  1231. messageCounts, err := s.messageCache.MessageCounts()
  1232. if err != nil {
  1233. log.Warn("Manager: Cannot get message counts: %s", err.Error())
  1234. messageCounts = make(map[string]int) // Empty, so we can continue
  1235. }
  1236. for _, count := range messageCounts {
  1237. messages += count
  1238. }
  1239. // Remove subscriptions without subscribers
  1240. s.mu.Lock()
  1241. var subscribers int
  1242. for _, t := range s.topics {
  1243. subs := t.SubscribersCount()
  1244. msgs, exists := messageCounts[t.ID]
  1245. if subs == 0 && (!exists || msgs == 0) {
  1246. log.Trace("Deleting empty topic %s", t.ID)
  1247. delete(s.topics, t.ID)
  1248. continue
  1249. }
  1250. subscribers += subs
  1251. }
  1252. s.mu.Unlock()
  1253. // Mail stats
  1254. var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
  1255. if s.smtpServerBackend != nil {
  1256. receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
  1257. }
  1258. var sentMailTotal, sentMailSuccess, sentMailFailure int64
  1259. if s.smtpSender != nil {
  1260. sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
  1261. }
  1262. // Print stats
  1263. s.mu.Lock()
  1264. messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
  1265. s.mu.Unlock()
  1266. log.Info("Stats: %d messages published, %d in cache, %d topic(s) active, %d subscriber(s), %d visitor(s), %d mails received (%d successful, %d failed), %d mails sent (%d successful, %d failed)",
  1267. messagesCount, messages, topicsCount, subscribers, visitorsCount,
  1268. receivedMailTotal, receivedMailSuccess, receivedMailFailure,
  1269. sentMailTotal, sentMailSuccess, sentMailFailure)
  1270. }
  1271. func (s *Server) runSMTPServer() error {
  1272. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1273. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1274. s.smtpServer.Addr = s.config.SMTPServerListen
  1275. s.smtpServer.Domain = s.config.SMTPServerDomain
  1276. s.smtpServer.ReadTimeout = 10 * time.Second
  1277. s.smtpServer.WriteTimeout = 10 * time.Second
  1278. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1279. s.smtpServer.MaxRecipients = 1
  1280. s.smtpServer.AllowInsecureAuth = true
  1281. return s.smtpServer.ListenAndServe()
  1282. }
  1283. func (s *Server) runManager() {
  1284. for {
  1285. select {
  1286. case <-time.After(s.config.ManagerInterval):
  1287. s.execManager()
  1288. case <-s.closeChan:
  1289. return
  1290. }
  1291. }
  1292. }
  1293. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1294. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1295. func (s *Server) runStatsResetter() {
  1296. for {
  1297. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1298. timer := time.NewTimer(time.Until(runAt))
  1299. log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
  1300. select {
  1301. case <-timer.C:
  1302. s.resetStats()
  1303. case <-s.closeChan:
  1304. timer.Stop()
  1305. return
  1306. }
  1307. }
  1308. }
  1309. func (s *Server) resetStats() {
  1310. log.Info("Resetting all visitor stats (daily task)")
  1311. s.mu.Lock()
  1312. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1313. for _, v := range s.visitors {
  1314. v.ResetStats()
  1315. }
  1316. if s.userManager != nil {
  1317. if err := s.userManager.ResetStats(); err != nil {
  1318. log.Warn("Failed to write to database: %s", err.Error())
  1319. }
  1320. }
  1321. }
  1322. func (s *Server) runFirebaseKeepaliver() {
  1323. if s.firebaseClient == nil {
  1324. return
  1325. }
  1326. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1327. for {
  1328. select {
  1329. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1330. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1331. case <-time.After(s.config.FirebasePollInterval):
  1332. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1333. case <-s.closeChan:
  1334. return
  1335. }
  1336. }
  1337. }
  1338. func (s *Server) runDelayedSender() {
  1339. for {
  1340. select {
  1341. case <-time.After(s.config.DelayedSenderInterval):
  1342. if err := s.sendDelayedMessages(); err != nil {
  1343. log.Warn("Error sending delayed messages: %s", err.Error())
  1344. }
  1345. case <-s.closeChan:
  1346. return
  1347. }
  1348. }
  1349. }
  1350. func (s *Server) sendDelayedMessages() error {
  1351. messages, err := s.messageCache.MessagesDue()
  1352. if err != nil {
  1353. return err
  1354. }
  1355. for _, m := range messages {
  1356. var v *visitor
  1357. if s.userManager != nil && m.User != "" {
  1358. u, err := s.userManager.User(m.User)
  1359. if err != nil {
  1360. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1361. continue
  1362. }
  1363. v = s.visitorFromUser(u, m.Sender)
  1364. } else {
  1365. v = s.visitorFromIP(m.Sender)
  1366. }
  1367. if err := s.sendDelayedMessage(v, m); err != nil {
  1368. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1369. }
  1370. }
  1371. return nil
  1372. }
  1373. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1374. log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
  1375. s.mu.Lock()
  1376. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1377. s.mu.Unlock()
  1378. if ok {
  1379. go func() {
  1380. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1381. if err := t.Publish(v, m); err != nil {
  1382. log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error())
  1383. }
  1384. }()
  1385. }
  1386. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1387. go s.sendToFirebase(v, m)
  1388. }
  1389. if s.config.UpstreamBaseURL != "" {
  1390. go s.forwardPollRequest(v, m)
  1391. }
  1392. if err := s.messageCache.MarkPublished(m); err != nil {
  1393. return err
  1394. }
  1395. return nil
  1396. }
  1397. func (s *Server) limitRequests(next handleFunc) handleFunc {
  1398. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1399. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  1400. return next(w, r, v)
  1401. } else if err := v.RequestAllowed(); err != nil {
  1402. return errHTTPTooManyRequestsLimitRequests
  1403. }
  1404. return next(w, r, v)
  1405. }
  1406. }
  1407. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1408. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1409. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1410. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1411. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
  1412. if err != nil {
  1413. return err
  1414. }
  1415. if !topicRegex.MatchString(m.Topic) {
  1416. return errHTTPBadRequestTopicInvalid
  1417. }
  1418. if m.Message == "" {
  1419. m.Message = emptyMessageBody
  1420. }
  1421. r.URL.Path = "/" + m.Topic
  1422. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1423. if m.Title != "" {
  1424. r.Header.Set("X-Title", m.Title)
  1425. }
  1426. if m.Priority != 0 {
  1427. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1428. }
  1429. if m.Tags != nil && len(m.Tags) > 0 {
  1430. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1431. }
  1432. if m.Attach != "" {
  1433. r.Header.Set("X-Attach", m.Attach)
  1434. }
  1435. if m.Filename != "" {
  1436. r.Header.Set("X-Filename", m.Filename)
  1437. }
  1438. if m.Click != "" {
  1439. r.Header.Set("X-Click", m.Click)
  1440. }
  1441. if m.Icon != "" {
  1442. r.Header.Set("X-Icon", m.Icon)
  1443. }
  1444. if len(m.Actions) > 0 {
  1445. actionsStr, err := json.Marshal(m.Actions)
  1446. if err != nil {
  1447. return errHTTPBadRequestMessageJSONInvalid
  1448. }
  1449. r.Header.Set("X-Actions", string(actionsStr))
  1450. }
  1451. if m.Email != "" {
  1452. r.Header.Set("X-Email", m.Email)
  1453. }
  1454. if m.Delay != "" {
  1455. r.Header.Set("X-Delay", m.Delay)
  1456. }
  1457. return next(w, r, v)
  1458. }
  1459. }
  1460. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1461. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1462. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1463. if err != nil {
  1464. return err
  1465. }
  1466. if err := next(w, newRequest, v); err != nil {
  1467. return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
  1468. }
  1469. return nil
  1470. }
  1471. }
  1472. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1473. return s.autorizeTopic(next, user.PermissionWrite)
  1474. }
  1475. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1476. return s.autorizeTopic(next, user.PermissionRead)
  1477. }
  1478. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1479. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1480. if s.userManager == nil {
  1481. return next(w, r, v)
  1482. }
  1483. topics, _, err := s.topicsFromPath(r.URL.Path)
  1484. if err != nil {
  1485. return err
  1486. }
  1487. for _, t := range topics {
  1488. if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
  1489. log.Info("unauthorized: %s", err.Error())
  1490. return errHTTPForbidden
  1491. }
  1492. }
  1493. return next(w, r, v)
  1494. }
  1495. }
  1496. // visitor creates or retrieves a rate.Limiter for the given visitor.
  1497. // Note that this function will always return a visitor, even if an error occurs.
  1498. func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
  1499. ip := extractIPAddress(r, s.config.BehindProxy)
  1500. var u *user.User // may stay nil if no auth header!
  1501. if u, err = s.authenticate(r); err != nil {
  1502. log.Debug("authentication failed: %s", err.Error())
  1503. err = errHTTPUnauthorized // Always return visitor, even when error occurs!
  1504. }
  1505. if u != nil {
  1506. v = s.visitorFromUser(u, ip)
  1507. } else {
  1508. v = s.visitorFromIP(ip)
  1509. }
  1510. v.mu.Lock()
  1511. v.user = u
  1512. v.mu.Unlock()
  1513. return v, err // Always return visitor, even when error occurs!
  1514. }
  1515. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1516. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1517. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1518. // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
  1519. func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
  1520. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1521. queryParam := readQueryParam(r, "authorization", "auth")
  1522. if queryParam != "" {
  1523. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1524. if err != nil {
  1525. return nil, err
  1526. }
  1527. value = strings.TrimSpace(string(a))
  1528. }
  1529. if value == "" {
  1530. return nil, nil
  1531. } else if s.userManager == nil {
  1532. return nil, errHTTPUnauthorized
  1533. }
  1534. if strings.HasPrefix(value, "Bearer") {
  1535. return s.authenticateBearerAuth(value)
  1536. }
  1537. return s.authenticateBasicAuth(r, value)
  1538. }
  1539. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1540. r.Header.Set("Authorization", value)
  1541. username, password, ok := r.BasicAuth()
  1542. if !ok {
  1543. return nil, errors.New("invalid basic auth")
  1544. }
  1545. return s.userManager.Authenticate(username, password)
  1546. }
  1547. func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
  1548. token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
  1549. return s.userManager.AuthenticateToken(token)
  1550. }
  1551. func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
  1552. s.mu.Lock()
  1553. defer s.mu.Unlock()
  1554. v, exists := s.visitors[visitorID]
  1555. if !exists {
  1556. s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1557. return s.visitors[visitorID]
  1558. }
  1559. v.Keepalive()
  1560. return v
  1561. }
  1562. func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
  1563. return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
  1564. }
  1565. func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
  1566. return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
  1567. }
  1568. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1569. w.Header().Set("Content-Type", "application/json")
  1570. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1571. if err := json.NewEncoder(w).Encode(v); err != nil {
  1572. return err
  1573. }
  1574. return nil
  1575. }