summaryrefslogtreecommitdiff
path: root/auth.go
diff options
context:
space:
mode:
Diffstat (limited to 'auth.go')
-rw-r--r--auth.go115
1 files changed, 115 insertions, 0 deletions
diff --git a/auth.go b/auth.go
new file mode 100644
index 0000000..c482366
--- /dev/null
+++ b/auth.go
@@ -0,0 +1,115 @@
+package main
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "io"
+ "os"
+ "os/user"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "tildegit.org/tjp/sliderule"
+ "tildegit.org/tjp/sliderule/gemini"
+)
+
+func GeminiAuthMiddleware(auth *Auth) sliderule.Middleware {
+ if auth == nil {
+ return func(inner sliderule.Handler) sliderule.Handler { return inner }
+ }
+
+ return func(inner sliderule.Handler) sliderule.Handler {
+ return sliderule.HandlerFunc(func(ctx context.Context, request *sliderule.Request) *sliderule.Response {
+ if auth.Strategy.Approve(ctx, request) {
+ return inner.Handle(ctx, request)
+ }
+
+ if len(request.TLSState.PeerCertificates) == 0 {
+ return gemini.RequireCert("client certificate required")
+ }
+ return gemini.CertAuthFailure("client certificate rejected")
+ })
+ }
+}
+
+func ClientTLSFile(path string) (AuthStrategy, error) {
+ if strings.Contains(path, "~") {
+ return UserClientTLSAuth(path), nil
+ }
+
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = f.Close() }()
+
+ contents, err := io.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+
+ fingerprints := []string{}
+ for _, line := range strings.Split(string(contents), "\n") {
+ line = strings.Trim(line, " \t\r")
+ if len(line) == sha256.Size*2 {
+ fingerprints = append(fingerprints, line)
+ }
+ }
+ return ClientTLSAuth(fingerprints), nil
+}
+
+func ClientTLS(raw string) AuthStrategy {
+ fingerprints := []string{}
+ for _, fp := range strings.Split(raw, ",") {
+ fp = strings.Trim(fp, " \t\r")
+ if len(fp) == sha256.Size*2 {
+ fingerprints = append(fingerprints, fp)
+ }
+ }
+ return ClientTLSAuth(fingerprints)
+}
+
+type UserClientTLSAuth string
+
+func (ca UserClientTLSAuth) Approve(ctx context.Context, request *sliderule.Request) bool {
+ u, err := user.Lookup(sliderule.RouteParams(ctx)["username"])
+ if err != nil {
+ return false
+ }
+ fpath := resolveTilde(string(ca), u)
+
+ strat, err := ClientTLSFile(fpath)
+ if err != nil {
+ return false
+ }
+ return strat.Approve(ctx, request)
+}
+
+func resolveTilde(path string, u *user.User) string {
+ if strings.HasPrefix(path, "~/") {
+ return filepath.Join(u.HomeDir, path[1:])
+ }
+ return strings.ReplaceAll(path, "~", u.Username)
+}
+
+type ClientTLSAuth []string
+
+func (ca ClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
+ if request.TLSState == nil || len(request.TLSState.PeerCertificates) == 0 {
+ return false
+ }
+ return slices.Contains(ca, fingerprint(request.TLSState.PeerCertificates[0].Raw))
+}
+
+func fingerprint(raw []byte) string {
+ hash := sha256.Sum256(raw)
+ return hex.EncodeToString(hash[:])
+}
+
+type HasClientTLSAuth struct{}
+
+func (_ HasClientTLSAuth) Approve(_ context.Context, request *sliderule.Request) bool {
+ return request.TLSState != nil && len(request.TLSState.PeerCertificates) > 0
+}