summaryrefslogtreecommitdiff
path: root/spartan/serve.go
blob: 677d76c3ff0162b8f681363506d8ff56d28acbed (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package spartan

import (
	"bufio"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"strings"

	"tildegit.org/tjp/gus"
	"tildegit.org/tjp/gus/internal"
	"tildegit.org/tjp/gus/logging"
)

type spartanRequestBodyKey struct{}
type spartanRequestBodyLenKey struct{}

// SpartanRequestBody is the key set in a handler's context for spartan request bodies.
//
// The corresponding value is a *bufio.Reader from which the request body can be read.
var SpartanRequestBody = spartanRequestBodyKey{}

// SpartanRequestBodyLen is the key set in a handler's context for the content-length of the request.
//
// The corresponding value is an int.
var SpartanRequestBodyLen = spartanRequestBodyLenKey{}

type spartanServer struct {
	internal.Server
	handler gus.Handler
}

func (ss spartanServer) Protocol() string { return "SPARTAN" }

// NewServer builds a spartan server.
func NewServer(
	ctx context.Context,
	hostname string,
	network string,
	address string,
	handler gus.Handler,
	errLog logging.Logger,
) (gus.Server, error) {
	ss := &spartanServer{handler: handler}

	if strings.IndexByte(hostname, ':') < 0 {
		hostname = net.JoinHostPort(hostname, "300")
	}

	var err error
	ss.Server, err = internal.NewServer(ctx, hostname, network, address, errLog, ss.handleConn)
	if err != nil {
		return nil, err
	}

	return ss, nil
}

func (ss *spartanServer) handleConn(conn net.Conn) {
	buf := bufio.NewReader(conn)

	var response *gus.Response
	request, clen, err := ParseRequest(buf)
	if err != nil {
		response = ClientError(err)
	} else {
		request.Server = ss
		request.RemoteAddr = conn.RemoteAddr()

		var body *bufio.Reader = nil
		if clen > 0 {
			body = bufio.NewReader(io.LimitReader(buf, int64(clen)))
		}
		ctx := context.WithValue(ss.Ctx, SpartanRequestBody, body)
		ctx = context.WithValue(ctx, SpartanRequestBodyLen, clen)

		defer func() {
			if r := recover(); r != nil {
				err := fmt.Errorf("%s", r)
				_ = ss.LogError("msg", "panic in handler", "err", err)
				rdr := NewResponseReader(ServerError(errors.New("Server error")))
				_, _ = io.Copy(conn, rdr)
			}
		}()
		response = ss.handler.Handle(ctx, request)
		if response == nil {
			response = ClientError(errors.New("Resource does not exist."))
		}
	}

	defer response.Close()
	_, _ = io.Copy(conn, NewResponseReader(response))
}