|
|
@@ -43,7 +43,7 @@ type Server struct {
|
|
|
smtpServerBackend *smtpBackend
|
|
|
smtpSender mailer
|
|
|
topics map[string]*topic
|
|
|
- visitors map[netip.Addr]*visitor
|
|
|
+ visitors map[string]*visitor // ip:<ip> or user:<user>
|
|
|
firebaseClient *firebaseClient
|
|
|
messages int64
|
|
|
auth auth.Auther
|
|
|
@@ -69,7 +69,9 @@ var (
|
|
|
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
|
|
|
|
|
|
webConfigPath = "/config.js"
|
|
|
- userStatsPath = "/user/stats"
|
|
|
+ userStatsPath = "/user/stats" // FIXME get rid of this in favor of /user/account
|
|
|
+ userAuthPath = "/user/auth"
|
|
|
+ userAccountPath = "/user/account"
|
|
|
matrixPushPath = "/_matrix/push/v1/notify"
|
|
|
staticRegex = regexp.MustCompile(`^/static/.+`)
|
|
|
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
|
|
|
@@ -151,7 +153,7 @@ func New(conf *Config) (*Server, error) {
|
|
|
smtpSender: mailer,
|
|
|
topics: topics,
|
|
|
auth: auther,
|
|
|
- visitors: make(map[netip.Addr]*visitor),
|
|
|
+ visitors: make(map[string]*visitor),
|
|
|
}, nil
|
|
|
}
|
|
|
|
|
|
@@ -255,12 +257,15 @@ func (s *Server) Stop() {
|
|
|
}
|
|
|
|
|
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|
|
- v := s.visitor(r)
|
|
|
- log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
|
|
- if log.IsTrace() {
|
|
|
- log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
|
|
|
+ v, err := s.visitor(r) // Note: Always returns v, even when error is returned
|
|
|
+ if err == nil {
|
|
|
+ log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
|
|
|
+ if log.IsTrace() {
|
|
|
+ log.Trace("%s Entire request (headers and body):\n%s", logHTTPPrefix(v, r), renderHTTPRequest(r))
|
|
|
+ }
|
|
|
+ err = s.handleInternal(w, r, v)
|
|
|
}
|
|
|
- if err := s.handleInternal(w, r, v); err != nil {
|
|
|
+ if err != nil {
|
|
|
if websocket.IsWebSocketUpgrade(r) {
|
|
|
isNormalError := strings.Contains(err.Error(), "i/o timeout")
|
|
|
if isNormalError {
|
|
|
@@ -300,6 +305,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|
|
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
|
|
|
} else if r.Method == http.MethodGet && r.URL.Path == userStatsPath {
|
|
|
return s.handleUserStats(w, r, v)
|
|
|
+ } else if r.Method == http.MethodGet && r.URL.Path == userAuthPath {
|
|
|
+ return s.handleUserAuth(w, r, v)
|
|
|
+ } else if r.Method == http.MethodGet && r.URL.Path == userAccountPath {
|
|
|
+ return s.handleUserAccount(w, r, v)
|
|
|
} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
|
|
|
return s.handleMatrixDiscovery(w)
|
|
|
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
|
|
|
@@ -394,6 +403,72 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
+var sessions = make(map[string]*auth.User) // token-> user
|
|
|
+
|
|
|
+type tokenAuthResponse struct {
|
|
|
+ Token string `json:"token"`
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
+ // TODO rate limit
|
|
|
+ if v.user == nil {
|
|
|
+ return errHTTPUnauthorized
|
|
|
+ }
|
|
|
+ token := util.RandomString(32)
|
|
|
+ sessions[token] = v.user
|
|
|
+ w.Header().Set("Content-Type", "text/json")
|
|
|
+ w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
+ response := &tokenAuthResponse{
|
|
|
+ Token: token,
|
|
|
+ }
|
|
|
+ if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+type userSubscriptionResponse struct {
|
|
|
+ BaseURL string `json:"base_url"`
|
|
|
+ Topic string `json:"topic"`
|
|
|
+}
|
|
|
+
|
|
|
+type userAccountResponse struct {
|
|
|
+ Username string `json:"username"`
|
|
|
+ Role string `json:"role,omitempty"`
|
|
|
+ Language string `json:"language,omitempty"`
|
|
|
+ Plan struct {
|
|
|
+ Id int `json:"id"`
|
|
|
+ Name string `json:"name"`
|
|
|
+ } `json:"plan,omitempty"`
|
|
|
+ Notification struct {
|
|
|
+ Sound string `json:"sound"`
|
|
|
+ MinPriority string `json:"min_priority"`
|
|
|
+ DeleteAfter int `json:"delete_after"`
|
|
|
+ } `json:"notification,omitempty"`
|
|
|
+ Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
+ w.Header().Set("Content-Type", "text/json")
|
|
|
+ w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
|
|
+ var response *userAccountResponse
|
|
|
+ if v.user != nil {
|
|
|
+ response = &userAccountResponse{
|
|
|
+ Username: v.user.Name,
|
|
|
+ Role: string(v.user.Role),
|
|
|
+ Language: "en_US",
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ response = &userAccountResponse{
|
|
|
+ Username: "anonymous",
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request, _ *visitor) error {
|
|
|
r.URL.Path = webSiteDir + r.URL.Path
|
|
|
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
|
|
|
@@ -1221,7 +1296,7 @@ func (s *Server) runFirebaseKeepaliver() {
|
|
|
if s.firebaseClient == nil {
|
|
|
return
|
|
|
}
|
|
|
- v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
|
|
|
+ v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified(), nil) // Background process, not a real visitor, uses IP 0.0.0.0
|
|
|
for {
|
|
|
select {
|
|
|
case <-time.After(s.config.FirebaseKeepaliveInterval):
|
|
|
@@ -1253,7 +1328,7 @@ func (s *Server) sendDelayedMessages() error {
|
|
|
return err
|
|
|
}
|
|
|
for _, m := range messages {
|
|
|
- v := s.visitorFromIP(m.Sender)
|
|
|
+ v := s.visitorFromID(fmt.Sprintf("ip:%s", m.Sender.String()), m.Sender, nil) // FIXME: This is wrong wrong wrong
|
|
|
if err := s.sendDelayedMessage(v, m); err != nil {
|
|
|
log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
|
|
|
}
|
|
|
@@ -1395,16 +1470,8 @@ func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- var user *auth.User // may stay nil if no auth header!
|
|
|
- username, password, ok := extractUserPass(r)
|
|
|
- if ok {
|
|
|
- if user, err = s.auth.Authenticate(username, password); err != nil {
|
|
|
- log.Info("authentication failed: %s", err.Error())
|
|
|
- return errHTTPUnauthorized
|
|
|
- }
|
|
|
- }
|
|
|
for _, t := range topics {
|
|
|
- if err := s.auth.Authorize(user, t.ID, perm); err != nil {
|
|
|
+ if err := s.auth.Authorize(v.user, t.ID, perm); err != nil {
|
|
|
log.Info("unauthorized: %s", err.Error())
|
|
|
return errHTTPForbidden
|
|
|
}
|
|
|
@@ -1435,8 +1502,39 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
|
|
|
}
|
|
|
|
|
|
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
|
|
-// This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
|
|
|
-func (s *Server) visitor(r *http.Request) *visitor {
|
|
|
+// Note that this function will always return a visitor, even if an error occurs.
|
|
|
+func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
|
|
+ ip := s.extractIPAddress(r)
|
|
|
+ visitorID := fmt.Sprintf("ip:%s", ip.String())
|
|
|
+
|
|
|
+ var user *auth.User // may stay nil if no auth header!
|
|
|
+ username, password, ok := extractUserPass(r)
|
|
|
+ if ok {
|
|
|
+ if user, err = s.auth.Authenticate(username, password); err != nil {
|
|
|
+ log.Debug("authentication failed: %s", err.Error())
|
|
|
+ err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
|
|
+ } else {
|
|
|
+ visitorID = fmt.Sprintf("user:%s", user.Name)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ v = s.visitorFromID(visitorID, ip, user)
|
|
|
+ v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
|
|
|
+ return v, err // Always return visitor, even when error occurs!
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
|
|
|
+ s.mu.Lock()
|
|
|
+ defer s.mu.Unlock()
|
|
|
+ v, exists := s.visitors[visitorID]
|
|
|
+ if !exists {
|
|
|
+ s.visitors[visitorID] = newVisitor(s.config, s.messageCache, ip, user)
|
|
|
+ return s.visitors[visitorID]
|
|
|
+ }
|
|
|
+ v.Keepalive()
|
|
|
+ return v
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) extractIPAddress(r *http.Request) netip.Addr {
|
|
|
remoteAddr := r.RemoteAddr
|
|
|
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
|
|
ip := addrPort.Addr()
|
|
|
@@ -1461,17 +1559,5 @@ func (s *Server) visitor(r *http.Request) *visitor {
|
|
|
ip = realIP
|
|
|
}
|
|
|
}
|
|
|
- return s.visitorFromIP(ip)
|
|
|
-}
|
|
|
-
|
|
|
-func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
|
|
|
- s.mu.Lock()
|
|
|
- defer s.mu.Unlock()
|
|
|
- v, exists := s.visitors[ip]
|
|
|
- if !exists {
|
|
|
- s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
|
|
|
- return s.visitors[ip]
|
|
|
- }
|
|
|
- v.Keepalive()
|
|
|
- return v
|
|
|
+ return ip
|
|
|
}
|