refactor: add secure with tls
This commit is contained in:
parent
85ca7512c6
commit
2028190971
@ -7,6 +7,7 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"slices"
|
||||
"time"
|
||||
@ -52,7 +53,66 @@ func NewInsecure(config MailServiceConfig) *EmailService {
|
||||
}
|
||||
}
|
||||
|
||||
var validCommonNames = []string{"ISRG Root X1", "R3", "R10", "R11", "E5", "DST Root CA X3"}
|
||||
var validCommonNames = []string{
|
||||
"ISRG Root X1",
|
||||
"R3",
|
||||
"R10",
|
||||
"R11",
|
||||
"E5",
|
||||
"DST Root CA X3",
|
||||
"DigiCert Global Root G2",
|
||||
"DigiCert Global G2 TLS RSA SHA256 2020 CA1",
|
||||
}
|
||||
|
||||
func customVerify(host string) func(cs tls.ConnectionState) error {
|
||||
return func(cs tls.ConnectionState) error {
|
||||
// Ensure we have at least one peer certificate
|
||||
if len(cs.PeerCertificates) == 0 {
|
||||
return fmt.Errorf("no peer certificates provided")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Set up verification options with a DNSName check.
|
||||
// This will perform hostname verification automatically.
|
||||
opts := x509.VerifyOptions{
|
||||
CurrentTime: now,
|
||||
DNSName: host, // assuming config.Host is accessible here
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
// Add all certificates except the leaf as intermediates.
|
||||
for i := 1; i < len(cs.PeerCertificates); i++ {
|
||||
opts.Intermediates.AddCert(cs.PeerCertificates[i])
|
||||
}
|
||||
|
||||
// Verify the certificate chain (including hostname check via opts.DNSName)
|
||||
if _, err := cs.PeerCertificates[0].Verify(opts); err != nil {
|
||||
return fmt.Errorf("certificate chain verification failed: %w", err)
|
||||
}
|
||||
|
||||
// Perform additional custom checks
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
if now.After(cert.NotAfter) {
|
||||
return fmt.Errorf("certificate expired on %s", cert.NotAfter)
|
||||
}
|
||||
if now.Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||
return fmt.Errorf("certificate will expire soon on %s", cert.NotAfter)
|
||||
}
|
||||
|
||||
// Check that the issuer's CommonName is in our allowed list.
|
||||
if !slices.Contains(validCommonNames, cert.Issuer.CommonName) {
|
||||
return fmt.Errorf("untrusted certificate issuer: %s", cert.Issuer.CommonName)
|
||||
}
|
||||
|
||||
// Check that the public key algorithm is RSA.
|
||||
if cert.PublicKeyAlgorithm != x509.RSA {
|
||||
return fmt.Errorf("unsupported public key algorithm: %v", cert.PublicKeyAlgorithm)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewSecure(config MailServiceConfig) *EmailService {
|
||||
return &EmailService{
|
||||
@ -63,54 +123,32 @@ func NewSecure(config MailServiceConfig) *EmailService {
|
||||
tlsconfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: config.Host,
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
|
||||
// // Check the server's common name
|
||||
// for _, cert := range cs.PeerCertificates {
|
||||
// log.Println("cert.DNSNames", cert.DNSNames)
|
||||
// if err := cert.VerifyHostname(config.Host); err != nil {
|
||||
// return fmt.Errorf("invalid common name: %w", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Check the certificate chain
|
||||
opts := x509.VerifyOptions{
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
for _, cert := range cs.PeerCertificates[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err := cs.PeerCertificates[0].Verify(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid certificate chain: %w", err)
|
||||
}
|
||||
|
||||
// Iterate over the certificates again to perform custom checks
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
// TODO: add more checks here...
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
return fmt.Errorf("certificate has expired")
|
||||
}
|
||||
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||
return fmt.Errorf("certificate will expire within 30 days")
|
||||
}
|
||||
|
||||
if !slices.Contains(validCommonNames, cert.Issuer.CommonName) {
|
||||
return fmt.Errorf("certificate is not issued by a trusted CA")
|
||||
}
|
||||
|
||||
if cert.PublicKeyAlgorithm != x509.RSA {
|
||||
return fmt.Errorf("unsupported public key algorithm")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
VerifyConnection: customVerify(config.Host),
|
||||
},
|
||||
dial: dial,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSecure465(config MailServiceConfig) *EmailService {
|
||||
tlsCfg := tls.Config{
|
||||
// Ideally, InsecureSkipVerify: false,
|
||||
// or do a proper certificate validation
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: config.Host,
|
||||
VerifyConnection: customVerify(config.Host),
|
||||
}
|
||||
return &EmailService{
|
||||
auth: config.Auth,
|
||||
host: config.Host,
|
||||
port: config.Port,
|
||||
from: config.From,
|
||||
tlsconfig: &tlsCfg,
|
||||
dial: func(hostPort string) (SMTPClientIface, error) {
|
||||
return dialTLS(hostPort, &tlsCfg)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type MailServiceConfig struct {
|
||||
Auth smtp.Auth
|
||||
Host string
|
||||
@ -126,6 +164,25 @@ func dial(hostPort string) (SMTPClientIface, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func dialTLS(hostPort string, tlsConfig *tls.Config) (SMTPClientIface, error) {
|
||||
// 1) Create a raw TCP connection
|
||||
conn, err := net.Dial("tcp", hostPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) Wrap it with TLS
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
|
||||
// 3) Now create the SMTP client on this TLS connection
|
||||
host, _, _ := net.SplitHostPort(hostPort)
|
||||
c, err := smtp.NewClient(tlsConn, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (e *EmailService) SendEmail(emailData EmailMessage) error {
|
||||
msg, err := newMessage(e.from, emailData.To, emailData.Subject).
|
||||
withAttachments(emailData.Body, emailData.Attachments)
|
||||
@ -134,7 +191,12 @@ func (e *EmailService) SendEmail(emailData EmailMessage) error {
|
||||
return fmt.Errorf("error while preparing email: %w", err)
|
||||
}
|
||||
|
||||
return e.send(emailData.To, msg)
|
||||
switch e.port {
|
||||
case "465":
|
||||
return e.sendTLS(emailData.To, msg)
|
||||
default:
|
||||
return e.send(emailData.To, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailService) send(to string, msg []byte) error {
|
||||
@ -186,6 +248,53 @@ func (e *EmailService) send(to string, msg []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EmailService) sendTLS(to string, msg []byte) error {
|
||||
c, err := e.dial(e.host + ":" + e.port) // dialTLS
|
||||
if err != nil {
|
||||
return fmt.Errorf("DIAL: %s", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Auth
|
||||
if err = c.Auth(e.auth); err != nil {
|
||||
return fmt.Errorf("c.Auth: %s", err)
|
||||
}
|
||||
|
||||
// To && From
|
||||
if err = c.Mail(e.from); err != nil {
|
||||
return fmt.Errorf("c.Mail: %s", err)
|
||||
}
|
||||
|
||||
if err = c.Rcpt(to); err != nil {
|
||||
return fmt.Errorf("c.Rcpt: %s", err)
|
||||
}
|
||||
|
||||
// Data
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("c.Data: %s", err)
|
||||
}
|
||||
|
||||
written, err := w.Write(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("w.Write: %s", err)
|
||||
}
|
||||
|
||||
if written <= 0 {
|
||||
return fmt.Errorf("%d bytes written", written)
|
||||
}
|
||||
|
||||
if err = w.Close(); err != nil {
|
||||
return fmt.Errorf("w.Close: %s", err)
|
||||
}
|
||||
|
||||
if err = c.Quit(); err != nil {
|
||||
return fmt.Errorf("w.Quit: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type message struct {
|
||||
from string
|
||||
to string
|
||||
|
Loading…
x
Reference in New Issue
Block a user