1
0
Эх сурвалжийг харах

Upgrade smtp library, but not tests

binwiederhier 3 жил өмнө
parent
commit
36c0be1097

+ 17 - 19
server/smtp_server.go

@@ -34,6 +34,9 @@ type smtpBackend struct {
 	mu      sync.Mutex
 	mu      sync.Mutex
 }
 }
 
 
+var _ smtp.Backend = (*smtpBackend)(nil)
+var _ smtp.Session = (*smtpSession)(nil)
+
 func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
 func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Request)) *smtpBackend {
 	return &smtpBackend{
 	return &smtpBackend{
 		config:  conf,
 		config:  conf,
@@ -41,14 +44,9 @@ func newMailBackend(conf *Config, handler func(http.ResponseWriter, *http.Reques
 	}
 	}
 }
 }
 
 
-func (b *smtpBackend) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
-	log.Debug("%s Incoming mail, login with user %s", logSMTPPrefix(state), username)
-	return &smtpSession{backend: b, state: state}, nil
-}
-
-func (b *smtpBackend) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
-	log.Debug("%s Incoming mail, anonymous login", logSMTPPrefix(state))
-	return &smtpSession{backend: b, state: state}, nil
+func (b *smtpBackend) NewSession(conn *smtp.Conn) (smtp.Session, error) {
+	log.Debug("%s Incoming mail", logSMTPPrefix(conn))
+	return &smtpSession{backend: b, conn: conn}, nil
 }
 }
 
 
 func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
 func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
