diff --git a/pkg/email/email.go b/pkg/email/email.go index 77e6425..44741b2 100644 --- a/pkg/email/email.go +++ b/pkg/email/email.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "fmt" "io" + "net" "net/smtp" "slices" "time" @@ -52,7 +53,66 @@ func NewInsecure(config MailServiceConfig) *EmailService { } } -var validCommonNames = []string{"ISRG Root X1", "R3", "R10", "R11", "E5", "DST Root CA X3"} +var validCommonNames = []string{ + "ISRG Root X1", + "R3", + "R10", + "R11", + "E5", + "DST Root CA X3", + "DigiCert Global Root G2", + "DigiCert Global G2 TLS RSA SHA256 2020 CA1", +} + +func customVerify(host string) func(cs tls.ConnectionState) error { + return func(cs tls.ConnectionState) error { + // Ensure we have at least one peer certificate + if len(cs.PeerCertificates) == 0 { + return fmt.Errorf("no peer certificates provided") + } + + now := time.Now() + + // Set up verification options with a DNSName check. + // This will perform hostname verification automatically. + opts := x509.VerifyOptions{ + CurrentTime: now, + DNSName: host, // assuming config.Host is accessible here + Intermediates: x509.NewCertPool(), + } + // Add all certificates except the leaf as intermediates. + for i := 1; i < len(cs.PeerCertificates); i++ { + opts.Intermediates.AddCert(cs.PeerCertificates[i]) + } + + // Verify the certificate chain (including hostname check via opts.DNSName) + if _, err := cs.PeerCertificates[0].Verify(opts); err != nil { + return fmt.Errorf("certificate chain verification failed: %w", err) + } + + // Perform additional custom checks + for _, cert := range cs.PeerCertificates { + if now.After(cert.NotAfter) { + return fmt.Errorf("certificate expired on %s", cert.NotAfter) + } + if now.Add(30 * 24 * time.Hour).After(cert.NotAfter) { + return fmt.Errorf("certificate will expire soon on %s", cert.NotAfter) + } + + // Check that the issuer's CommonName is in our allowed list. + if !slices.Contains(validCommonNames, cert.Issuer.CommonName) { + return fmt.Errorf("untrusted certificate issuer: %s", cert.Issuer.CommonName) + } + + // Check that the public key algorithm is RSA. + if cert.PublicKeyAlgorithm != x509.RSA { + return fmt.Errorf("unsupported public key algorithm: %v", cert.PublicKeyAlgorithm) + } + } + + return nil + } +} func NewSecure(config MailServiceConfig) *EmailService { return &EmailService{ @@ -63,54 +123,32 @@ func NewSecure(config MailServiceConfig) *EmailService { tlsconfig: &tls.Config{ InsecureSkipVerify: true, ServerName: config.Host, - VerifyConnection: func(cs tls.ConnectionState) error { - - // // Check the server's common name - // for _, cert := range cs.PeerCertificates { - // log.Println("cert.DNSNames", cert.DNSNames) - // if err := cert.VerifyHostname(config.Host); err != nil { - // return fmt.Errorf("invalid common name: %w", err) - // } - // } - - // Check the certificate chain - opts := x509.VerifyOptions{ - Intermediates: x509.NewCertPool(), - } - for _, cert := range cs.PeerCertificates[1:] { - opts.Intermediates.AddCert(cert) - } - _, err := cs.PeerCertificates[0].Verify(opts) - if err != nil { - return fmt.Errorf("invalid certificate chain: %w", err) - } - - // Iterate over the certificates again to perform custom checks - for _, cert := range cs.PeerCertificates { - // TODO: add more checks here... - if time.Now().After(cert.NotAfter) { - return fmt.Errorf("certificate has expired") - } - if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) { - return fmt.Errorf("certificate will expire within 30 days") - } - - if !slices.Contains(validCommonNames, cert.Issuer.CommonName) { - return fmt.Errorf("certificate is not issued by a trusted CA") - } - - if cert.PublicKeyAlgorithm != x509.RSA { - return fmt.Errorf("unsupported public key algorithm") - } - } - - return nil - }, + VerifyConnection: customVerify(config.Host), }, dial: dial, } } +func NewSecure465(config MailServiceConfig) *EmailService { + tlsCfg := tls.Config{ + // Ideally, InsecureSkipVerify: false, + // or do a proper certificate validation + InsecureSkipVerify: true, + ServerName: config.Host, + VerifyConnection: customVerify(config.Host), + } + return &EmailService{ + auth: config.Auth, + host: config.Host, + port: config.Port, + from: config.From, + tlsconfig: &tlsCfg, + dial: func(hostPort string) (SMTPClientIface, error) { + return dialTLS(hostPort, &tlsCfg) + }, + } +} + type MailServiceConfig struct { Auth smtp.Auth Host string @@ -126,6 +164,25 @@ func dial(hostPort string) (SMTPClientIface, error) { return client, nil } +func dialTLS(hostPort string, tlsConfig *tls.Config) (SMTPClientIface, error) { + // 1) Create a raw TCP connection + conn, err := net.Dial("tcp", hostPort) + if err != nil { + return nil, err + } + + // 2) Wrap it with TLS + tlsConn := tls.Client(conn, tlsConfig) + + // 3) Now create the SMTP client on this TLS connection + host, _, _ := net.SplitHostPort(hostPort) + c, err := smtp.NewClient(tlsConn, host) + if err != nil { + return nil, err + } + return c, nil +} + func (e *EmailService) SendEmail(emailData EmailMessage) error { msg, err := newMessage(e.from, emailData.To, emailData.Subject). withAttachments(emailData.Body, emailData.Attachments) @@ -134,7 +191,12 @@ func (e *EmailService) SendEmail(emailData EmailMessage) error { return fmt.Errorf("error while preparing email: %w", err) } - return e.send(emailData.To, msg) + switch e.port { + case "465": + return e.sendTLS(emailData.To, msg) + default: + return e.send(emailData.To, msg) + } } func (e *EmailService) send(to string, msg []byte) error { @@ -186,6 +248,53 @@ func (e *EmailService) send(to string, msg []byte) error { return nil } +func (e *EmailService) sendTLS(to string, msg []byte) error { + c, err := e.dial(e.host + ":" + e.port) // dialTLS + if err != nil { + return fmt.Errorf("DIAL: %s", err) + } + defer c.Close() + + // Auth + if err = c.Auth(e.auth); err != nil { + return fmt.Errorf("c.Auth: %s", err) + } + + // To && From + if err = c.Mail(e.from); err != nil { + return fmt.Errorf("c.Mail: %s", err) + } + + if err = c.Rcpt(to); err != nil { + return fmt.Errorf("c.Rcpt: %s", err) + } + + // Data + w, err := c.Data() + if err != nil { + return fmt.Errorf("c.Data: %s", err) + } + + written, err := w.Write(msg) + if err != nil { + return fmt.Errorf("w.Write: %s", err) + } + + if written <= 0 { + return fmt.Errorf("%d bytes written", written) + } + + if err = w.Close(); err != nil { + return fmt.Errorf("w.Close: %s", err) + } + + if err = c.Quit(); err != nil { + return fmt.Errorf("w.Quit: %s", err) + } + + return nil +} + type message struct { from string to string