package server

import (
	"context"
	"net/http"
	"strings"

	"codeberg.org/pronounscc/pronouns.cc/backend/log"
	"codeberg.org/pronounscc/pronouns.cc/backend/server/auth"
	"github.com/go-chi/render"
)

// maybeAuth is a globally-used middleware.
func (s *Server) maybeAuth(next http.Handler) http.Handler {
	fn := func(w http.ResponseWriter, r *http.Request) {
		token := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
		if token == "" {
			next.ServeHTTP(w, r)
			return
		}

		claims, err := s.Auth.Claims(token)
		if err != nil {
			render.Status(r, errCodeStatuses[ErrInvalidToken])
			render.JSON(w, r, APIError{
				Code:    ErrInvalidToken,
				Message: errCodeMessages[ErrInvalidToken],
			})
			return
		}

		// "valid" here refers to existence and expiry date, not whether the token is known
		valid, err := s.DB.TokenValid(r.Context(), claims.UserID, claims.TokenID)
		if err != nil {
			log.Errorf("validating token for user %v: %v", claims.UserID, err)
			render.Status(r, errCodeStatuses[ErrInternalServerError])
			render.JSON(w, r, APIError{
				Code:    ErrInternalServerError,
				Message: errCodeMessages[ErrInternalServerError],
			})
			return
		}

		if !valid {
			render.Status(r, errCodeStatuses[ErrInvalidToken])
			render.JSON(w, r, APIError{
				Code:    ErrInvalidToken,
				Message: errCodeMessages[ErrInvalidToken],
			})
			return
		}

		ctx := context.WithValue(r.Context(), ctxKeyClaims, claims)

		next.ServeHTTP(w, r.WithContext(ctx))
	}

	return http.HandlerFunc(fn)
}

// MustAuth makes a valid token required
func MustAuth(next http.Handler) http.Handler {
	fn := func(w http.ResponseWriter, r *http.Request) {
		_, ok := ClaimsFromContext(r.Context())
		if !ok {
			render.Status(r, errCodeStatuses[ErrForbidden])
			render.JSON(w, r, APIError{
				Code:    ErrForbidden,
				Message: errCodeMessages[ErrForbidden],
			})
			return
		}

		next.ServeHTTP(w, r)
	}

	return http.HandlerFunc(fn)
}

// ClaimsFromContext returns the auth.Claims in the context, if any.
func ClaimsFromContext(ctx context.Context) (auth.Claims, bool) {
	v := ctx.Value(ctxKeyClaims)
	if v == nil {
		return auth.Claims{}, false
	}

	claims, ok := v.(auth.Claims)
	return claims, ok
}