summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortjp <tjp@ctrl-c.club>2024-01-10 11:56:08 -0700
committertjp <tjp@ctrl-c.club>2024-01-10 11:56:08 -0700
commitcde393cdf50391ccac137a4cc6a9ed231ec3b6d1 (patch)
treef58eeca192af819ca419912e010ce86b4a5d00ae
parent9b4b34baa338dfa4c90497a037d6b2a297351df4 (diff)
prompt to update TOFU store on violations
-rw-r--r--actions.go60
-rw-r--r--tls.go18
2 files changed, 67 insertions, 11 deletions
diff --git a/actions.go b/actions.go
index 4196232..b8b3c87 100644
--- a/actions.go
+++ b/actions.go
@@ -127,13 +127,13 @@ func Reload(state *BrowserState, conf *Config) error {
body := io.LimitReader(bytes.NewBuffer(input), int64(len(input)))
state.Url.Fragment = ""
- response, err = upload(state.Url.String(), body, tlsConf)
+ response, err = upload(state, state.Url.String(), body, tlsConf)
state.Url.Fragment = "prompt"
if err != nil {
return err
}
} else {
- response, err = fetch(urlStr, tlsConf)
+ response, err = fetch(state, urlStr, tlsConf)
if err != nil {
return err
}
@@ -152,7 +152,7 @@ outer:
state.Url = response.Request.URL
state.Url.RawQuery = url.QueryEscape(strings.TrimRight(line, "\n"))
- response, err = fetch(state.Url.String(), tlsConf)
+ response, err = fetch(state, state.Url.String(), tlsConf)
if err != nil {
return err
}
@@ -164,7 +164,7 @@ outer:
state.Url = response.Request.URL
state.Url.RawQuery = url.QueryEscape(strings.TrimRight(string(line), "\n"))
- response, err = fetch(state.Url.String(), tlsConf)
+ response, err = fetch(state, state.Url.String(), tlsConf)
if err != nil {
return err
}
@@ -190,14 +190,58 @@ outer:
return HandleResource(state, conf)
}
-func fetch(u string, tlsConf *tls.Config) (*sliderule.Response, error) {
+func fetch(state *BrowserState, u string, tlsConf *tls.Config) (*sliderule.Response, error) {
tlsConf.ClientSessionCache = nil
- return sliderule.NewClient(tlsConf).Fetch(u)
+ response, err := sliderule.NewClient(tlsConf).Fetch(u)
+ var tofuErr *TOFUViolation
+ if errors.As(err, &tofuErr) {
+ writeError(err.Error())
+ state.Readline.SetPrompt("Trust new certificate instead (y/n)? [n] ")
+ line, err := state.Readline.Readline()
+ if err != nil {
+ return nil, err
+ }
+ if line != "y" {
+ return nil, tofuErr
+ }
+
+ tofuStore[tofuErr.domain] = tofuErr.got
+ if err := saveTofuStore(tofuStore); err != nil {
+ return nil, err
+ }
+
+ return sliderule.NewClient(tlsConf).Fetch(u)
+ } else if err != nil {
+ return nil, err
+ }
+ return response, nil
}
-func upload(u string, body io.Reader, tlsConf *tls.Config) (*sliderule.Response, error) {
+func upload(state *BrowserState, u string, body io.Reader, tlsConf *tls.Config) (*sliderule.Response, error) {
tlsConf.ClientSessionCache = nil
- return sliderule.NewClient(tlsConf).Upload(u, body)
+ response, err := sliderule.NewClient(tlsConf).Upload(u, body)
+ var tofuErr *TOFUViolation
+ if errors.As(err, &tofuErr) {
+ writeError(err.Error())
+ state.Readline.SetPrompt("Trust new certificate instead (y/n)? [n] ")
+ line, err := state.Readline.Readline()
+ if err != nil {
+ return nil, err
+ }
+ if line != "y" {
+ return nil, tofuErr
+ }
+
+ tofuStore[tofuErr.domain] = tofuErr.got
+ if err := saveTofuStore(tofuStore); err != nil {
+ return nil, err
+ }
+
+ return sliderule.NewClient(tlsConf).Upload(u, body)
+ } else if err != nil {
+ return nil, err
+ }
+ return response, nil
}
func externalMessage() ([]byte, error) {
diff --git a/tls.go b/tls.go
index fa25441..d4452f2 100644
--- a/tls.go
+++ b/tls.go
@@ -8,7 +8,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/hex"
- "errors"
+ "fmt"
"math/big"
"os"
"time"
@@ -23,7 +23,15 @@ func tlsConfig(state *BrowserState) *tls.Config {
var tofuStore map[string]string
-var ErrTOFUViolation = errors.New("certificate for this domain has changed")
+type TOFUViolation struct {
+ domain string
+ expected string
+ got string
+}
+
+func (tv *TOFUViolation) Error() string {
+ return fmt.Sprintf("certificate for domain %s has changed from %s to %s", tv.domain, tv.expected, tv.got)
+}
var anonymousTLS = &tls.Config{
InsecureSkipVerify: true,
@@ -43,7 +51,11 @@ func tofuVerify(connState tls.ConnectionState) error {
}
if certhash != expected {
- return ErrTOFUViolation
+ return &TOFUViolation{
+ domain: connState.ServerName,
+ expected: expected,
+ got: certhash,
+ }
}
return nil
}