util.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package server
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "mime"
  8. "net/http"
  9. "net/netip"
  10. "regexp"
  11. "slices"
  12. "strings"
  13. "heckel.io/ntfy/v2/util"
  14. )
  15. var (
  16. mimeDecoder mime.WordDecoder
  17. // priorityHeaderIgnoreRegex matches specific patterns of the "Priority" header (RFC 9218), so that it can be ignored
  18. priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
  19. // forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239)
  20. // IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]"
  21. forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`)
  22. )
  23. func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
  24. value := strings.ToLower(readParam(r, names...))
  25. if value == "" {
  26. return defaultValue
  27. }
  28. return toBool(value)
  29. }
  30. func isBoolValue(value string) bool {
  31. return value == "1" || value == "yes" || value == "true" || value == "0" || value == "no" || value == "false"
  32. }
  33. func toBool(value string) bool {
  34. return value == "1" || value == "yes" || value == "true"
  35. }
  36. func readCommaSeparatedParam(r *http.Request, names ...string) []string {
  37. if paramStr := readParam(r, names...); paramStr != "" {
  38. return util.Map(util.SplitNoEmpty(paramStr, ","), strings.TrimSpace)
  39. }
  40. return []string{}
  41. }
  42. func readParam(r *http.Request, names ...string) string {
  43. value := readHeaderParam(r, names...)
  44. if value != "" {
  45. return value
  46. }
  47. return readQueryParam(r, names...)
  48. }
  49. func readHeaderParam(r *http.Request, names ...string) string {
  50. for _, name := range names {
  51. value := strings.TrimSpace(maybeDecodeHeader(name, r.Header.Get(name)))
  52. if value != "" {
  53. return value
  54. }
  55. }
  56. return ""
  57. }
  58. func readQueryParam(r *http.Request, names ...string) string {
  59. for _, name := range names {
  60. value := r.URL.Query().Get(strings.ToLower(name))
  61. if value != "" {
  62. return strings.TrimSpace(value)
  63. }
  64. }
  65. return ""
  66. }
  67. // extractIPAddress extracts the IP address of the visitor from the request,
  68. // either from the TCP socket or from a proxy header.
  69. func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader string, proxyTrustedAddresses []string) netip.Addr {
  70. if behindProxy && proxyForwardedHeader != "" {
  71. if addr, err := extractIPAddressFromHeader(r, proxyForwardedHeader, proxyTrustedAddresses); err == nil {
  72. return addr
  73. }
  74. // Fall back to the remote address if the header is not found or invalid
  75. }
  76. addrPort, err := netip.ParseAddrPort(r.RemoteAddr)
  77. if err != nil {
  78. logr(r).Err(err).Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created", r.RemoteAddr)
  79. return netip.IPv4Unspecified()
  80. }
  81. return addrPort.Addr()
  82. }
  83. // extractIPAddressFromHeader extracts the last IP address from the specified header.
  84. //
  85. // It supports multiple formats:
  86. // - single IP address
  87. // - comma-separated list
  88. // - RFC 7239-style list (Forwarded header)
  89. //
  90. // If there are multiple addresses, we first remove the trusted IP addresses from the list, and
  91. // then take the right-most address in the list (as this is the one added by our proxy server).
  92. // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
  93. func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
  94. value := strings.TrimSpace(r.Header.Get(forwardedHeader))
  95. if value == "" {
  96. return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
  97. }
  98. // Extract valid addresses
  99. addrsStrs := util.Map(util.SplitNoEmpty(value, ","), strings.TrimSpace)
  100. var validAddrs []netip.Addr
  101. for _, addrStr := range addrsStrs {
  102. // Handle Forwarded header with for="[IPv6]" or for="IPv4"
  103. if m := forwardedHeaderRegex.FindStringSubmatch(addrStr); len(m) == 2 {
  104. addrRaw := m[1]
  105. if strings.HasPrefix(addrRaw, "[") && strings.HasSuffix(addrRaw, "]") {
  106. addrRaw = addrRaw[1 : len(addrRaw)-1]
  107. }
  108. if addr, err := netip.ParseAddr(addrRaw); err == nil {
  109. validAddrs = append(validAddrs, addr)
  110. }
  111. } else if addr, err := netip.ParseAddr(addrStr); err == nil {
  112. validAddrs = append(validAddrs, addr)
  113. }
  114. }
  115. // Filter out proxy addresses
  116. clientAddrs := util.Filter(validAddrs, func(addr netip.Addr) bool {
  117. return !slices.Contains(trustedAddresses, addr.String())
  118. })
  119. if len(clientAddrs) == 0 {
  120. return netip.IPv4Unspecified(), fmt.Errorf("no client IP address found in %s header: %s", forwardedHeader, value)
  121. }
  122. return clientAddrs[len(clientAddrs)-1], nil
  123. }
  124. func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
  125. obj, err := util.UnmarshalJSONWithLimit[T](r, limit, allowEmpty)
  126. if errors.Is(err, util.ErrUnmarshalJSON) {
  127. return nil, errHTTPBadRequestJSONInvalid
  128. } else if errors.Is(err, util.ErrTooLargeJSON) {
  129. return nil, errHTTPEntityTooLargeJSONBody
  130. } else if err != nil {
  131. return nil, err
  132. }
  133. return obj, nil
  134. }
  135. func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
  136. c := r.Context()
  137. for k, v := range ctx {
  138. c = context.WithValue(c, k, v)
  139. }
  140. return r.WithContext(c)
  141. }
  142. func fromContext[T any](r *http.Request, key contextKey) (T, error) {
  143. t, ok := r.Context().Value(key).(T)
  144. if !ok {
  145. return t, fmt.Errorf("cannot find key %v in request context", key)
  146. }
  147. return t, nil
  148. }
  149. // maybeDecodeHeader decodes the given header value if it is MIME encoded, e.g. "=?utf-8?q?Hello_World?=",
  150. // or returns the original header value if it is not MIME encoded. It also calls maybeIgnoreSpecialHeader
  151. // to ignore the new HTTP "Priority" header.
  152. func maybeDecodeHeader(name, value string) string {
  153. decoded, err := mimeDecoder.DecodeHeader(value)
  154. if err != nil {
  155. return maybeIgnoreSpecialHeader(name, value)
  156. }
  157. return maybeIgnoreSpecialHeader(name, decoded)
  158. }
  159. // maybeIgnoreSpecialHeader ignores the new HTTP "Priority" header (RFC 9218, see https://datatracker.ietf.org/doc/html/rfc9218)
  160. //
  161. // Cloudflare (and potentially other providers) add this to requests when forwarding to the backend (ntfy),
  162. // so we just ignore it. If the "Priority" header is set to "u=*, i" or "u=*" (by Cloudflare), the header will be ignored.
  163. // Returning an empty string will allow the rest of the logic to continue searching for another header (x-priority, prio, p),
  164. // or in the Query parameters.
  165. func maybeIgnoreSpecialHeader(name, value string) string {
  166. if strings.ToLower(name) == "priority" && priorityHeaderIgnoreRegex.MatchString(strings.TrimSpace(value)) {
  167. return ""
  168. }
  169. return value
  170. }