server.go 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/sha256"
  6. "embed"
  7. "encoding/base64"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "heckel.io/ntfy/user"
  12. "io"
  13. "net"
  14. "net/http"
  15. "net/netip"
  16. "net/url"
  17. "os"
  18. "path"
  19. "path/filepath"
  20. "regexp"
  21. "sort"
  22. "strconv"
  23. "strings"
  24. "sync"
  25. "time"
  26. "unicode/utf8"
  27. "heckel.io/ntfy/log"
  28. "github.com/emersion/go-smtp"
  29. "github.com/gorilla/websocket"
  30. "golang.org/x/sync/errgroup"
  31. "heckel.io/ntfy/util"
  32. )
  33. /*
  34. TODO
  35. --
  36. UAT results (round 1):
  37. - Security: Account re-creation leads to terrible behavior. Use user ID instead of user name for (a) visitor map, (b) messages.user column, (c) Stripe checkout session
  38. - Account: Changing password should confirm the old password (Thorben)
  39. - Signup: Re-add password confirmation (Thorben & deadcade)
  40. - Reservation: Kill existing subscribers when topic is reserved (deadcade)
  41. - Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
  42. - Reservation (UI): Ask for confirmation when removing reservation (deadcade)
  43. - Logging: Add detailed logging with username/customerID for all Stripe events (phil)
  44. races:
  45. - v.user --> see publishSyncEventAsync() test
  46. payments:
  47. - reconciliation
  48. delete messages + reserved topics on ResetTier delete attachments in access.go
  49. account deletion should delete messages and reservations and attachments
  50. Limits & rate limiting:
  51. rate limiting weirdness. wth is going on?
  52. bandwidth limit must be in tier
  53. users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
  54. login/account endpoints
  55. when ResetStats() is run, reset messagesLimiter (and others)?
  56. Delete visitor when tier is changed to refresh rate limiters
  57. Make sure account endpoints make sense for admins
  58. UI:
  59. -
  60. - reservation icons
  61. - reservation table delete button: dialog "keep or delete messages?"
  62. - flicker of upgrade banner
  63. - JS constants
  64. Sync:
  65. - sync problems with "deleteAfter=0" and "displayName="
  66. Tests:
  67. - Payment endpoints (make mocks)
  68. - Message rate limiting and reset tests
  69. - Bandwidth limit test
  70. - test that the visitor is based on the IP address when a user has no tier
  71. */
  72. // Server is the main server, providing the UI and API for ntfy
  73. type Server struct {
  74. config *Config
  75. httpServer *http.Server
  76. httpsServer *http.Server
  77. unixListener net.Listener
  78. smtpServer *smtp.Server
  79. smtpServerBackend *smtpBackend
  80. smtpSender mailer
  81. topics map[string]*topic
  82. visitors map[string]*visitor // ip:<ip> or user:<user>
  83. firebaseClient *firebaseClient
  84. messages int64
  85. userManager *user.Manager // Might be nil!
  86. messageCache *messageCache // Database that stores the messages
  87. fileCache *fileCache // File system based cache that stores attachments
  88. stripe stripeAPI // Stripe API, can be replaced with a mock
  89. priceCache *util.LookupCache[map[string]string] // Stripe price ID -> formatted price
  90. closeChan chan bool
  91. mu sync.Mutex
  92. }
  93. // handleFunc extends the normal http.HandlerFunc to be able to easily return errors
  94. type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
  95. var (
  96. // If changed, don't forget to update Android App and auth_sqlite.go
  97. topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
  98. topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
  99. externalTopicPathRegex = regexp.MustCompile(`^/[^/]+\.[^/]+/[-_A-Za-z0-9]{1,64}$`) // Extended topic path, for web-app, e.g. /example.com/mytopic
  100. jsonPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/json$`)
  101. ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`)
  102. rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`)
  103. wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`)
  104. authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`)
  105. publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
  106. webConfigPath = "/config.js"
  107. accountPath = "/account"
  108. matrixPushPath = "/_matrix/push/v1/notify"
  109. apiHealthPath = "/v1/health"
  110. apiTiers = "/v1/tiers"
  111. apiAccountPath = "/v1/account"
  112. apiAccountTokenPath = "/v1/account/token"
  113. apiAccountPasswordPath = "/v1/account/password"
  114. apiAccountSettingsPath = "/v1/account/settings"
  115. apiAccountSubscriptionPath = "/v1/account/subscription"
  116. apiAccountReservationPath = "/v1/account/reservation"
  117. apiAccountBillingPortalPath = "/v1/account/billing/portal"
  118. apiAccountBillingWebhookPath = "/v1/account/billing/webhook"
  119. apiAccountBillingSubscriptionPath = "/v1/account/billing/subscription"
  120. apiAccountBillingSubscriptionCheckoutSuccessTemplate = "/v1/account/billing/subscription/success/{CHECKOUT_SESSION_ID}"
  121. apiAccountBillingSubscriptionCheckoutSuccessRegex = regexp.MustCompile(`/v1/account/billing/subscription/success/(.+)$`)
  122. apiAccountReservationSingleRegex = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
  123. apiAccountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
  124. staticRegex = regexp.MustCompile(`^/static/.+`)
  125. docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
  126. fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
  127. disallowedTopics = []string{"docs", "static", "file", "app", "account", "settings", "pricing", "signup", "login", "reset-password"} // If updated, also update in Android and web app
  128. urlRegex = regexp.MustCompile(`^https?://`)
  129. //go:embed site
  130. webFs embed.FS
  131. webFsCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: webFs}
  132. webSiteDir = "/site"
  133. webHomeIndex = "/home.html" // Landing page, only if "web-root: home"
  134. webAppIndex = "/app.html" // React app
  135. //go:embed docs
  136. docsStaticFs embed.FS
  137. docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
  138. )
  139. const (
  140. firebaseControlTopic = "~control" // See Android if changed
  141. firebasePollTopic = "~poll" // See iOS if changed
  142. emptyMessageBody = "triggered" // Used if message body is empty
  143. newMessageBody = "New message" // Used in poll requests as generic message
  144. defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
  145. encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
  146. jsonBodyBytesLimit = 16384
  147. )
  148. // WebSocket constants
  149. const (
  150. wsWriteWait = 2 * time.Second
  151. wsBufferSize = 1024
  152. wsReadLimit = 64 // We only ever receive PINGs
  153. wsPongWait = 15 * time.Second
  154. )
  155. // New instantiates a new Server. It creates the cache and adds a Firebase
  156. // subscriber (if configured).
  157. func New(conf *Config) (*Server, error) {
  158. var mailer mailer
  159. if conf.SMTPSenderAddr != "" {
  160. mailer = &smtpSender{config: conf}
  161. }
  162. var stripe stripeAPI
  163. if conf.StripeSecretKey != "" {
  164. stripe = newStripeAPI()
  165. }
  166. messageCache, err := createMessageCache(conf)
  167. if err != nil {
  168. return nil, err
  169. }
  170. topics, err := messageCache.Topics()
  171. if err != nil {
  172. return nil, err
  173. }
  174. var fileCache *fileCache
  175. if conf.AttachmentCacheDir != "" {
  176. fileCache, err = newFileCache(conf.AttachmentCacheDir, conf.AttachmentTotalSizeLimit)
  177. if err != nil {
  178. return nil, err
  179. }
  180. }
  181. var userManager *user.Manager
  182. if conf.AuthFile != "" {
  183. userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault)
  184. if err != nil {
  185. return nil, err
  186. }
  187. }
  188. var firebaseClient *firebaseClient
  189. if conf.FirebaseKeyFile != "" {
  190. sender, err := newFirebaseSender(conf.FirebaseKeyFile)
  191. if err != nil {
  192. return nil, err
  193. }
  194. firebaseClient = newFirebaseClient(sender, userManager)
  195. }
  196. s := &Server{
  197. config: conf,
  198. messageCache: messageCache,
  199. fileCache: fileCache,
  200. firebaseClient: firebaseClient,
  201. smtpSender: mailer,
  202. topics: topics,
  203. userManager: userManager,
  204. visitors: make(map[string]*visitor),
  205. stripe: stripe,
  206. }
  207. s.priceCache = util.NewLookupCache(s.fetchStripePrices, conf.StripePriceCacheDuration)
  208. return s, nil
  209. }
  210. func createMessageCache(conf *Config) (*messageCache, error) {
  211. if conf.CacheDuration == 0 {
  212. return newNopCache()
  213. } else if conf.CacheFile != "" {
  214. return newSqliteCache(conf.CacheFile, conf.CacheStartupQueries, conf.CacheDuration, conf.CacheBatchSize, conf.CacheBatchTimeout, false)
  215. }
  216. return newMemCache()
  217. }
  218. // Run executes the main server. It listens on HTTP (+ HTTPS, if configured), and starts
  219. // a manager go routine to print stats and prune messages.
  220. func (s *Server) Run() error {
  221. var listenStr string
  222. if s.config.ListenHTTP != "" {
  223. listenStr += fmt.Sprintf(" %s[http]", s.config.ListenHTTP)
  224. }
  225. if s.config.ListenHTTPS != "" {
  226. listenStr += fmt.Sprintf(" %s[https]", s.config.ListenHTTPS)
  227. }
  228. if s.config.ListenUnix != "" {
  229. listenStr += fmt.Sprintf(" %s[unix]", s.config.ListenUnix)
  230. }
  231. if s.config.SMTPServerListen != "" {
  232. listenStr += fmt.Sprintf(" %s[smtp]", s.config.SMTPServerListen)
  233. }
  234. log.Info("Listening on%s, ntfy %s, log level is %s", listenStr, s.config.Version, log.CurrentLevel().String())
  235. mux := http.NewServeMux()
  236. mux.HandleFunc("/", s.handle)
  237. errChan := make(chan error)
  238. s.mu.Lock()
  239. s.closeChan = make(chan bool)
  240. if s.config.ListenHTTP != "" {
  241. s.httpServer = &http.Server{Addr: s.config.ListenHTTP, Handler: mux}
  242. go func() {
  243. errChan <- s.httpServer.ListenAndServe()
  244. }()
  245. }
  246. if s.config.ListenHTTPS != "" {
  247. s.httpsServer = &http.Server{Addr: s.config.ListenHTTPS, Handler: mux}
  248. go func() {
  249. errChan <- s.httpsServer.ListenAndServeTLS(s.config.CertFile, s.config.KeyFile)
  250. }()
  251. }
  252. if s.config.ListenUnix != "" {
  253. go func() {
  254. var err error
  255. s.mu.Lock()
  256. os.Remove(s.config.ListenUnix)
  257. s.unixListener, err = net.Listen("unix", s.config.ListenUnix)
  258. if err != nil {
  259. s.mu.Unlock()
  260. errChan <- err
  261. return
  262. }
  263. defer s.unixListener.Close()
  264. if s.config.ListenUnixMode > 0 {
  265. if err := os.Chmod(s.config.ListenUnix, s.config.ListenUnixMode); err != nil {
  266. s.mu.Unlock()
  267. errChan <- err
  268. return
  269. }
  270. }
  271. s.mu.Unlock()
  272. httpServer := &http.Server{Handler: mux}
  273. errChan <- httpServer.Serve(s.unixListener)
  274. }()
  275. }
  276. if s.config.SMTPServerListen != "" {
  277. go func() {
  278. errChan <- s.runSMTPServer()
  279. }()
  280. }
  281. s.mu.Unlock()
  282. go s.runManager()
  283. go s.runStatsResetter()
  284. go s.runDelayedSender()
  285. go s.runFirebaseKeepaliver()
  286. return <-errChan
  287. }
  288. // Stop stops HTTP (+HTTPS) server and all managers
  289. func (s *Server) Stop() {
  290. s.mu.Lock()
  291. defer s.mu.Unlock()
  292. if s.httpServer != nil {
  293. s.httpServer.Close()
  294. }
  295. if s.httpsServer != nil {
  296. s.httpsServer.Close()
  297. }
  298. if s.unixListener != nil {
  299. s.unixListener.Close()
  300. }
  301. if s.smtpServer != nil {
  302. s.smtpServer.Close()
  303. }
  304. close(s.closeChan)
  305. }
  306. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  307. v, err := s.visitor(r) // Note: Always returns v, even when error is returned
  308. if err == nil {
  309. log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
  310. if log.IsTrace() {
  311. log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
  312. }
  313. err = s.handleInternal(w, r, v)
  314. }
  315. if err != nil {
  316. if websocket.IsWebSocketUpgrade(r) {
  317. isNormalError := strings.Contains(err.Error(), "i/o timeout")
  318. if isNormalError {
  319. log.Debug("%s WebSocket error (this error is okay, it happens a lot): %s", logHTTPPrefix(v, r), err.Error())
  320. } else {
  321. log.Info("%s WebSocket error: %s", logHTTPPrefix(v, r), err.Error())
  322. }
  323. return // Do not attempt to write to upgraded connection
  324. }
  325. if matrixErr, ok := err.(*errMatrix); ok {
  326. writeMatrixError(w, r, v, matrixErr)
  327. return
  328. }
  329. httpErr, ok := err.(*errHTTP)
  330. if !ok {
  331. httpErr = errHTTPInternalError
  332. }
  333. isNormalError := httpErr.HTTPCode == http.StatusNotFound || httpErr.HTTPCode == http.StatusBadRequest
  334. if isNormalError {
  335. log.Debug("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  336. } else {
  337. log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
  338. }
  339. w.Header().Set("Content-Type", "application/json")
  340. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  341. w.WriteHeader(httpErr.HTTPCode)
  342. io.WriteString(w, httpErr.JSON()+"\n")
  343. }
  344. }
  345. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
  346. if r.Method == http.MethodGet && r.URL.Path == "/" {
  347. return s.ensureWebEnabled(s.handleHome)(w, r, v)
  348. } else if r.Method == http.MethodHead && r.URL.Path == "/" {
  349. return s.ensureWebEnabled(s.handleEmpty)(w, r, v)
  350. } else if r.Method == http.MethodGet && r.URL.Path == apiHealthPath {
  351. return s.handleHealth(w, r, v)
  352. } else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
  353. return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
  354. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
  355. return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
  356. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountTokenPath {
  357. return s.ensureUser(s.handleAccountTokenIssue)(w, r, v)
  358. } else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
  359. return s.handleAccountGet(w, r, v) // Allowed by anonymous
  360. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountPath {
  361. return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
  362. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountPasswordPath {
  363. return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
  364. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountTokenPath {
  365. return s.ensureUser(s.handleAccountTokenExtend)(w, r, v)
  366. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountTokenPath {
  367. return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
  368. } else if r.Method == http.MethodPatch && r.URL.Path == apiAccountSettingsPath {
  369. return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
  370. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountSubscriptionPath {
  371. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
  372. } else if r.Method == http.MethodPatch && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  373. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
  374. } else if r.Method == http.MethodDelete && apiAccountSubscriptionSingleRegex.MatchString(r.URL.Path) {
  375. return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
  376. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountReservationPath {
  377. return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
  378. } else if r.Method == http.MethodDelete && apiAccountReservationSingleRegex.MatchString(r.URL.Path) {
  379. return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
  380. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingSubscriptionPath {
  381. return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
  382. } else if r.Method == http.MethodGet && apiAccountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
  383. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
  384. } else if r.Method == http.MethodPut && r.URL.Path == apiAccountBillingSubscriptionPath {
  385. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
  386. } else if r.Method == http.MethodDelete && r.URL.Path == apiAccountBillingSubscriptionPath {
  387. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
  388. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingPortalPath {
  389. return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
  390. } else if r.Method == http.MethodPost && r.URL.Path == apiAccountBillingWebhookPath {
  391. return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v) // This request comes from Stripe!
  392. } else if r.Method == http.MethodGet && r.URL.Path == apiTiers {
  393. return s.ensurePaymentsEnabled(s.handleBillingTiersGet)(w, r, v)
  394. } else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
  395. return s.handleMatrixDiscovery(w)
  396. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  397. return s.ensureWebEnabled(s.handleStatic)(w, r, v)
  398. } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
  399. return s.ensureWebEnabled(s.handleDocs)(w, r, v)
  400. } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
  401. return s.limitRequests(s.handleFile)(w, r, v)
  402. } else if r.Method == http.MethodOptions {
  403. return s.ensureWebEnabled(s.handleOptions)(w, r, v)
  404. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
  405. return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
  406. } else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
  407. return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
  408. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
  409. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  410. } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
  411. return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
  412. } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
  413. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
  414. } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
  415. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
  416. } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
  417. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
  418. } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
  419. return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
  420. } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
  421. return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
  422. } else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
  423. return s.ensureWebEnabled(s.handleTopic)(w, r, v)
  424. }
  425. return errHTTPNotFound
  426. }
  427. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request, v *visitor) error {
  428. if s.config.WebRootIsApp {
  429. r.URL.Path = webAppIndex
  430. } else {
  431. r.URL.Path = webHomeIndex
  432. }
  433. return s.handleStatic(w, r, v)
  434. }
  435. func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor) error {
  436. unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
  437. if unifiedpush {
  438. w.Header().Set("Content-Type", "application/json")
  439. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  440. _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
  441. return err
  442. }
  443. r.URL.Path = webAppIndex
  444. return s.handleStatic(w, r, v)
  445. }
  446. func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error {
  447. return nil
  448. }
  449. func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  450. return s.writeJSON(w, newSuccessResponse())
  451. }
  452. func (s *Server) handleHealth(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  453. response := &apiHealthResponse{
  454. Healthy: true,
  455. }
  456. return s.writeJSON(w, response)
  457. }
  458. func (s *Server) handleWebConfig(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  459. appRoot := "/"
  460. if !s.config.WebRootIsApp {
  461. appRoot = "/app"
  462. }
  463. response := &apiConfigResponse{
  464. BaseURL: "", // Will translate to window.location.origin
  465. AppRoot: appRoot,
  466. EnableLogin: s.config.EnableLogin,
  467. EnableSignup: s.config.EnableSignup,
  468. EnablePayments: s.config.StripeSecretKey != "",
  469. EnableReservations: s.config.EnableReservations,
  470. DisallowedTopics: disallowedTopics,
  471. }
  472. b, err := json.MarshalIndent(response, "", " ")
  473. if err != nil {
  474. return err
  475. }
  476. w.Header().Set("Content-Type", "text/javascript")
  477. _, err = io.WriteString(w, fmt.Sprintf("// Generated server configuration\nvar config = %s;\n", string(b)))
  478. return err
  479. }
  480. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  481. r.URL.Path = webSiteDir + r.URL.Path
  482. util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
  483. return nil
  484. }
  485. func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  486. util.Gzip(http.FileServer(http.FS(docsStaticCached))).ServeHTTP(w, r)
  487. return nil
  488. }
  489. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error {
  490. if s.config.AttachmentCacheDir == "" {
  491. return errHTTPInternalError
  492. }
  493. matches := fileRegex.FindStringSubmatch(r.URL.Path)
  494. if len(matches) != 2 {
  495. return errHTTPInternalErrorInvalidPath
  496. }
  497. messageID := matches[1]
  498. file := filepath.Join(s.config.AttachmentCacheDir, messageID)
  499. stat, err := os.Stat(file)
  500. if err != nil {
  501. return errHTTPNotFound
  502. }
  503. if r.Method == http.MethodGet {
  504. if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
  505. return errHTTPTooManyRequestsLimitAttachmentBandwidth
  506. }
  507. }
  508. w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
  509. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  510. if r.Method == http.MethodGet {
  511. f, err := os.Open(file)
  512. if err != nil {
  513. return err
  514. }
  515. defer f.Close()
  516. _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
  517. return err
  518. }
  519. return nil
  520. }
  521. func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
  522. if s.config.BaseURL == "" {
  523. return errHTTPInternalErrorMissingBaseURL
  524. }
  525. return writeMatrixDiscoveryResponse(w)
  526. }
  527. func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
  528. t, err := s.topicFromPath(r.URL.Path)
  529. if err != nil {
  530. return nil, err
  531. }
  532. if err := v.MessageAllowed(); err != nil {
  533. return nil, errHTTPTooManyRequestsLimitMessages
  534. }
  535. body, err := util.Peek(r.Body, s.config.MessageLimit)
  536. if err != nil {
  537. return nil, err
  538. }
  539. m := newDefaultMessage(t.ID, "")
  540. cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
  541. if err != nil {
  542. return nil, err
  543. }
  544. if m.PollID != "" {
  545. m = newPollRequestMessage(t.ID, m.PollID)
  546. }
  547. if v.user != nil {
  548. m.User = v.user.Name
  549. }
  550. m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
  551. if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
  552. return nil, err
  553. }
  554. if m.Message == "" {
  555. m.Message = emptyMessageBody
  556. }
  557. delayed := m.Time > time.Now().Unix()
  558. log.Debug("%s Received message: event=%s, user=%s, body=%d byte(s), delayed=%t, firebase=%t, cache=%t, up=%t, email=%s",
  559. logMessagePrefix(v, m), m.Event, m.User, len(m.Message), delayed, firebase, cache, unifiedpush, email)
  560. if log.IsTrace() {
  561. log.Trace("%s Message body: %s", logMessagePrefix(v, m), util.MaybeMarshalJSON(m))
  562. }
  563. if !delayed {
  564. if err := t.Publish(v, m); err != nil {
  565. return nil, err
  566. }
  567. if s.firebaseClient != nil && firebase {
  568. go s.sendToFirebase(v, m)
  569. }
  570. if s.smtpSender != nil && email != "" {
  571. v.IncrementEmails()
  572. go s.sendEmail(v, m, email)
  573. }
  574. if s.config.UpstreamBaseURL != "" {
  575. go s.forwardPollRequest(v, m)
  576. }
  577. } else {
  578. log.Debug("%s Message delayed, will process later", logMessagePrefix(v, m))
  579. }
  580. if cache {
  581. log.Debug("%s Adding message to cache", logMessagePrefix(v, m))
  582. if err := s.messageCache.AddMessage(m); err != nil {
  583. return nil, err
  584. }
  585. }
  586. v.IncrementMessages()
  587. if s.userManager != nil && v.user != nil {
  588. s.userManager.EnqueueStats(v.user)
  589. }
  590. s.mu.Lock()
  591. s.messages++
  592. s.mu.Unlock()
  593. return m, nil
  594. }
  595. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  596. m, err := s.handlePublishWithoutResponse(r, v)
  597. if err != nil {
  598. return err
  599. }
  600. return s.writeJSON(w, m)
  601. }
  602. func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
  603. _, err := s.handlePublishWithoutResponse(r, v)
  604. if err != nil {
  605. return &errMatrix{pushKey: r.Header.Get(matrixPushKeyHeader), err: err}
  606. }
  607. return writeMatrixSuccess(w)
  608. }
  609. func (s *Server) sendToFirebase(v *visitor, m *message) {
  610. log.Debug("%s Publishing to Firebase", logMessagePrefix(v, m))
  611. if err := s.firebaseClient.Send(v, m); err != nil {
  612. if err == errFirebaseTemporarilyBanned {
  613. log.Debug("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  614. } else {
  615. log.Warn("%s Unable to publish to Firebase: %v", logMessagePrefix(v, m), err.Error())
  616. }
  617. }
  618. }
  619. func (s *Server) sendEmail(v *visitor, m *message, email string) {
  620. log.Debug("%s Sending email to %s", logMessagePrefix(v, m), email)
  621. if err := s.smtpSender.Send(v, m, email); err != nil {
  622. log.Warn("%s Unable to send email to %s: %v", logMessagePrefix(v, m), email, err.Error())
  623. }
  624. }
  625. func (s *Server) forwardPollRequest(v *visitor, m *message) {
  626. topicURL := fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic)
  627. topicHash := fmt.Sprintf("%x", sha256.Sum256([]byte(topicURL)))
  628. forwardURL := fmt.Sprintf("%s/%s", s.config.UpstreamBaseURL, topicHash)
  629. log.Debug("%s Publishing poll request to %s", logMessagePrefix(v, m), forwardURL)
  630. req, err := http.NewRequest("POST", forwardURL, strings.NewReader(""))
  631. if err != nil {
  632. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  633. return
  634. }
  635. req.Header.Set("X-Poll-ID", m.ID)
  636. var httpClient = &http.Client{
  637. Timeout: time.Second * 10,
  638. }
  639. response, err := httpClient.Do(req)
  640. if err != nil {
  641. log.Warn("%s Unable to publish poll request: %v", logMessagePrefix(v, m), err.Error())
  642. return
  643. } else if response.StatusCode != http.StatusOK {
  644. log.Warn("%s Unable to publish poll request, unexpected HTTP status: %d", logMessagePrefix(v, m), response.StatusCode)
  645. return
  646. }
  647. }
  648. func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
  649. cache = readBoolParam(r, true, "x-cache", "cache")
  650. firebase = readBoolParam(r, true, "x-firebase", "firebase")
  651. m.Title = readParam(r, "x-title", "title", "t")
  652. m.Click = readParam(r, "x-click", "click")
  653. icon := readParam(r, "x-icon", "icon")
  654. filename := readParam(r, "x-filename", "filename", "file", "f")
  655. attach := readParam(r, "x-attach", "attach", "a")
  656. if attach != "" || filename != "" {
  657. m.Attachment = &attachment{}
  658. }
  659. if filename != "" {
  660. m.Attachment.Name = filename
  661. }
  662. if attach != "" {
  663. if !urlRegex.MatchString(attach) {
  664. return false, false, "", false, errHTTPBadRequestAttachmentURLInvalid
  665. }
  666. m.Attachment.URL = attach
  667. if m.Attachment.Name == "" {
  668. u, err := url.Parse(m.Attachment.URL)
  669. if err == nil {
  670. m.Attachment.Name = path.Base(u.Path)
  671. if m.Attachment.Name == "." || m.Attachment.Name == "/" {
  672. m.Attachment.Name = ""
  673. }
  674. }
  675. }
  676. if m.Attachment.Name == "" {
  677. m.Attachment.Name = "attachment"
  678. }
  679. }
  680. if icon != "" {
  681. if !urlRegex.MatchString(icon) {
  682. return false, false, "", false, errHTTPBadRequestIconURLInvalid
  683. }
  684. m.Icon = icon
  685. }
  686. email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
  687. if email != "" {
  688. if err := v.EmailAllowed(); err != nil {
  689. return false, false, "", false, errHTTPTooManyRequestsLimitEmails
  690. }
  691. }
  692. if s.smtpSender == nil && email != "" {
  693. return false, false, "", false, errHTTPBadRequestEmailDisabled
  694. }
  695. messageStr := strings.ReplaceAll(readParam(r, "x-message", "message", "m"), "\\n", "\n")
  696. if messageStr != "" {
  697. m.Message = messageStr
  698. }
  699. m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
  700. if err != nil {
  701. return false, false, "", false, errHTTPBadRequestPriorityInvalid
  702. }
  703. tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
  704. if tagsStr != "" {
  705. m.Tags = make([]string, 0)
  706. for _, s := range util.SplitNoEmpty(tagsStr, ",") {
  707. m.Tags = append(m.Tags, strings.TrimSpace(s))
  708. }
  709. }
  710. delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
  711. if delayStr != "" {
  712. if !cache {
  713. return false, false, "", false, errHTTPBadRequestDelayNoCache
  714. }
  715. if email != "" {
  716. return false, false, "", false, errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
  717. }
  718. delay, err := util.ParseFutureTime(delayStr, time.Now())
  719. if err != nil {
  720. return false, false, "", false, errHTTPBadRequestDelayCannotParse
  721. } else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
  722. return false, false, "", false, errHTTPBadRequestDelayTooSmall
  723. } else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
  724. return false, false, "", false, errHTTPBadRequestDelayTooLarge
  725. }
  726. m.Time = delay.Unix()
  727. m.Sender = v.ip // Important for rate limiting
  728. }
  729. actionsStr := readParam(r, "x-actions", "actions", "action")
  730. if actionsStr != "" {
  731. m.Actions, err = parseActions(actionsStr)
  732. if err != nil {
  733. return false, false, "", false, wrapErrHTTP(errHTTPBadRequestActionsInvalid, err.Error())
  734. }
  735. }
  736. unifiedpush = readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see GET too!
  737. if unifiedpush {
  738. firebase = false
  739. unifiedpush = true
  740. }
  741. m.PollID = readParam(r, "x-poll-id", "poll-id")
  742. if m.PollID != "" {
  743. unifiedpush = false
  744. cache = false
  745. email = ""
  746. }
  747. return cache, firebase, email, unifiedpush, nil
  748. }
  749. // handlePublishBody consumes the PUT/POST body and decides whether the body is an attachment or the message.
  750. //
  751. // 1. curl -X POST -H "Poll: 1234" ntfy.sh/...
  752. // If a message is flagged as poll request, the body does not matter and is discarded
  753. // 2. curl -T somebinarydata.bin "ntfy.sh/mytopic?up=1"
  754. // If body is binary, encode as base64, if not do not encode
  755. // 3. curl -H "Attach: http://example.com/file.jpg" ntfy.sh/mytopic
  756. // Body must be a message, because we attached an external URL
  757. // 4. curl -T short.txt -H "Filename: short.txt" ntfy.sh/mytopic
  758. // Body must be attachment, because we passed a filename
  759. // 5. curl -T file.txt ntfy.sh/mytopic
  760. // If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
  761. // 6. curl -T file.txt ntfy.sh/mytopic
  762. // If file.txt is > message limit, treat it as an attachment
  763. func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
  764. if m.Event == pollRequestEvent { // Case 1
  765. return s.handleBodyDiscard(body)
  766. } else if unifiedpush {
  767. return s.handleBodyAsMessageAutoDetect(m, body) // Case 2
  768. } else if m.Attachment != nil && m.Attachment.URL != "" {
  769. return s.handleBodyAsTextMessage(m, body) // Case 3
  770. } else if m.Attachment != nil && m.Attachment.Name != "" {
  771. return s.handleBodyAsAttachment(r, v, m, body) // Case 4
  772. } else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
  773. return s.handleBodyAsTextMessage(m, body) // Case 5
  774. }
  775. return s.handleBodyAsAttachment(r, v, m, body) // Case 6
  776. }
  777. func (s *Server) handleBodyDiscard(body *util.PeekedReadCloser) error {
  778. _, err := io.Copy(io.Discard, body)
  779. _ = body.Close()
  780. return err
  781. }
  782. func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
  783. if utf8.Valid(body.PeekedBytes) {
  784. m.Message = string(body.PeekedBytes) // Do not trim
  785. } else {
  786. m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
  787. m.Encoding = encodingBase64
  788. }
  789. return nil
  790. }
  791. func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
  792. if !utf8.Valid(body.PeekedBytes) {
  793. return errHTTPBadRequestMessageNotUTF8
  794. }
  795. if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
  796. m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
  797. }
  798. if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
  799. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  800. }
  801. return nil
  802. }
  803. func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
  804. if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
  805. return errHTTPBadRequestAttachmentsDisallowed
  806. }
  807. vinfo, err := v.Info()
  808. if err != nil {
  809. return err
  810. }
  811. attachmentExpiry := time.Now().Add(vinfo.Limits.AttachmentExpiryDuration).Unix()
  812. if m.Time > attachmentExpiry {
  813. return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
  814. }
  815. contentLengthStr := r.Header.Get("Content-Length")
  816. if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
  817. contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
  818. if err == nil && (contentLength > vinfo.Stats.AttachmentTotalSizeRemaining || contentLength > vinfo.Limits.AttachmentFileSizeLimit) {
  819. return errHTTPEntityTooLargeAttachment
  820. }
  821. }
  822. if m.Attachment == nil {
  823. m.Attachment = &attachment{}
  824. }
  825. var ext string
  826. m.Sender = v.ip // Important for attachment rate limiting
  827. m.Attachment.Expires = attachmentExpiry
  828. m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
  829. m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
  830. if m.Attachment.Name == "" {
  831. m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
  832. }
  833. if m.Message == "" {
  834. m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
  835. }
  836. limiters := []util.Limiter{
  837. v.BandwidthLimiter(),
  838. util.NewFixedLimiter(vinfo.Limits.AttachmentFileSizeLimit),
  839. util.NewFixedLimiter(vinfo.Stats.AttachmentTotalSizeRemaining),
  840. }
  841. m.Attachment.Size, err = s.fileCache.Write(m.ID, body, limiters...)
  842. if err == util.ErrLimitReached {
  843. return errHTTPEntityTooLargeAttachment
  844. } else if err != nil {
  845. return err
  846. }
  847. return nil
  848. }
  849. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  850. encoder := func(msg *message) (string, error) {
  851. var buf bytes.Buffer
  852. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  853. return "", err
  854. }
  855. return buf.String(), nil
  856. }
  857. return s.handleSubscribeHTTP(w, r, v, "application/x-ndjson", encoder)
  858. }
  859. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  860. encoder := func(msg *message) (string, error) {
  861. var buf bytes.Buffer
  862. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  863. return "", err
  864. }
  865. if msg.Event != messageEvent {
  866. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  867. }
  868. return fmt.Sprintf("data: %s\n", buf.String()), nil
  869. }
  870. return s.handleSubscribeHTTP(w, r, v, "text/event-stream", encoder)
  871. }
  872. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  873. encoder := func(msg *message) (string, error) {
  874. if msg.Event == messageEvent { // only handle default events
  875. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  876. }
  877. return "\n", nil // "keepalive" and "open" events just send an empty line
  878. }
  879. return s.handleSubscribeHTTP(w, r, v, "text/plain", encoder)
  880. }
  881. func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
  882. log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
  883. defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
  884. if err := v.SubscriptionAllowed(); err != nil {
  885. return errHTTPTooManyRequestsLimitSubscriptions
  886. }
  887. defer v.RemoveSubscription()
  888. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  889. if err != nil {
  890. return err
  891. }
  892. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  893. if err != nil {
  894. return err
  895. }
  896. var wlock sync.Mutex
  897. defer func() {
  898. // Hack: This is the fix for a horrible data race that I have not been able to figure out in quite some time.
  899. // It appears to be happening when the Go HTTP code reads from the socket when closing the request (i.e. AFTER
  900. // this function returns), and causes a data race with the ResponseWriter. Locking wlock here silences the
  901. // data race detector. See https://github.com/binwiederhier/ntfy/issues/338#issuecomment-1163425889.
  902. wlock.TryLock()
  903. }()
  904. sub := func(v *visitor, msg *message) error {
  905. if !filters.Pass(msg) {
  906. return nil
  907. }
  908. m, err := encoder(msg)
  909. if err != nil {
  910. return err
  911. }
  912. wlock.Lock()
  913. defer wlock.Unlock()
  914. if _, err := w.Write([]byte(m)); err != nil {
  915. return err
  916. }
  917. if fl, ok := w.(http.Flusher); ok {
  918. fl.Flush()
  919. }
  920. return nil
  921. }
  922. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  923. w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
  924. if poll {
  925. return s.sendOldMessages(topics, since, scheduled, v, sub)
  926. }
  927. subscriberIDs := make([]int, 0)
  928. for _, t := range topics {
  929. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  930. }
  931. defer func() {
  932. for i, subscriberID := range subscriberIDs {
  933. topics[i].Unsubscribe(subscriberID) // Order!
  934. }
  935. }()
  936. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  937. return err
  938. }
  939. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  940. return err
  941. }
  942. for {
  943. select {
  944. case <-r.Context().Done():
  945. return nil
  946. case <-time.After(s.config.KeepaliveInterval):
  947. log.Trace("%s Sending keepalive message", logHTTPPrefix(v, r))
  948. v.Keepalive()
  949. if err := sub(v, newKeepaliveMessage(topicsStr)); err != nil { // Send keepalive message
  950. return err
  951. }
  952. }
  953. }
  954. }
  955. func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *visitor) error {
  956. if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
  957. return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
  958. }
  959. if err := v.SubscriptionAllowed(); err != nil {
  960. return errHTTPTooManyRequestsLimitSubscriptions
  961. }
  962. defer v.RemoveSubscription()
  963. log.Debug("%s WebSocket connection opened", logHTTPPrefix(v, r))
  964. defer log.Debug("%s WebSocket connection closed", logHTTPPrefix(v, r))
  965. topics, topicsStr, err := s.topicsFromPath(r.URL.Path)
  966. if err != nil {
  967. return err
  968. }
  969. poll, since, scheduled, filters, err := parseSubscribeParams(r)
  970. if err != nil {
  971. return err
  972. }
  973. upgrader := &websocket.Upgrader{
  974. ReadBufferSize: wsBufferSize,
  975. WriteBufferSize: wsBufferSize,
  976. CheckOrigin: func(r *http.Request) bool {
  977. return true // We're open for business!
  978. },
  979. }
  980. conn, err := upgrader.Upgrade(w, r, nil)
  981. if err != nil {
  982. return err
  983. }
  984. defer conn.Close()
  985. var wlock sync.Mutex
  986. g, ctx := errgroup.WithContext(context.Background())
  987. g.Go(func() error {
  988. pongWait := s.config.KeepaliveInterval + wsPongWait
  989. conn.SetReadLimit(wsReadLimit)
  990. if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
  991. return err
  992. }
  993. conn.SetPongHandler(func(appData string) error {
  994. log.Trace("%s Received WebSocket pong", logHTTPPrefix(v, r))
  995. return conn.SetReadDeadline(time.Now().Add(pongWait))
  996. })
  997. for {
  998. _, _, err := conn.NextReader()
  999. if err != nil {
  1000. return err
  1001. }
  1002. }
  1003. })
  1004. g.Go(func() error {
  1005. ping := func() error {
  1006. wlock.Lock()
  1007. defer wlock.Unlock()
  1008. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1009. return err
  1010. }
  1011. log.Trace("%s Sending WebSocket ping", logHTTPPrefix(v, r))
  1012. return conn.WriteMessage(websocket.PingMessage, nil)
  1013. }
  1014. for {
  1015. select {
  1016. case <-ctx.Done():
  1017. return nil
  1018. case <-time.After(s.config.KeepaliveInterval):
  1019. v.Keepalive()
  1020. if err := ping(); err != nil {
  1021. return err
  1022. }
  1023. }
  1024. }
  1025. })
  1026. sub := func(v *visitor, msg *message) error {
  1027. if !filters.Pass(msg) {
  1028. return nil
  1029. }
  1030. wlock.Lock()
  1031. defer wlock.Unlock()
  1032. if err := conn.SetWriteDeadline(time.Now().Add(wsWriteWait)); err != nil {
  1033. return err
  1034. }
  1035. return conn.WriteJSON(msg)
  1036. }
  1037. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  1038. if poll {
  1039. return s.sendOldMessages(topics, since, scheduled, v, sub)
  1040. }
  1041. subscriberIDs := make([]int, 0)
  1042. for _, t := range topics {
  1043. subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
  1044. }
  1045. defer func() {
  1046. for i, subscriberID := range subscriberIDs {
  1047. topics[i].Unsubscribe(subscriberID) // Order!
  1048. }
  1049. }()
  1050. if err := sub(v, newOpenMessage(topicsStr)); err != nil { // Send out open message
  1051. return err
  1052. }
  1053. if err := s.sendOldMessages(topics, since, scheduled, v, sub); err != nil {
  1054. return err
  1055. }
  1056. err = g.Wait()
  1057. if err != nil && websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
  1058. log.Trace("%s WebSocket connection closed: %s", logHTTPPrefix(v, r), err.Error())
  1059. return nil // Normal closures are not errors; note: "1006 (abnormal closure)" is treated as normal, because people disconnect a lot
  1060. }
  1061. return err
  1062. }
  1063. func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
  1064. poll = readBoolParam(r, false, "x-poll", "poll", "po")
  1065. scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
  1066. since, err = parseSince(r, poll)
  1067. if err != nil {
  1068. return
  1069. }
  1070. filters, err = parseQueryFilters(r)
  1071. if err != nil {
  1072. return
  1073. }
  1074. return
  1075. }
  1076. // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
  1077. // marker, returning only messages that are newer than the marker.
  1078. func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
  1079. if since.IsNone() {
  1080. return nil
  1081. }
  1082. messages := make([]*message, 0)
  1083. for _, t := range topics {
  1084. topicMessages, err := s.messageCache.Messages(t.ID, since, scheduled)
  1085. if err != nil {
  1086. return err
  1087. }
  1088. messages = append(messages, topicMessages...)
  1089. }
  1090. sort.Slice(messages, func(i, j int) bool {
  1091. return messages[i].Time < messages[j].Time
  1092. })
  1093. for _, m := range messages {
  1094. if err := sub(v, m); err != nil {
  1095. return err
  1096. }
  1097. }
  1098. return nil
  1099. }
  1100. // parseSince returns a timestamp identifying the time span from which cached messages should be received.
  1101. //
  1102. // Values in the "since=..." parameter can be either a unix timestamp or a duration (e.g. 12h), or
  1103. // "all" for all messages.
  1104. func parseSince(r *http.Request, poll bool) (sinceMarker, error) {
  1105. since := readParam(r, "x-since", "since", "si")
  1106. // Easy cases (empty, all, none)
  1107. if since == "" {
  1108. if poll {
  1109. return sinceAllMessages, nil
  1110. }
  1111. return sinceNoMessages, nil
  1112. } else if since == "all" {
  1113. return sinceAllMessages, nil
  1114. } else if since == "none" {
  1115. return sinceNoMessages, nil
  1116. }
  1117. // ID, timestamp, duration
  1118. if validMessageID(since) {
  1119. return newSinceID(since), nil
  1120. } else if s, err := strconv.ParseInt(since, 10, 64); err == nil {
  1121. return newSinceTime(s), nil
  1122. } else if d, err := time.ParseDuration(since); err == nil {
  1123. return newSinceTime(time.Now().Add(-1 * d).Unix()), nil
  1124. }
  1125. return sinceNoMessages, errHTTPBadRequestSinceInvalid
  1126. }
  1127. func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  1128. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST, PATCH, DELETE")
  1129. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1130. w.Header().Set("Access-Control-Allow-Headers", "*") // CORS, allow auth via JS // FIXME is this terrible?
  1131. return nil
  1132. }
  1133. func (s *Server) topicFromPath(path string) (*topic, error) {
  1134. parts := strings.Split(path, "/")
  1135. if len(parts) < 2 {
  1136. return nil, errHTTPBadRequestTopicInvalid
  1137. }
  1138. topics, err := s.topicsFromIDs(parts[1])
  1139. if err != nil {
  1140. return nil, err
  1141. }
  1142. return topics[0], nil
  1143. }
  1144. func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
  1145. parts := strings.Split(path, "/")
  1146. if len(parts) < 2 {
  1147. return nil, "", errHTTPBadRequestTopicInvalid
  1148. }
  1149. topicIDs := util.SplitNoEmpty(parts[1], ",")
  1150. topics, err := s.topicsFromIDs(topicIDs...)
  1151. if err != nil {
  1152. return nil, "", errHTTPBadRequestTopicInvalid
  1153. }
  1154. return topics, parts[1], nil
  1155. }
  1156. func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
  1157. s.mu.Lock()
  1158. defer s.mu.Unlock()
  1159. topics := make([]*topic, 0)
  1160. for _, id := range ids {
  1161. if util.Contains(disallowedTopics, id) {
  1162. return nil, errHTTPBadRequestTopicDisallowed
  1163. }
  1164. if _, ok := s.topics[id]; !ok {
  1165. if len(s.topics) >= s.config.TotalTopicLimit {
  1166. return nil, errHTTPTooManyRequestsLimitTotalTopics
  1167. }
  1168. s.topics[id] = newTopic(id)
  1169. }
  1170. topics = append(topics, s.topics[id])
  1171. }
  1172. return topics, nil
  1173. }
  1174. func (s *Server) execManager() {
  1175. log.Debug("Manager: Starting")
  1176. defer log.Debug("Manager: Finished")
  1177. // WARNING: Make sure to only selectively lock with the mutex, and be aware that this
  1178. // there is no mutex for the entire function.
  1179. // Expire visitors from rate visitors map
  1180. s.mu.Lock()
  1181. staleVisitors := 0
  1182. for ip, v := range s.visitors {
  1183. if v.Stale() {
  1184. log.Trace("Deleting stale visitor %s", v.ip)
  1185. delete(s.visitors, ip)
  1186. staleVisitors++
  1187. }
  1188. }
  1189. s.mu.Unlock()
  1190. log.Debug("Manager: Deleted %d stale visitor(s)", staleVisitors)
  1191. // Delete expired user tokens
  1192. if s.userManager != nil {
  1193. if err := s.userManager.RemoveExpiredTokens(); err != nil {
  1194. log.Warn("Error expiring user tokens: %s", err.Error())
  1195. }
  1196. }
  1197. // Delete expired attachments
  1198. if s.fileCache != nil {
  1199. ids, err := s.messageCache.AttachmentsExpired()
  1200. if err != nil {
  1201. log.Warn("Manager: Error retrieving expired attachments: %s", err.Error())
  1202. } else if len(ids) > 0 {
  1203. if log.IsDebug() {
  1204. log.Debug("Manager: Deleting attachments %s", strings.Join(ids, ", "))
  1205. }
  1206. if err := s.fileCache.Remove(ids...); err != nil {
  1207. log.Warn("Manager: Error deleting attachments: %s", err.Error())
  1208. }
  1209. if err := s.messageCache.MarkAttachmentsDeleted(ids...); err != nil {
  1210. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1211. }
  1212. } else {
  1213. log.Debug("Manager: No expired attachments to delete")
  1214. }
  1215. }
  1216. // DeleteMessages message cache
  1217. log.Debug("Manager: Pruning messages")
  1218. expiredMessageIDs, err := s.messageCache.MessagesExpired()
  1219. if err != nil {
  1220. log.Warn("Manager: Error retrieving expired messages: %s", err.Error())
  1221. } else if len(expiredMessageIDs) > 0 {
  1222. if err := s.fileCache.Remove(expiredMessageIDs...); err != nil {
  1223. log.Warn("Manager: Error deleting attachments for expired messages: %s", err.Error())
  1224. }
  1225. if err := s.messageCache.DeleteMessages(expiredMessageIDs...); err != nil {
  1226. log.Warn("Manager: Error marking attachments deleted: %s", err.Error())
  1227. }
  1228. } else {
  1229. log.Debug("Manager: No expired messages to delete")
  1230. }
  1231. // Message count per topic
  1232. var messages int
  1233. messageCounts, err := s.messageCache.MessageCounts()
  1234. if err != nil {
  1235. log.Warn("Manager: Cannot get message counts: %s", err.Error())
  1236. messageCounts = make(map[string]int) // Empty, so we can continue
  1237. }
  1238. for _, count := range messageCounts {
  1239. messages += count
  1240. }
  1241. // Remove subscriptions without subscribers
  1242. s.mu.Lock()
  1243. var subscribers int
  1244. for _, t := range s.topics {
  1245. subs := t.SubscribersCount()
  1246. msgs, exists := messageCounts[t.ID]
  1247. if subs == 0 && (!exists || msgs == 0) {
  1248. log.Trace("Deleting empty topic %s", t.ID)
  1249. delete(s.topics, t.ID)
  1250. continue
  1251. }
  1252. subscribers += subs
  1253. }
  1254. s.mu.Unlock()
  1255. // Mail stats
  1256. var receivedMailTotal, receivedMailSuccess, receivedMailFailure int64
  1257. if s.smtpServerBackend != nil {
  1258. receivedMailTotal, receivedMailSuccess, receivedMailFailure = s.smtpServerBackend.Counts()
  1259. }
  1260. var sentMailTotal, sentMailSuccess, sentMailFailure int64
  1261. if s.smtpSender != nil {
  1262. sentMailTotal, sentMailSuccess, sentMailFailure = s.smtpSender.Counts()
  1263. }
  1264. // Print stats
  1265. s.mu.Lock()
  1266. messagesCount, topicsCount, visitorsCount := s.messages, len(s.topics), len(s.visitors)
  1267. s.mu.Unlock()
  1268. log.Info("Stats: %d messages published, %d in cache, %d topic(s) active, %d subscriber(s), %d visitor(s), %d mails received (%d successful, %d failed), %d mails sent (%d successful, %d failed)",
  1269. messagesCount, messages, topicsCount, subscribers, visitorsCount,
  1270. receivedMailTotal, receivedMailSuccess, receivedMailFailure,
  1271. sentMailTotal, sentMailSuccess, sentMailFailure)
  1272. }
  1273. func (s *Server) runSMTPServer() error {
  1274. s.smtpServerBackend = newMailBackend(s.config, s.handle)
  1275. s.smtpServer = smtp.NewServer(s.smtpServerBackend)
  1276. s.smtpServer.Addr = s.config.SMTPServerListen
  1277. s.smtpServer.Domain = s.config.SMTPServerDomain
  1278. s.smtpServer.ReadTimeout = 10 * time.Second
  1279. s.smtpServer.WriteTimeout = 10 * time.Second
  1280. s.smtpServer.MaxMessageBytes = 1024 * 1024 // Must be much larger than message size (headers, multipart, etc.)
  1281. s.smtpServer.MaxRecipients = 1
  1282. s.smtpServer.AllowInsecureAuth = true
  1283. return s.smtpServer.ListenAndServe()
  1284. }
  1285. func (s *Server) runManager() {
  1286. for {
  1287. select {
  1288. case <-time.After(s.config.ManagerInterval):
  1289. s.execManager()
  1290. case <-s.closeChan:
  1291. return
  1292. }
  1293. }
  1294. }
  1295. // runStatsResetter runs once a day (usually midnight UTC) to reset all the visitor's message and
  1296. // email counters. The stats are used to display the counters in the web app, as well as for rate limiting.
  1297. func (s *Server) runStatsResetter() {
  1298. for {
  1299. runAt := util.NextOccurrenceUTC(s.config.VisitorStatsResetTime, time.Now())
  1300. timer := time.NewTimer(time.Until(runAt))
  1301. log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
  1302. select {
  1303. case <-timer.C:
  1304. s.resetStats()
  1305. case <-s.closeChan:
  1306. timer.Stop()
  1307. return
  1308. }
  1309. }
  1310. }
  1311. func (s *Server) resetStats() {
  1312. log.Info("Resetting all visitor stats (daily task)")
  1313. s.mu.Lock()
  1314. defer s.mu.Unlock() // Includes the database query to avoid races with other processes
  1315. for _, v := range s.visitors {
  1316. v.ResetStats()
  1317. }
  1318. if s.userManager != nil {
  1319. if err := s.userManager.ResetStats(); err != nil {
  1320. log.Warn("Failed to write to database: %s", err.Error())
  1321. }
  1322. }
  1323. }
  1324. func (s *Server) runFirebaseKeepaliver() {
  1325. if s.firebaseClient == nil {
  1326. return
  1327. }
  1328. v := newVisitor(s.config, s.messageCache, s.userManager, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
  1329. for {
  1330. select {
  1331. case <-time.After(s.config.FirebaseKeepaliveInterval):
  1332. s.sendToFirebase(v, newKeepaliveMessage(firebaseControlTopic))
  1333. case <-time.After(s.config.FirebasePollInterval):
  1334. s.sendToFirebase(v, newKeepaliveMessage(firebasePollTopic))
  1335. case <-s.closeChan:
  1336. return
  1337. }
  1338. }
  1339. }
  1340. func (s *Server) runDelayedSender() {
  1341. for {
  1342. select {
  1343. case <-time.After(s.config.DelayedSenderInterval):
  1344. if err := s.sendDelayedMessages(); err != nil {
  1345. log.Warn("Error sending delayed messages: %s", err.Error())
  1346. }
  1347. case <-s.closeChan:
  1348. return
  1349. }
  1350. }
  1351. }
  1352. func (s *Server) sendDelayedMessages() error {
  1353. messages, err := s.messageCache.MessagesDue()
  1354. if err != nil {
  1355. return err
  1356. }
  1357. for _, m := range messages {
  1358. var v *visitor
  1359. if s.userManager != nil && m.User != "" {
  1360. u, err := s.userManager.User(m.User)
  1361. if err != nil {
  1362. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1363. continue
  1364. }
  1365. v = s.visitorFromUser(u, m.Sender)
  1366. } else {
  1367. v = s.visitorFromIP(m.Sender)
  1368. }
  1369. if err := s.sendDelayedMessage(v, m); err != nil {
  1370. log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
  1371. }
  1372. }
  1373. return nil
  1374. }
  1375. func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
  1376. log.Debug("%s Sending delayed message", logMessagePrefix(v, m))
  1377. s.mu.Lock()
  1378. t, ok := s.topics[m.Topic] // If no subscribers, just mark message as published
  1379. s.mu.Unlock()
  1380. if ok {
  1381. go func() {
  1382. // We do not rate-limit messages here, since we've rate limited them in the PUT/POST handler
  1383. if err := t.Publish(v, m); err != nil {
  1384. log.Warn("%s Unable to publish message: %v", logMessagePrefix(v, m), err.Error())
  1385. }
  1386. }()
  1387. }
  1388. if s.firebaseClient != nil { // Firebase subscribers may not show up in topics map
  1389. go s.sendToFirebase(v, m)
  1390. }
  1391. if s.config.UpstreamBaseURL != "" {
  1392. go s.forwardPollRequest(v, m)
  1393. }
  1394. if err := s.messageCache.MarkPublished(m); err != nil {
  1395. return err
  1396. }
  1397. return nil
  1398. }
  1399. func (s *Server) limitRequests(next handleFunc) handleFunc {
  1400. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1401. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  1402. return next(w, r, v)
  1403. } else if err := v.RequestAllowed(); err != nil {
  1404. return errHTTPTooManyRequestsLimitRequests
  1405. }
  1406. return next(w, r, v)
  1407. }
  1408. }
  1409. // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
  1410. // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
  1411. func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
  1412. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1413. m, err := readJSONWithLimit[publishMessage](r.Body, s.config.MessageLimit*2) // 2x to account for JSON format overhead
  1414. if err != nil {
  1415. return err
  1416. }
  1417. if !topicRegex.MatchString(m.Topic) {
  1418. return errHTTPBadRequestTopicInvalid
  1419. }
  1420. if m.Message == "" {
  1421. m.Message = emptyMessageBody
  1422. }
  1423. r.URL.Path = "/" + m.Topic
  1424. r.Body = io.NopCloser(strings.NewReader(m.Message))
  1425. if m.Title != "" {
  1426. r.Header.Set("X-Title", m.Title)
  1427. }
  1428. if m.Priority != 0 {
  1429. r.Header.Set("X-Priority", fmt.Sprintf("%d", m.Priority))
  1430. }
  1431. if m.Tags != nil && len(m.Tags) > 0 {
  1432. r.Header.Set("X-Tags", strings.Join(m.Tags, ","))
  1433. }
  1434. if m.Attach != "" {
  1435. r.Header.Set("X-Attach", m.Attach)
  1436. }
  1437. if m.Filename != "" {
  1438. r.Header.Set("X-Filename", m.Filename)
  1439. }
  1440. if m.Click != "" {
  1441. r.Header.Set("X-Click", m.Click)
  1442. }
  1443. if m.Icon != "" {
  1444. r.Header.Set("X-Icon", m.Icon)
  1445. }
  1446. if len(m.Actions) > 0 {
  1447. actionsStr, err := json.Marshal(m.Actions)
  1448. if err != nil {
  1449. return errHTTPBadRequestMessageJSONInvalid
  1450. }
  1451. r.Header.Set("X-Actions", string(actionsStr))
  1452. }
  1453. if m.Email != "" {
  1454. r.Header.Set("X-Email", m.Email)
  1455. }
  1456. if m.Delay != "" {
  1457. r.Header.Set("X-Delay", m.Delay)
  1458. }
  1459. return next(w, r, v)
  1460. }
  1461. }
  1462. func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
  1463. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1464. newRequest, err := newRequestFromMatrixJSON(r, s.config.BaseURL, s.config.MessageLimit)
  1465. if err != nil {
  1466. return err
  1467. }
  1468. if err := next(w, newRequest, v); err != nil {
  1469. return &errMatrix{pushKey: newRequest.Header.Get(matrixPushKeyHeader), err: err}
  1470. }
  1471. return nil
  1472. }
  1473. }
  1474. func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
  1475. return s.autorizeTopic(next, user.PermissionWrite)
  1476. }
  1477. func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
  1478. return s.autorizeTopic(next, user.PermissionRead)
  1479. }
  1480. func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc {
  1481. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  1482. if s.userManager == nil {
  1483. return next(w, r, v)
  1484. }
  1485. topics, _, err := s.topicsFromPath(r.URL.Path)
  1486. if err != nil {
  1487. return err
  1488. }
  1489. for _, t := range topics {
  1490. if err := s.userManager.Authorize(v.user, t.ID, perm); err != nil {
  1491. log.Info("unauthorized: %s", err.Error())
  1492. return errHTTPForbidden
  1493. }
  1494. }
  1495. return next(w, r, v)
  1496. }
  1497. }
  1498. // visitor creates or retrieves a rate.Limiter for the given visitor.
  1499. // Note that this function will always return a visitor, even if an error occurs.
  1500. func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
  1501. ip := extractIPAddress(r, s.config.BehindProxy)
  1502. var u *user.User // may stay nil if no auth header!
  1503. if u, err = s.authenticate(r); err != nil {
  1504. log.Debug("authentication failed: %s", err.Error())
  1505. err = errHTTPUnauthorized // Always return visitor, even when error occurs!
  1506. }
  1507. if u != nil {
  1508. v = s.visitorFromUser(u, ip)
  1509. } else {
  1510. v = s.visitorFromIP(ip)
  1511. }
  1512. v.mu.Lock()
  1513. v.user = u
  1514. v.mu.Unlock()
  1515. return v, err // Always return visitor, even when error occurs!
  1516. }
  1517. // authenticate a user based on basic auth username/password (Authorization: Basic ...), or token auth (Authorization: Bearer ...).
  1518. // The Authorization header can be passed as a header or the ?auth=... query param. The latter is required only to
  1519. // support the WebSocket JavaScript class, which does not support passing headers during the initial request. The auth
  1520. // query param is effectively double base64 encoded. Its format is base64(Basic base64(user:pass)).
  1521. func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
  1522. value := strings.TrimSpace(r.Header.Get("Authorization"))
  1523. queryParam := readQueryParam(r, "authorization", "auth")
  1524. if queryParam != "" {
  1525. a, err := base64.RawURLEncoding.DecodeString(queryParam)
  1526. if err != nil {
  1527. return nil, err
  1528. }
  1529. value = strings.TrimSpace(string(a))
  1530. }
  1531. if value == "" {
  1532. return nil, nil
  1533. } else if s.userManager == nil {
  1534. return nil, errHTTPUnauthorized
  1535. }
  1536. if strings.HasPrefix(value, "Bearer") {
  1537. return s.authenticateBearerAuth(value)
  1538. }
  1539. return s.authenticateBasicAuth(r, value)
  1540. }
  1541. func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *user.User, err error) {
  1542. r.Header.Set("Authorization", value)
  1543. username, password, ok := r.BasicAuth()
  1544. if !ok {
  1545. return nil, errors.New("invalid basic auth")
  1546. }
  1547. return s.userManager.Authenticate(username, password)
  1548. }
  1549. func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
  1550. token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
  1551. return s.userManager.AuthenticateToken(token)
  1552. }
  1553. func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
  1554. s.mu.Lock()
  1555. defer s.mu.Unlock()
  1556. v, exists := s.visitors[visitorID]
  1557. if !exists {
  1558. s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
  1559. return s.visitors[visitorID]
  1560. }
  1561. v.Keepalive()
  1562. return v
  1563. }
  1564. func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
  1565. return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
  1566. }
  1567. func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
  1568. return s.visitorFromID(fmt.Sprintf("user:%s", user.Name), ip, user)
  1569. }
  1570. func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
  1571. w.Header().Set("Content-Type", "application/json")
  1572. w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
  1573. if err := json.NewEncoder(w).Encode(v); err != nil {
  1574. return err
  1575. }
  1576. return nil
  1577. }