From 6e1c25af361dde4c063eccbf769e966df4b65f23 Mon Sep 17 00:00:00 2001 From: tjpcc Date: Thu, 28 Sep 2023 08:08:48 -0600 Subject: config file refactor --- auth.go | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 auth.go (limited to 'auth.go') diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..c482366 --- /dev/null +++ b/auth.go @@ -0,0 +1,115 @@ +package main + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "io" + "os" + "os/user" + "path/filepath" + "slices" + "strings" + + "tildegit.org/tjp/sliderule" + "tildegit.org/tjp/sliderule/gemini" +) + +func GeminiAuthMiddleware(auth *Auth) sliderule.Middleware { + if auth == nil { + return func(inner sliderule.Handler) sliderule.Handler { return inner } + } + + return func(inner sliderule.Handler) sliderule.Handler { + return sliderule.HandlerFunc(func(ctx context.Context, request *sliderule.Request) *sliderule.Response { + if auth.Strategy.Approve(ctx, request) { + return inner.Handle(ctx, request) + } + + if len(request.TLSState.PeerCertificates) == 0 { + return gemini.RequireCert("client certificate required") + } + return gemini.CertAuthFailure("client certificate rejected") + }) + } +} + +func ClientTLSFile(path string) (AuthStrategy, error) { + if strings.Contains(path, "~") { + return UserClientTLSAuth(path), nil + } + + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer func() { _ = f.Close() }() + + contents, err := io.ReadAll(f) + if err != nil { + return nil, err + } + + fingerprints := []string{} + for _, line := range strings.Split(string(contents), "\n") { + line = strings.Trim(line, " \t\r") + if len(line) == sha256.Size*2 { + fingerprints = append(fingerprints, line) + } + } + return ClientTLSAuth(fingerprints), nil +} + +func ClientTLS(raw string) AuthStrategy { + fingerprints := []string{} + for _, fp := range strings.Split(raw, ",") { + fp = strings.Trim(fp, " \t\r") + if len(fp) == sha256.Size*2 { + fingerprints = append(fingerprints, fp) + } + } + return ClientTLSAuth(fingerprints) +} + +type UserClientTLSAuth string + +func (ca UserClientTLSAuth) Approve(ctx context.Context, request *sliderule.Request) bool { + u, err := user.Lookup(sliderule.RouteParams(ctx)["username"]) + if err != nil { + return false + } + fpath := resolveTilde(string(ca), u) + + strat, err := ClientTLSFile(fpath) + if err != nil { + return false + } + return strat.Approve(ctx, request) +} + +func resolveTilde(path string, u *user.User) string { + if strings.HasPrefix(path, "~/") { + return filepath.Join(u.HomeDir, path[1:]) + } + return strings.ReplaceAll(path, "~", u.Username) +} + +type ClientTLSAuth []string + +func (ca ClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool { + if request.TLSState == nil || len(request.TLSState.PeerCertificates) == 0 { + return false + } + return slices.Contains(ca, fingerprint(request.TLSState.PeerCertificates[0].Raw)) +} + +func fingerprint(raw []byte) string { + hash := sha256.Sum256(raw) + return hex.EncodeToString(hash[:]) +} + +type HasClientTLSAuth struct{} + +func (_ HasClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool { + return request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0 +} -- cgit v1.2.3