summaryrefslogtreecommitdiff
path: root/gemini/serve.go
blob: f9a8a1cf06d34f41d20c47d4a7c714204521a3ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package gemini

import (
	"context"
	"crypto/tls"
	"io"
	"net"
	"sync"
)

// Server listens on a network and serves the gemini protocol.
type Server struct {
	ctx      context.Context
	network  string
	address  string
	cancel   context.CancelFunc
	wg       *sync.WaitGroup
	listener net.Listener
	handler  Handler
}

// NewServer builds a server.
func NewServer(
	ctx context.Context,
	tlsConfig *tls.Config,
	network string,
	address string,
	handler Handler,
) (*Server, error) {
	listener, err := net.Listen(network, address)
	if err != nil {
		return nil, err
	}

	s := &Server{
		ctx:      ctx,
		network:  network,
		address:  address,
		wg:       &sync.WaitGroup{},
		listener: tls.NewListener(listener, tlsConfig),
		handler:  handler,
	}

	return s, nil
}

// Serve starts the server and blocks until it is closed.
//
// This function will allocate resources which are not cleaned up until
// Close() is called.
//
// It will respect cancellation of the context the server was created with,
// but be aware that Close() must still be called in that case to avoid
// dangling goroutines.
func (s *Server) Serve() error {
	s.wg.Add(1)
	defer s.wg.Done()

	s.ctx, s.cancel = context.WithCancel(s.ctx)

	s.wg.Add(1)
	go s.propagateCancel()

	for {
		conn, err := s.listener.Accept()
		if err != nil {
			if s.closed() {
				err = nil
			}
			return err
		}

		s.wg.Add(1)
		go s.handleConn(conn)
	}
}

// Close begins a graceful shutdown of the server.
//
// It cancels the server's context which interrupts all concurrently running
// request handlers, if they support it. It then blocks until all resources
// have been cleaned up and all request handlers have completed.
func (s *Server) Close() {
	s.cancel()
	s.wg.Wait()
}

// Network returns the network type on which the server is running.
func (s *Server) Network() string {
	return s.network
}

// Address returns the address on which the server is listening.
func (s *Server) Address() string {
	return s.address
}

// Hostname returns just the hostname portion of the listen address.
func (s *Server) Hostname() string {
	host, _, _ := net.SplitHostPort(s.address)
	return host
}

// Port returns the port on which the server is listening.
func (s *Server) Port() string {
	_, portStr, _ := net.SplitHostPort(s.address)
	return portStr
}

func (s *Server) handleConn(conn net.Conn) {
	defer s.wg.Done()
	defer conn.Close()

	req, err := ParseRequest(conn)
	if err != nil {
		_, _ = io.Copy(conn, BadRequest(err.Error()))
		return
	}

	req.Server = s
	req.RemoteAddr = conn.RemoteAddr()
	if tlsconn, ok := conn.(*tls.Conn); req != nil && ok {
		state := tlsconn.ConnectionState()
		req.TLSState = &state
	}

	resp := s.handler(s.ctx, req)
	defer resp.Close()

	_, _ = io.Copy(conn, resp)
}

func (s *Server) propagateCancel() {
	go func() {
		defer s.wg.Done()

		<-s.ctx.Done()
		_ = s.listener.Close()
	}()
}

func (s *Server) closed() bool {
	select {
	case <-s.ctx.Done():
		return true
	default:
		return false
	}
}