Просмотр исходного кода

WIP: Auth in 80 lines of code :-)

Philipp Heckel 4 лет назад
Родитель
Сommit
2181227a6e
3 измененных файлов с 83 добавлено и 12 удалено
  1. 31 0
      server/auth_simple.go
  2. 1 0
      server/errors.go
  3. 51 12
      server/server.go

+ 31 - 0
server/auth_simple.go

@@ -0,0 +1,31 @@
+package server
+
+/*
+sqlite> create table user (id int auto increment, user text, password text not null);
+sqlite> create table user_topic (user_id int not null, topic text not null, allow_write int, allow_read int);
+sqlite> create table topic (topic text primary key, allow_anonymous_write int, allow_anonymous_read int);
+*/
+
+const (
+	permRead  = 1
+	permWrite = 2
+)
+
+type auther interface {
+	Authenticate(user, pass string) bool
+	Authorize(user, topic string, perm int) bool
+}
+
+type memAuther struct {
+}
+
+func (m memAuther) Authenticate(user, pass string) bool {
+	return user == "phil" && pass == "phil"
+}
+
+func (m memAuther) Authorize(user, topic string, perm int) bool {
+	if perm == permRead {
+		return true
+	}
+	return user == "phil" && topic == "mytopic"
+}

+ 1 - 0
server/errors.go

@@ -40,6 +40,7 @@ var (
 	errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
 	errHTTPBadRequestWebSocketsUpgradeHeaderMissing  = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""}
 	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
+	errHTTPUnauthorized                              = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", ""}
 	errHTTPTooManyRequestsLimitRequests              = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitEmails                = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitSubscriptions         = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}

+ 51 - 12
server/server.go

@@ -46,6 +46,7 @@ type Server struct {
 	firebase     subscriber
 	mailer       mailer
 	messages     int64
+	auther       auther
 	cache        cache
 	fileCache    *fileCache
 	closeChan    chan bool
@@ -57,6 +58,9 @@ type indexPage struct {
 	CacheDuration time.Duration
 }
 
+// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
+type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
+
 var (
 	topicRegex       = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`)  // No /!
 	topicPathRegex   = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
@@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) {
 		firebase:  firebaseSubscriber,
 		mailer:    mailer,
 		topics:    topics,
+		auther:    &memAuther{},
 		visitors:  make(map[string]*visitor),
 	}, nil
 }
@@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 }
 
 func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
+	v := s.visitor(r)
 	if r.Method == http.MethodGet && r.URL.Path == "/" {
 		return s.handleHome(w, r)
 	} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
@@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 	} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
 		return s.handleDocs(w, r)
 	} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
-		return s.withRateLimit(w, r, s.handleFile)
+		return s.limitRequests(s.handleFile)(w, r, v)
 	} else if r.Method == http.MethodOptions {
 		return s.handleOptions(w, r)
 	} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
 		return s.handleTopic(w, r)
 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handlePublish)
+		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
 	} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handlePublish)
+		return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
 	} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handleSubscribeJSON)
+		return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
 	} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handleSubscribeSSE)
+		return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
 	} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handleSubscribeRaw)
+		return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
 	} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
-		return s.withRateLimit(w, r, s.handleSubscribeWS)
+		return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
 	}
 	return errHTTPNotFound
 }
@@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error {
 	return nil
 }
 
-func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
-	v := s.visitor(r)
-	if err := v.RequestAllowed(); err != nil {
-		return errHTTPTooManyRequestsLimitRequests
+func (s *Server) limitRequests(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if err := v.RequestAllowed(); err != nil {
+			return errHTTPTooManyRequestsLimitRequests
+		}
+		return next(w, r, v)
+	}
+}
+
+func (s *Server) authWrite(next handleFunc) handleFunc {
+	return s.withAuth(next, permWrite)
+}
+
+func (s *Server) authRead(next handleFunc) handleFunc {
+	return s.withAuth(next, permRead)
+}
+
+func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if s.auther == nil {
+			return next(w, r, v)
+		}
+		t, err := s.topicFromPath(r.URL.Path)
+		if err != nil {
+			return err
+		}
+		user, pass, ok := r.BasicAuth()
+		if ok {
+			if !s.auther.Authenticate(user, pass) {
+				return errHTTPUnauthorized
+			}
+		} else {
+			user = "" // Just in case
+		}
+		if !s.auther.Authorize(user, t.ID, perm) {
+			return errHTTPUnauthorized
+		}
+		return next(w, r, v)
 	}
-	return handler(w, r, v)
 }
 
 // visitor creates or retrieves a rate.Limiter for the given visitor.