Browse Source

recommended fixes [1 of 2]

Karmanyaah Malhotra 3 years ago
parent
commit
de2ca33700
4 changed files with 36 additions and 31 deletions
  1. 22 23
      cmd/serve.go
  2. 6 1
      server/message_cache.go
  3. 5 5
      server/server.go
  4. 3 2
      util/util.go

+ 22 - 23
cmd/serve.go

@@ -306,32 +306,31 @@ func sigHandlerConfigReload(config string) {
 }
 
 func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
-        //try parsing as prefix
-        prefix, err := netip.ParsePrefix(host)
-        if err == nil {
-                prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter
-                return prefixes, nil                         // success
-        }
+	//try parsing as prefix
+	prefix, err := netip.ParsePrefix(host)
+	if err == nil {
+		prefixes = append(prefixes, prefix.Masked()) // Masked returns the prefix in its canonical form,  the same for every ip in the range. This exists for ease of debugging. For example, 10.1.2.3/16 is 10.1.0.0/16.
+		return prefixes, nil                         // success
+	}
 
-        // not a prefix, parse as host or IP
-        // LookupHost forwards through if it's an IP
-        ips, err := net.LookupHost(host)
-        if err == nil {
-                for _, i := range ips {
-                        ip, err := netip.ParseAddr(i)
-                        if err == nil {
-                                prefix, err := ip.Prefix(ip.BitLen())
-                                if err != nil {
-                                        return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error()))
-                                }
-                                prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip
-                        }
-                }
-        }
-        return
+	// not a prefix, parse as host or IP
+	// LookupHost forwards through if it's an IP
+	ips, err := net.LookupHost(host)
+	if err == nil {
+		for _, i := range ips {
+			ip, err := netip.ParseAddr(i)
+			if err == nil {
+				prefix, err := ip.Prefix(ip.BitLen())
+				if err != nil {
+					return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error()))
+				}
+				prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip
+			}
+		}
+	}
+	return
 }
 
-
 func reloadLogLevel(inputSource altsrc.InputSourceContext) {
 	newLevelStr, err := inputSource.String("log-level")
 	if err != nil {

+ 6 - 1
server/message_cache.go

@@ -456,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 				return nil, err
 			}
 		}
+		senderIP, err := netip.ParseAddr(sender)
+		if err != nil {
+			senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
+		}
+
 		var att *attachment
 		if attachmentName != "" && attachmentURL != "" {
 			att = &attachment{
@@ -479,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 			Icon:       icon,
 			Actions:    actions,
 			Attachment: att,
-			Sender:     netip.MustParseAddr(sender), // Must parse assuming database must be correct
+			Sender:     senderIP, // Must parse assuming database must be correct
 			Encoding:   encoding,
 		})
 	}

+ 5 - 5
server/server.go

@@ -643,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 			return false, false, "", false, errHTTPBadRequestDelayTooLarge
 		}
 		m.Time = delay.Unix()
+		m.Sender = v.ip // Important for rate limiting
 	}
-	m.Sender = v.ip // Important for rate limiting
 	actionsStr := readParam(r, "x-actions", "actions", "action")
 	if actionsStr != "" {
 		m.Actions, err = parseActions(actionsStr)
@@ -1220,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
 	if s.firebaseClient == nil {
 		return
 	}
-	v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor
+	v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
 	for {
 		select {
 		case <-time.After(s.config.FirebaseKeepaliveInterval):
@@ -1287,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
+		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 			return next(w, r, v)
 		} else if err := v.RequestAllowed(); err != nil {
 			return errHTTPTooManyRequestsLimitRequests
@@ -1449,8 +1449,8 @@ func (s *Server) visitor(r *http.Request) *visitor {
 		ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
 		myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
 		if err != nil {
-			log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err)
-			// fall back to regular remote address if x forwarded for is damaged
+			log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
+			// fall back to regular remote address if X-Forwarded-For is damaged
 		} else {
 			ip = myip
 		}

+ 3 - 2
util/util.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"math/rand"
+	"net/netip"
 	"os"
 	"regexp"
 	"strconv"
@@ -46,8 +47,8 @@ func Contains[T comparable](haystack []T, needle T) bool {
 	return false
 }
 
-// ContainsContains returns true if any element of haystack .Contains(needle).
-func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool {
+// ContainsIP returns true if any one of the of prefixes contains the ip.
+func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
 	for _, s := range haystack {
 		if s.Contains(needle) {
 			return true