344 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			344 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package email
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/base64"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"net/smtp"
 | |
| 	"slices"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	mime      = "MIME-version: 1.0;\nContent-Type: text/html; charset=\"UTF-8\";\n\n"
 | |
| 	delimeter = "**=myohmy689407924327"
 | |
| )
 | |
| 
 | |
| type SMTPClientIface interface {
 | |
| 	StartTLS(*tls.Config) error
 | |
| 	Auth(a smtp.Auth) error
 | |
| 	Close() error
 | |
| 	Data() (io.WriteCloser, error)
 | |
| 	Mail(from string) error
 | |
| 	Quit() error
 | |
| 	Rcpt(to string) error
 | |
| }
 | |
| 
 | |
| type SmtpDialFn func(hostPort string) (SMTPClientIface, error)
 | |
| 
 | |
| type EmailService struct {
 | |
| 	auth      smtp.Auth
 | |
| 	host      string
 | |
| 	port      string
 | |
| 	from      string
 | |
| 	tlsconfig *tls.Config
 | |
| 	dial      SmtpDialFn
 | |
| }
 | |
| 
 | |
| func NewInsecure(config MailServiceConfig) *EmailService {
 | |
| 	return &EmailService{
 | |
| 		auth: config.Auth,
 | |
| 		host: config.Host,
 | |
| 		port: config.Port,
 | |
| 		from: config.From,
 | |
| 		tlsconfig: &tls.Config{
 | |
| 			InsecureSkipVerify: true,
 | |
| 			ServerName:         config.Host,
 | |
| 		},
 | |
| 		dial: dial,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| var validCommonNames = []string{
 | |
| 	"ISRG Root X1",
 | |
| 	"R3",
 | |
| 	"R10",
 | |
| 	"R11",
 | |
| 	"E5",
 | |
| 	"DST Root CA X3",
 | |
| 	"DigiCert Global Root G2",
 | |
| 	"DigiCert Global G2 TLS RSA SHA256 2020 CA1",
 | |
| }
 | |
| 
 | |
| func customVerify(host string) func(cs tls.ConnectionState) error {
 | |
| 	return func(cs tls.ConnectionState) error {
 | |
| 		// Ensure we have at least one peer certificate
 | |
| 		if len(cs.PeerCertificates) == 0 {
 | |
| 			return fmt.Errorf("no peer certificates provided")
 | |
| 		}
 | |
| 
 | |
| 		now := time.Now()
 | |
| 
 | |
| 		// Set up verification options with a DNSName check.
 | |
| 		// This will perform hostname verification automatically.
 | |
| 		opts := x509.VerifyOptions{
 | |
| 			CurrentTime:   now,
 | |
| 			DNSName:       host, // assuming config.Host is accessible here
 | |
| 			Intermediates: x509.NewCertPool(),
 | |
| 		}
 | |
| 		// Add all certificates except the leaf as intermediates.
 | |
| 		for i := 1; i < len(cs.PeerCertificates); i++ {
 | |
| 			opts.Intermediates.AddCert(cs.PeerCertificates[i])
 | |
| 		}
 | |
| 
 | |
| 		// Verify the certificate chain (including hostname check via opts.DNSName)
 | |
| 		if _, err := cs.PeerCertificates[0].Verify(opts); err != nil {
 | |
| 			return fmt.Errorf("certificate chain verification failed: %w", err)
 | |
| 		}
 | |
| 
 | |
| 		// Perform additional custom checks
 | |
| 		for _, cert := range cs.PeerCertificates {
 | |
| 			if now.After(cert.NotAfter) {
 | |
| 				return fmt.Errorf("certificate expired on %s", cert.NotAfter)
 | |
| 			}
 | |
| 			if now.Add(30 * 24 * time.Hour).After(cert.NotAfter) {
 | |
| 				return fmt.Errorf("certificate will expire soon on %s", cert.NotAfter)
 | |
| 			}
 | |
| 
 | |
| 			// Check that the issuer's CommonName is in our allowed list.
 | |
| 			if !slices.Contains(validCommonNames, cert.Issuer.CommonName) {
 | |
| 				return fmt.Errorf("untrusted certificate issuer: %s", cert.Issuer.CommonName)
 | |
| 			}
 | |
| 
 | |
| 			// Check that the public key algorithm is RSA.
 | |
| 			if cert.PublicKeyAlgorithm != x509.RSA {
 | |
| 				return fmt.Errorf("unsupported public key algorithm: %v", cert.PublicKeyAlgorithm)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func NewSecure(config MailServiceConfig) *EmailService {
 | |
| 	return &EmailService{
 | |
| 		auth: config.Auth,
 | |
| 		host: config.Host,
 | |
| 		port: config.Port,
 | |
| 		from: config.From,
 | |
| 		tlsconfig: &tls.Config{
 | |
| 			InsecureSkipVerify: true,
 | |
| 			ServerName:         config.Host,
 | |
| 			VerifyConnection:   customVerify(config.Host),
 | |
| 		},
 | |
| 		dial: dial,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func NewSecure465(config MailServiceConfig) *EmailService {
 | |
| 	tlsCfg := tls.Config{
 | |
| 		// Ideally, InsecureSkipVerify: false,
 | |
| 		// or do a proper certificate validation
 | |
| 		InsecureSkipVerify: true,
 | |
| 		ServerName:         config.Host,
 | |
| 		VerifyConnection:   customVerify(config.Host),
 | |
| 	}
 | |
| 	return &EmailService{
 | |
| 		auth:      config.Auth,
 | |
| 		host:      config.Host,
 | |
| 		port:      config.Port,
 | |
| 		from:      config.From,
 | |
| 		tlsconfig: &tlsCfg,
 | |
| 		dial: func(hostPort string) (SMTPClientIface, error) {
 | |
| 			return dialTLS(hostPort, &tlsCfg)
 | |
| 		},
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type MailServiceConfig struct {
 | |
| 	Auth smtp.Auth
 | |
| 	Host string
 | |
| 	Port string
 | |
| 	From string // Sender email address
 | |
| }
 | |
| 
 | |
| func dial(hostPort string) (SMTPClientIface, error) {
 | |
| 	client, err := smtp.Dial(hostPort)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return client, nil
 | |
| }
 | |
| 
 | |
| func dialTLS(hostPort string, tlsConfig *tls.Config) (SMTPClientIface, error) {
 | |
| 	// 1) Create a raw TCP connection
 | |
| 	conn, err := net.Dial("tcp", hostPort)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// 2) Wrap it with TLS
 | |
| 	tlsConn := tls.Client(conn, tlsConfig)
 | |
| 
 | |
| 	// 3) Now create the SMTP client on this TLS connection
 | |
| 	host, _, _ := net.SplitHostPort(hostPort)
 | |
| 	c, err := smtp.NewClient(tlsConn, host)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| func (e *EmailService) SendEmail(emailData EmailMessage) error {
 | |
| 	msg, err := newMessage(e.from, emailData.To, emailData.Subject).
 | |
| 		withAttachments(emailData.Body, emailData.Attachments)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("error while preparing email: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	switch e.port {
 | |
| 	case "465":
 | |
| 		return e.sendTLS(emailData.To, msg)
 | |
| 	default:
 | |
| 		return e.send(emailData.To, msg)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (e *EmailService) send(to string, msg []byte) error {
 | |
| 	c, err := e.dial(e.host + ":" + e.port)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("DIAL: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if err = c.StartTLS(e.tlsconfig); err != nil {
 | |
| 		return fmt.Errorf("c.StartTLS: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// Auth
 | |
| 	if err = c.Auth(e.auth); err != nil {
 | |
| 		return fmt.Errorf("c.Auth: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// To && From
 | |
| 	if err = c.Mail(e.from); err != nil {
 | |
| 		return fmt.Errorf("c.Mail: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if err = c.Rcpt(to); err != nil {
 | |
| 		return fmt.Errorf("c.Rcpt: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// Data
 | |
| 	w, err := c.Data()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("c.Data: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	written, err := w.Write(msg)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("w.Write: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if written <= 0 {
 | |
| 		return fmt.Errorf("%d bytes written", written)
 | |
| 	}
 | |
| 
 | |
| 	if err = w.Close(); err != nil {
 | |
| 		return fmt.Errorf("w.Close: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if err = c.Quit(); err != nil {
 | |
| 		return fmt.Errorf("w.Quit: %s", err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (e *EmailService) sendTLS(to string, msg []byte) error {
 | |
| 	c, err := e.dial(e.host + ":" + e.port) // dialTLS
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("DIAL: %s", err)
 | |
| 	}
 | |
| 	defer c.Close()
 | |
| 
 | |
| 	// Auth
 | |
| 	if err = c.Auth(e.auth); err != nil {
 | |
| 		return fmt.Errorf("c.Auth: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// To && From
 | |
| 	if err = c.Mail(e.from); err != nil {
 | |
| 		return fmt.Errorf("c.Mail: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if err = c.Rcpt(to); err != nil {
 | |
| 		return fmt.Errorf("c.Rcpt: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	// Data
 | |
| 	w, err := c.Data()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("c.Data: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	written, err := w.Write(msg)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("w.Write: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if written <= 0 {
 | |
| 		return fmt.Errorf("%d bytes written", written)
 | |
| 	}
 | |
| 
 | |
| 	if err = w.Close(); err != nil {
 | |
| 		return fmt.Errorf("w.Close: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	if err = c.Quit(); err != nil {
 | |
| 		return fmt.Errorf("w.Quit: %s", err)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| type message struct {
 | |
| 	from    string
 | |
| 	to      string
 | |
| 	subject string
 | |
| }
 | |
| 
 | |
| func newMessage(from, to, subject string) message {
 | |
| 	return message{from: from, to: to, subject: subject}
 | |
| }
 | |
| 
 | |
| func (m message) withAttachments(body string, attachments []EmailAttachment) ([]byte, error) {
 | |
| 	headers := make(map[string]string)
 | |
| 	headers["From"] = m.from
 | |
| 	headers["To"] = m.to
 | |
| 	headers["Subject"] = m.subject
 | |
| 	headers["MIME-Version"] = "1.0"
 | |
| 
 | |
| 	var message bytes.Buffer
 | |
| 
 | |
| 	for k, v := range headers {
 | |
| 		message.WriteString(k)
 | |
| 		message.WriteString(": ")
 | |
| 		message.WriteString(v)
 | |
| 		message.WriteString("\r\n")
 | |
| 	}
 | |
| 
 | |
| 	message.WriteString("Content-Type: " + fmt.Sprintf("multipart/mixed; boundary=\"%s\"\r\n", delimeter))
 | |
| 	message.WriteString("--" + delimeter + "\r\n")
 | |
| 	message.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n\r\n")
 | |
| 	message.WriteString(body + "\r\n\r\n")
 | |
| 
 | |
| 	for _, attachment := range attachments {
 | |
| 		attachmentRawFile, err := attachment.ReadContent()
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		message.WriteString("--" + delimeter + "\r\n")
 | |
| 		message.WriteString("Content-Disposition: attachment; filename=\"" + attachment.Title + "\"\r\n")
 | |
| 		message.WriteString("Content-Type: application/octet-stream\r\n")
 | |
| 		message.WriteString("Content-Transfer-Encoding: base64\r\n\r\n")
 | |
| 		message.WriteString(base64.StdEncoding.EncodeToString(attachmentRawFile) + "\r\n")
 | |
| 	}
 | |
| 
 | |
| 	message.WriteString("--" + delimeter + "--") // End the message
 | |
| 	return message.Bytes(), nil
 | |
| }
 |