package kate import ( "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strconv" "strings" "testing" ) // Mock implementation of MagicLinkMailer for testing type mockMagicLinkUserData[T any] struct { users map[string]T counter int tokens map[string]*MagicLinkToken[T] sentEmails []struct { user T token string } sendEmailFunc func(T, string) error } func newMockMagicLinkUserData[T any](users map[string]T, sendEmailFunc func(T, string) error) *mockMagicLinkUserData[T] { return &mockMagicLinkUserData[T]{ users: users, tokens: make(map[string]*MagicLinkToken[T]), sendEmailFunc: sendEmailFunc, } } func (m *mockMagicLinkUserData[T]) GenerateToken(username, redirectPath string) (*MagicLinkToken[T], error) { user, exists := m.users[username] if !exists { return nil, nil } token := &MagicLinkToken[T]{ Identifier: strconv.Itoa(m.counter), UserData: user, RedirectPath: redirectPath, } m.counter++ m.tokens[token.Identifier] = token return token, nil } func (m *mockMagicLinkUserData[T]) ValidateToken(identifier string) (*MagicLinkToken[T], error) { token, ok := m.tokens[identifier] if !ok { return nil, fmt.Errorf("no token %q", identifier) } delete(m.tokens, identifier) return token, nil } func (m *mockMagicLinkUserData[T]) SendEmail(userData T, token string) error { m.sentEmails = append(m.sentEmails, struct { user T token string }{userData, token}) if m.sendEmailFunc != nil { return m.sendEmailFunc(userData, token) } return nil } type testUser struct { Username string Hash string ID int } type testUserSerDes struct{} func (ts testUserSerDes) Serialize(w io.Writer, data testUser) error { _, err := fmt.Fprintf(w, "%s|%s|%d", data.Username, data.Hash, data.ID) return err } func (ts testUserSerDes) Deserialize(r io.Reader, data *testUser) error { buf, err := io.ReadAll(r) if err != nil { return err } parts := strings.Split(string(buf), "|") if len(parts) != 3 { return errors.New("invalid format") } data.Username = parts[0] data.Hash = parts[1] id, err := strconv.Atoi(parts[2]) if err != nil { return err } data.ID = id return nil } func TestMagicLinkHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, "jane@example.com": {Username: "jane@example.com", Hash: "", ID: 2}, } mockData := newMockMagicLinkUserData(users, nil) config := MagicLinkConfig[testUser]{ UserData: mockData, Redirects: Redirects{ Default: "/dashboard", AllowedPrefixes: []string{"/app/", "/admin/"}, }, } handler := auth.MagicLinkLoginHandler(config) tests := []struct { name string method string formData url.Values expectedStatus int expectedBody string checkEmail bool }{ { name: "successful magic link request", method: "POST", formData: url.Values{"email": {"john@example.com"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: true, }, { name: "magic link with redirect", method: "POST", formData: url.Values{"email": {"john@example.com"}, "redirect": {"/app/settings"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: true, }, { name: "nonexistent user returns success (no user enumeration)", method: "POST", formData: url.Values{"email": {"nobody@example.com"}}, expectedStatus: http.StatusOK, expectedBody: "Magic link sent", checkEmail: false, }, { name: "missing email", method: "POST", formData: url.Values{}, expectedStatus: http.StatusBadRequest, checkEmail: false, }, { name: "GET method not allowed", method: "GET", formData: url.Values{}, expectedStatus: http.StatusMethodNotAllowed, checkEmail: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockData.sentEmails = nil // Reset var req *http.Request if tt.method == "POST" { req = httptest.NewRequest(tt.method, "/magic-link", strings.NewReader(tt.formData.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") } else { req = httptest.NewRequest(tt.method, "/magic-link", nil) } rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if tt.expectedBody != "" { body := strings.TrimSpace(rr.Body.String()) if body != tt.expectedBody { t.Errorf("expected body %q, got %q", tt.expectedBody, body) } } if tt.checkEmail { if len(mockData.sentEmails) != 1 { t.Errorf("expected 1 email sent, got %d", len(mockData.sentEmails)) } else { email := mockData.sentEmails[0] if email.token == "" { t.Error("expected non-empty token") } } } else { if len(mockData.sentEmails) != 0 { t.Errorf("expected no emails sent, got %d", len(mockData.sentEmails)) } } }) } } func TestMagicLinkVerifyHandler(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "john@example.com": {Username: "john@example.com", Hash: "", ID: 1}, } mockData := newMockMagicLinkUserData(users, nil) config := MagicLinkConfig[testUser]{ UserData: mockData, Redirects: Redirects{Default: "/dashbaord"}, } token, err := mockData.GenerateToken("john@example.com", "/app/settings") if err != nil { t.Fatal(err) } validToken := string(auth.enc.Encrypt([]byte(token.Identifier))) handler := auth.MagicLinkVerifyHandler(config) tests := []struct { name string method string token string expectedStatus int expectedRedirect string checkCookie bool }{ { name: "valid token with redirect", method: "GET", token: validToken, expectedStatus: http.StatusSeeOther, expectedRedirect: "/app/settings", checkCookie: true, }, { name: "missing token", method: "GET", token: "", expectedStatus: http.StatusBadRequest, checkCookie: false, }, { name: "invalid token", method: "GET", token: "invalid-token", expectedStatus: http.StatusUnauthorized, checkCookie: false, }, { name: "POST method not allowed", method: "POST", token: validToken, expectedStatus: http.StatusMethodNotAllowed, checkCookie: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { url := "/verify" if tt.token != "" { url += "?token=" + tt.token } req := httptest.NewRequest(tt.method, url, nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != tt.expectedStatus { t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) } if tt.expectedRedirect != "" { location := rr.Header().Get("Location") if location != tt.expectedRedirect { t.Errorf("expected redirect to %s, got %s", tt.expectedRedirect, location) } } if tt.checkCookie { cookies := rr.Result().Cookies() found := false for _, cookie := range cookies { if cookie.Name == "test_session" && cookie.Value != "" { found = true break } } if !found { t.Error("expected authentication cookie to be set") } } }) } } func TestMagicLinkConfigDefaults(t *testing.T) { config := MagicLinkConfig[testUser]{} config.setDefaults() if config.UsernameField != "email" { t.Errorf("expected UsernameField to be 'email', got %s", config.UsernameField) } if config.TokenField != "token" { t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) } if config.TokenLocation != TokenLocationQuery { t.Errorf("expected TokenLocation to be TokenLocationQuery, got %v", config.TokenLocation) } } func TestMagicLinkVerifyHandlerPathToken(t *testing.T) { auth := New("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", AuthConfig[testUser]{ SerDes: testUserSerDes{}, CookieName: "test_session", }) users := map[string]testUser{ "test": {Username: "test", ID: 1}, } mockData := newMockMagicLinkUserData(users, nil) config := MagicLinkConfig[testUser]{ UserData: mockData, Redirects: Redirects{Default: "/dashboard"}, TokenLocation: TokenLocationPath, TokenField: "token", } handler := auth.MagicLinkVerifyHandler(config) if config.TokenLocation != TokenLocationPath { t.Errorf("expected TokenLocation to be TokenLocationPath, got %v", config.TokenLocation) } if config.TokenField != "token" { t.Errorf("expected TokenField to be 'token', got %s", config.TokenField) } req := httptest.NewRequest("GET", "/verify", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected status %d for missing token, got %d", http.StatusBadRequest, rr.Code) } } func TestTokenLocationEnum(t *testing.T) { if TokenLocationQuery.location != 0 { t.Errorf("expected TokenLocationQuery to be 0, got %d", TokenLocationQuery.location) } if TokenLocationPath.location != 1 { t.Errorf("expected TokenLocationPath to be 1, got %d", TokenLocationPath.location) } }