package email import ( "bytes" "crypto/tls" "crypto/x509" "encoding/base64" "fmt" "io" "net/smtp" "slices" "time" ) const ( mime = "MIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n" delimeter = "**=myohmy689407924327" ) type SMTPClientIface interface { StartTLS(*tls.Config) error Auth(a smtp.Auth) error Close() error Data() (io.WriteCloser, error) Mail(from string) error Quit() error Rcpt(to string) error } type SmtpDialFn func(hostPort string) (SMTPClientIface, error) type EmailService struct { auth smtp.Auth host string port string from string tlsconfig *tls.Config dial SmtpDialFn } func NewInsecure(config MailServiceConfig) *EmailService { return &EmailService{ auth: config.Auth, host: config.Host, port: config.Port, from: config.From, tlsconfig: &tls.Config{ InsecureSkipVerify: true, ServerName: config.Host, }, dial: dial, } } var validCommonNames = []string{"ISRG Root X1", "R3", "E5", "DST Root CA X3"} func NewSecure(config MailServiceConfig) *EmailService { return &EmailService{ auth: config.Auth, host: config.Host, port: config.Port, from: config.From, tlsconfig: &tls.Config{ InsecureSkipVerify: true, ServerName: config.Host, VerifyConnection: func(cs tls.ConnectionState) error { // // Check the server's common name // for _, cert := range cs.PeerCertificates { // log.Println("cert.DNSNames", cert.DNSNames) // if err := cert.VerifyHostname(config.Host); err != nil { // return fmt.Errorf("invalid common name: %w", err) // } // } // Check the certificate chain opts := x509.VerifyOptions{ Intermediates: x509.NewCertPool(), } for _, cert := range cs.PeerCertificates[1:] { opts.Intermediates.AddCert(cert) } _, err := cs.PeerCertificates[0].Verify(opts) if err != nil { return fmt.Errorf("invalid certificate chain: %w", err) } // Iterate over the certificates again to perform custom checks for _, cert := range cs.PeerCertificates { // TODO: add more checks here... if time.Now().After(cert.NotAfter) { return fmt.Errorf("certificate has expired") } if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) { return fmt.Errorf("certificate will expire within 30 days") } if !slices.Contains(validCommonNames, cert.Issuer.CommonName) { return fmt.Errorf("certificate is not issued by a trusted CA") } if cert.PublicKeyAlgorithm != x509.RSA { return fmt.Errorf("unsupported public key algorithm") } } return nil }, }, dial: dial, } } type MailServiceConfig struct { Auth smtp.Auth Host string Port string From string // Sender email address } func dial(hostPort string) (SMTPClientIface, error) { client, err := smtp.Dial(hostPort) if err != nil { return nil, err } return client, nil } func (e *EmailService) SendEmail(emailData EmailMessage) error { msg, err := newMessage(e.from, emailData.To, emailData.Subject). withAttachments(emailData.Body, emailData.Attachments) if err != nil { return fmt.Errorf("error while preparing email: %w", err) } return e.send(emailData.To, msg) } func (e *EmailService) send(to string, msg []byte) error { c, err := e.dial(e.host + ":" + e.port) if err != nil { return fmt.Errorf("DIAL: %s", err) } if err = c.StartTLS(e.tlsconfig); err != nil { return fmt.Errorf("c.StartTLS: %s", err) } // Auth if err = c.Auth(e.auth); err != nil { return fmt.Errorf("c.Auth: %s", err) } // To && From if err = c.Mail(e.from); err != nil { return fmt.Errorf("c.Mail: %s", err) } if err = c.Rcpt(to); err != nil { return fmt.Errorf("c.Rcpt: %s", err) } // Data w, err := c.Data() if err != nil { return fmt.Errorf("c.Data: %s", err) } written, err := w.Write(msg) if err != nil { return fmt.Errorf("w.Write: %s", err) } if written <= 0 { return fmt.Errorf("%d bytes written", written) } if err = w.Close(); err != nil { return fmt.Errorf("w.Close: %s", err) } if err = c.Quit(); err != nil { return fmt.Errorf("w.Quit: %s", err) } return nil } type message struct { from string to string subject string } func newMessage(from, to, subject string) message { return message{from: from, to: to, subject: subject} } func (m message) withAttachments(body string, attachments []EmailAttachment) ([]byte, error) { headers := make(map[string]string) headers["From"] = m.from headers["To"] = m.to headers["Subject"] = m.subject headers["MIME-Version"] = "1.0" var message bytes.Buffer for k, v := range headers { message.WriteString(k) message.WriteString(": ") message.WriteString(v) message.WriteString("\r\n") } message.WriteString("Content-Type: " + fmt.Sprintf("multipart/mixed; boundary=\"%s\"\r\n", delimeter)) message.WriteString("--" + delimeter + "\r\n") message.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n\r\n") message.WriteString(body + "\r\n\r\n") for _, attachment := range attachments { attachmentRawFile, err := attachment.ReadContent() if err != nil { return nil, err } message.WriteString("--" + delimeter + "\r\n") message.WriteString("Content-Disposition: attachment; filename=\"" + attachment.Title + "\"\r\n") message.WriteString("Content-Type: application/octet-stream\r\n") message.WriteString("Content-Transfer-Encoding: base64\r\n\r\n") message.WriteString(base64.StdEncoding.EncodeToString(attachmentRawFile) + "\r\n") } message.WriteString("--" + delimeter + "--") // End the message return message.Bytes(), nil }