@@ -60,23 +58,23 @@ func (b *smtpBackend) Counts() (total int64, success int64, failure int64) {
 // smtpSession is returned after EHLO.
 // smtpSession is returned after EHLO.
 type smtpSession struct {
 type smtpSession struct {
 	backend *smtpBackend
 	backend *smtpBackend
-	state   *smtp.ConnectionState
+	conn    *smtp.Conn
 	topic   string
 	topic   string
 	mu      sync.Mutex
 	mu      sync.Mutex
 }
 }
 
 
-func (s *smtpSession) AuthPlain(username, password string) error {
-	log.Debug("%s AUTH PLAIN (with username %s)", logSMTPPrefix(s.state), username)
+func (s *smtpSession) AuthPlain(username, _ string) error {
+	log.Debug("%s AUTH PLAIN (with username %s)", logSMTPPrefix(s.conn), username)
 	return nil
 	return nil
 }
 }
 
 
-func (s *smtpSession) Mail(from string, opts smtp.MailOptions) error {
-	log.Debug("%s MAIL FROM: %s (with options: %#v)", logSMTPPrefix(s.state), from, opts)
+func (s *smtpSession) Mail(from string, opts *smtp.MailOptions) error {
+	log.Debug("%s MAIL FROM: %s (with options: %#v)", logSMTPPrefix(s.conn), from, opts)
 	return nil
 	return nil
 }
 }
 
 
 func (s *smtpSession) Rcpt(to string) error {
 func (s *smtpSession) Rcpt(to string) error {
-	log.Debug("%s RCPT TO: %s", logSMTPPrefix(s.state), to)
+	log.Debug("%s RCPT TO: %s", logSMTPPrefix(s.conn), to)
 	return s.withFailCount(func() error {
 	return s.withFailCount(func() error {
 		conf := s.backend.config
 		conf := s.backend.config
 		addressList, err := mail.ParseAddressList(to)
 		addressList, err := mail.ParseAddressList(to)
@@ -114,9 +112,9 @@ func (s *smtpSession) Data(r io.Reader) error {
 			return err
 			return err
 		}
 		}
 		if log.IsTrace() {
 		if log.IsTrace() {
-			log.Trace("%s DATA: %s", logSMTPPrefix(s.state), string(b))
+			log.Trace("%s DATA: %s", logSMTPPrefix(s.conn), string(b))
 		} else if log.IsDebug() {
 		} else if log.IsDebug() {
-			log.Debug("%s DATA: %d byte(s)", logSMTPPrefix(s.state), len(b))
+			log.Debug("%s DATA: %d byte(s)", logSMTPPrefix(s.conn), len(b))
 		}
 		}
 		msg, err := mail.ReadMessage(bytes.NewReader(b))
 		msg, err := mail.ReadMessage(bytes.NewReader(b))
 		if err != nil {
 		if err != nil {
@@ -156,9 +154,9 @@ func (s *smtpSession) Data(r io.Reader) error {
 
 
 func (s *smtpSession) publishMessage(m *message) error {
 func (s *smtpSession) publishMessage(m *message) error {
 	// Extract remote address (for rate limiting)
 	// Extract remote address (for rate limiting)
-	remoteAddr, _, err := net.SplitHostPort(s.state.RemoteAddr.String())
+	remoteAddr, _, err := net.SplitHostPort(s.conn.Conn().RemoteAddr().String())
 	if err != nil {
 	if err != nil {
-		remoteAddr = s.state.RemoteAddr.String()
+		remoteAddr = s.conn.Conn().RemoteAddr().String()
 	}
 	}
 
 
 	// Call HTTP handler with fake HTTP request
 	// Call HTTP handler with fake HTTP request
@@ -198,7 +196,7 @@ func (s *smtpSession) withFailCount(fn func() error) error {
 	if err != nil {
 	if err != nil {
 		// Almost all of these errors are parse errors, and user input errors.
 		// Almost all of these errors are parse errors, and user input errors.
 		// We do not want to spam the log with WARN messages.
 		// We do not want to spam the log with WARN messages.
-		log.Debug("%s Incoming mail error: %s", logSMTPPrefix(s.state), err.Error())
+		log.Debug("%s Incoming mail error: %s", logSMTPPrefix(s.conn), err.Error())
 		s.backend.failure++
 		s.backend.failure++
 	}
 	}
 	return err
 	return err

+ 4 - 4
server/smtp_server_test.go

@@ -34,8 +34,8 @@ Content-Type: text/html; charset="UTF-8"
 		require.Equal(t, "and one more", r.Header.Get("Title"))
 		require.Equal(t, "and one more", r.Header.Get("Title"))
 		require.Equal(t, "what's up", readAll(t, r.Body))
 		require.Equal(t, "what's up", readAll(t, r.Body))
 	})
 	})
-	session, _ := backend.AnonymousLogin(fakeConnState(t, "1.2.3.4"))
-	require.Nil(t, session.Mail("phil@example.com", smtp.MailOptions{}))
+	session, _ := backend.NewSession(fakeConnState(t, "1.2.3.4"))
+	require.Nil(t, session.Mail("phil@example.com", &smtp.MailOptions{}))
 	require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
 	require.Nil(t, session.Rcpt("ntfy-mytopic@ntfy.sh"))
 	require.Nil(t, session.Data(strings.NewReader(email)))
 	require.Nil(t, session.Data(strings.NewReader(email)))
 }
 }
@@ -303,12 +303,12 @@ func newTestBackend(t *testing.T, handler func(http.ResponseWriter, *http.Reques
 	return conf, backend
 	return conf, backend
 }
 }
 
 
-func fakeConnState(t *testing.T, remoteAddr string) *smtp.ConnectionState {
+func fakeConnState(t *testing.T, remoteAddr string) *smtp.Conn {
 	ip, err := net.ResolveIPAddr("ip", remoteAddr)
 	ip, err := net.ResolveIPAddr("ip", remoteAddr)
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}
-	return &smtp.ConnectionState{
+	return &smtp.Conn{
 		Hostname:   "myhostname",
 		Hostname:   "myhostname",
 		LocalAddr:  ip,
 		LocalAddr:  ip,
 		RemoteAddr: ip,
 		RemoteAddr: ip,

+ 2 - 2
server/util.go

@@ -57,8 +57,8 @@ func logHTTPPrefix(v *visitor, r *http.Request) string {
 	return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
 	return fmt.Sprintf("%s HTTP %s %s", v.ip, r.Method, requestURI)
 }
 }
 
 
-func logSMTPPrefix(state *smtp.ConnectionState) string {
-	return fmt.Sprintf("%s/%s SMTP", state.Hostname, state.RemoteAddr.String())
+func logSMTPPrefix(conn *smtp.Conn) string {
+	return fmt.Sprintf("%s/%s SMTP", conn.Hostname(), conn.Conn().RemoteAddr().String())
 }
 }
 
 
 func renderHTTPRequest(r *http.Request) string {
 func renderHTTPRequest(r *http.Request) string {