summaryrefslogtreecommitdiff
path: root/main.go
blob: d32fb014095a5466f891a6ac7896e0ef7720b951 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package main

import (
	"bytes"
	"context"
	"crypto/tls"
	"flag"
	"io/ioutil"
	"log"
	"net"
	"net/mail"
	"net/smtp"
	"net/textproto"
	"strings"
	"time"

	"github.com/mhale/smtpd"
	"github.com/pkg/errors"
)

var flagListen = flag.String(`l`, `:25`, `Address to listen on`)
var flagHostname = flag.String(`h`, `HOSTNAME-NOT-SET`, `Server flagHostname`)
var flagMap = flag.String(`m`, ``, `-m prefix-matcher1:target@targethost,prefix-matcher2:target@targethost`)
var flagCertFile = flag.String(`c`, ``, ``)
var flagKeyFile = flag.String(`k`, ``, ``)

func logWrapper(handler smtpd.Handler) smtpd.Handler {
	return func(remoteAddr net.Addr, from string, to []string, data []byte) (err error) {
		log.Printf(`received email remoteAddr=%v from=%v`, remoteAddr, from)
		err = handler(remoteAddr, from, to, data)
		if err != nil {
			log.Printf(`failed to forward: %v`, err.Error())
			return err
		}
		log.Printf(`forwarded email remoteAddr=%v from=%v`, remoteAddr, from)
		return nil
	}
}

func transformEmail(data []byte) (headers mail.Header, msgData []byte, err error) {
	msg, err := mail.ReadMessage(bytes.NewReader(data))
	if err != nil {
		return nil, nil, errors.Wrap(err, `failed to parse email`)
	}
	headers = make(mail.Header)

	msgData, err = ioutil.ReadAll(msg.Body)
	if err != nil {
		return nil, nil, errors.Wrap(err, `failed to parse email body`)
	}

	for headerKey, headerVal := range msg.Header {
		headers[headerKey] = headerVal
	}

	return headers, msgData, nil
}

func forward(targetEmail string, data []byte) error {
	toAddress := strings.SplitN(targetEmail, `@`, 2)
	if len(toAddress) != 2 {
		return errors.New(`invalid targetEmail address: ` + targetEmail)
	}
	mxes, err := net.DefaultResolver.LookupMX(context.Background(), toAddress[1])
	if err != nil || len(mxes) == 0 {
		return errors.Wrap(err, `failed targetEmail resolve mx`)
	}

	headers, msgData, err := transformEmail(data)
	if err != nil {
		return errors.Wrap(err, `failed targetEmail transform email`)
	}
	headers[textproto.CanonicalMIMEHeaderKey(`To`)] = []string{targetEmail}
	headers[textproto.CanonicalMIMEHeaderKey(`From`)] = []string{`forwarder@localnet.cc`}
	headers[textproto.CanonicalMIMEHeaderKey(`Subject`)] = []string{`Forwarded: ` + headers.Get(`subject`)}

	var builder bytes.Buffer
	for headerName, headerValues := range headers {
		for _, value := range headerValues {
			builder.WriteString(textproto.CanonicalMIMEHeaderKey(headerName))
			builder.WriteString(`: `)
			builder.WriteString(value)
			builder.WriteString("\r\n")
		}
	}
	builder.WriteString("\r\n")
	builder.Write(msgData)

	var retryCount = 5
	for retryCount > 0 {
		err = smtp.SendMail(mxes[0].Host+":25", nil, `forwarder@localnet.cc`, []string{targetEmail}, builder.Bytes())
		if err, ok := err.(*textproto.Error); ok {
			if 400 <= err.Code && err.Code < 500 {
				log.Printf(`retry sleep 120s count=%v code=%v`, retryCount, err.Code)
				time.Sleep(120 * time.Second)
				retryCount--
				continue
			}
		}
		if err != nil {
			return errors.Wrap(err, `failed targetEmail send mail via smtp`)
		}
	}
	log.Printf("forwarded targetEmail=%v", targetEmail)

	return nil
}

func makeHandleEmail(mapping map[string]string) smtpd.Handler {
	return func(remoteAddr net.Addr, from string, to []string, data []byte) error {
		for prefix, targetEmail := range mapping {
			for _, to := range to {
				if strings.HasPrefix(to, prefix) {
					err := forward(targetEmail, data)
					if err != nil {
						log.Print(`forwarded failed `, to, ` to `, targetEmail, err.Error())
					}
				}
			}
		}
		return nil
	}
}

func main() {
	flag.Parse()
	var mapping = make(map[string]string)
	for _, m := range strings.Split(*flagMap, `,`) {
		if m == `` {
			continue
		}
		m := strings.SplitN(m, `:`, 2)
		if len(m) != 2 {
			panic(`invalid flag -m: ` + *flagMap)
		}
		mapping[m[0]] = m[1]
	}

	var tlsConfig *tls.Config
	if *flagCertFile != `` && *flagKeyFile != `` {
		cert, err := tls.LoadX509KeyPair(*flagCertFile, *flagKeyFile)
		if err != nil {
			panic(err)
		}
		tlsConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
	}
	server := &smtpd.Server{
		TLSConfig: tlsConfig,
		Addr:      *flagListen,
		Hostname:  *flagHostname,
		Handler:   logWrapper(makeHandleEmail(mapping)),
	}
	_ = server.ListenAndServe()
}