Go Authentication Middleware

Authentication and authorization middleware patterns for Go web applications. Includes JWT, OAuth2, Auth0, and CORS implementations.

Use Case

  • Protect API endpoints with authentication
  • Implement role-based access control
  • Integrate with OAuth providers (Auth0, Google, GitHub)
  • Handle CORS for frontend applications

JWT Middleware

Basic JWT Middleware

  1package middleware
  2
  3import (
  4    "context"
  5    "net/http"
  6    "strings"
  7    
  8    "github.com/golang-jwt/jwt/v5"
  9)
 10
 11type contextKey string
 12
 13const UserContextKey contextKey = "user"
 14
 15type Claims struct {
 16    UserID string   `json:"user_id"`
 17    Email  string   `json:"email"`
 18    Roles  []string `json:"roles"`
 19    jwt.RegisteredClaims
 20}
 21
 22// JWTMiddleware validates JWT tokens
 23func JWTMiddleware(secret []byte) func(http.Handler) http.Handler {
 24    return func(next http.Handler) http.Handler {
 25        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 26            // Extract token from Authorization header
 27            authHeader := r.Header.Get("Authorization")
 28            if authHeader == "" {
 29                http.Error(w, "Missing authorization header", http.StatusUnauthorized)
 30                return
 31            }
 32            
 33            // Bearer token format: "Bearer <token>"
 34            parts := strings.Split(authHeader, " ")
 35            if len(parts) != 2 || parts[0] != "Bearer" {
 36                http.Error(w, "Invalid authorization header format", http.StatusUnauthorized)
 37                return
 38            }
 39            
 40            tokenString := parts[1]
 41            
 42            // Parse and validate token
 43            token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
 44                // Validate signing method
 45                if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 46                    return nil, jwt.ErrSignatureInvalid
 47                }
 48                return secret, nil
 49            })
 50            
 51            if err != nil {
 52                http.Error(w, "Invalid token", http.StatusUnauthorized)
 53                return
 54            }
 55            
 56            if !token.Valid {
 57                http.Error(w, "Token is not valid", http.StatusUnauthorized)
 58                return
 59            }
 60            
 61            // Extract claims
 62            claims, ok := token.Claims.(*Claims)
 63            if !ok {
 64                http.Error(w, "Invalid token claims", http.StatusUnauthorized)
 65                return
 66            }
 67            
 68            // Add claims to context
 69            ctx := context.WithValue(r.Context(), UserContextKey, claims)
 70            next.ServeHTTP(w, r.WithContext(ctx))
 71        })
 72    }
 73}
 74
 75// GetUserFromContext extracts user claims from context
 76func GetUserFromContext(ctx context.Context) (*Claims, bool) {
 77    claims, ok := ctx.Value(UserContextKey).(*Claims)
 78    return claims, ok
 79}
 80
 81// RequireRoles middleware checks if user has required roles
 82func RequireRoles(roles ...string) func(http.Handler) http.Handler {
 83    return func(next http.Handler) http.Handler {
 84        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 85            claims, ok := GetUserFromContext(r.Context())
 86            if !ok {
 87                http.Error(w, "Unauthorized", http.StatusUnauthorized)
 88                return
 89            }
 90            
 91            // Check if user has any of the required roles
 92            hasRole := false
 93            for _, requiredRole := range roles {
 94                for _, userRole := range claims.Roles {
 95                    if userRole == requiredRole {
 96                        hasRole = true
 97                        break
 98                    }
 99                }
100                if hasRole {
101                    break
102                }
103            }
104            
105            if !hasRole {
106                http.Error(w, "Forbidden", http.StatusForbidden)
107                return
108            }
109            
110            next.ServeHTTP(w, r)
111        })
112    }
113}

Usage Example

 1package main
 2
 3import (
 4    "encoding/json"
 5    "net/http"
 6    "time"
 7    
 8    "github.com/golang-jwt/jwt/v5"
 9    "github.com/gorilla/mux"
10)
11
12var jwtSecret = []byte("your-secret-key-change-this")
13
14func main() {
15    r := mux.NewRouter()
16    
17    // Public routes
18    r.HandleFunc("/login", loginHandler).Methods("POST")
19    
20    // Protected routes
21    api := r.PathPrefix("/api").Subrouter()
22    api.Use(JWTMiddleware(jwtSecret))
23    
24    api.HandleFunc("/profile", profileHandler).Methods("GET")
25    
26    // Admin-only routes
27    admin := api.PathPrefix("/admin").Subrouter()
28    admin.Use(RequireRoles("admin"))
29    admin.HandleFunc("/users", listUsersHandler).Methods("GET")
30    
31    http.ListenAndServe(":8080", r)
32}
33
34func loginHandler(w http.ResponseWriter, r *http.Request) {
35    var creds struct {
36        Email    string `json:"email"`
37        Password string `json:"password"`
38    }
39    
40    if err := json.NewDecoder(r.Body).Decode(&creds); err != nil {
41        http.Error(w, "Invalid request", http.StatusBadRequest)
42        return
43    }
44    
45    // Validate credentials (check database)
46    // This is a simplified example
47    if creds.Email != "user@example.com" || creds.Password != "password" {
48        http.Error(w, "Invalid credentials", http.StatusUnauthorized)
49        return
50    }
51    
52    // Create JWT token
53    claims := &Claims{
54        UserID: "user123",
55        Email:  creds.Email,
56        Roles:  []string{"user"},
57        RegisteredClaims: jwt.RegisteredClaims{
58            ExpiresAt: jwt.NewNumericDate(time.Now().Add(15 * time.Minute)),
59            IssuedAt:  jwt.NewNumericDate(time.Now()),
60        },
61    }
62    
63    token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
64    tokenString, err := token.SignedString(jwtSecret)
65    if err != nil {
66        http.Error(w, "Error generating token", http.StatusInternalServerError)
67        return
68    }
69    
70    json.NewEncoder(w).Encode(map[string]string{
71        "token": tokenString,
72    })
73}
74
75func profileHandler(w http.ResponseWriter, r *http.Request) {
76    claims, _ := GetUserFromContext(r.Context())
77    json.NewEncoder(w).Encode(claims)
78}
79
80func listUsersHandler(w http.ResponseWriter, r *http.Request) {
81    // Admin-only endpoint
82    json.NewEncoder(w).Encode(map[string]string{
83        "message": "List of users (admin only)",
84    })
85}

Auth0 Integration

  1package middleware
  2
  3import (
  4    "context"
  5    "encoding/json"
  6    "errors"
  7    "net/http"
  8    "net/url"
  9    "strings"
 10    "time"
 11    
 12    "github.com/golang-jwt/jwt/v5"
 13)
 14
 15type Auth0Config struct {
 16    Domain   string
 17    Audience string
 18}
 19
 20type JWKS struct {
 21    Keys []JSONWebKey `json:"keys"`
 22}
 23
 24type JSONWebKey struct {
 25    Kty string   `json:"kty"`
 26    Kid string   `json:"kid"`
 27    Use string   `json:"use"`
 28    N   string   `json:"n"`
 29    E   string   `json:"e"`
 30    X5c []string `json:"x5c"`
 31}
 32
 33// Auth0Middleware validates Auth0 JWT tokens
 34func Auth0Middleware(config Auth0Config) func(http.Handler) http.Handler {
 35    return func(next http.Handler) http.Handler {
 36        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 37            // Extract token
 38            authHeader := r.Header.Get("Authorization")
 39            if authHeader == "" {
 40                http.Error(w, "Missing authorization header", http.StatusUnauthorized)
 41                return
 42            }
 43            
 44            parts := strings.Split(authHeader, " ")
 45            if len(parts) != 2 || parts[0] != "Bearer" {
 46                http.Error(w, "Invalid authorization header", http.StatusUnauthorized)
 47                return
 48            }
 49            
 50            tokenString := parts[1]
 51            
 52            // Parse token without validation to get kid
 53            token, _, err := new(jwt.Parser).ParseUnverified(tokenString, jwt.MapClaims{})
 54            if err != nil {
 55                http.Error(w, "Invalid token", http.StatusUnauthorized)
 56                return
 57            }
 58            
 59            // Get kid from token header
 60            kid, ok := token.Header["kid"].(string)
 61            if !ok {
 62                http.Error(w, "Invalid token header", http.StatusUnauthorized)
 63                return
 64            }
 65            
 66            // Fetch JWKS from Auth0
 67            jwks, err := fetchJWKS(config.Domain)
 68            if err != nil {
 69                http.Error(w, "Error fetching JWKS", http.StatusInternalServerError)
 70                return
 71            }
 72            
 73            // Find matching key
 74            var jwk *JSONWebKey
 75            for _, key := range jwks.Keys {
 76                if key.Kid == kid {
 77                    jwk = &key
 78                    break
 79                }
 80            }
 81            
 82            if jwk == nil {
 83                http.Error(w, "Unable to find appropriate key", http.StatusUnauthorized)
 84                return
 85            }
 86            
 87            // Parse and validate token with public key
 88            parsedToken, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 89                // Verify signing method
 90                if token.Method.Alg() != "RS256" {
 91                    return nil, errors.New("unexpected signing method")
 92                }
 93                
 94                // Convert JWK to public key
 95                cert := "-----BEGIN CERTIFICATE-----\n" + jwk.X5c[0] + "\n-----END CERTIFICATE-----"
 96                return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
 97            })
 98            
 99            if err != nil || !parsedToken.Valid {
100                http.Error(w, "Invalid token", http.StatusUnauthorized)
101                return
102            }
103            
104            // Validate claims
105            claims, ok := parsedToken.Claims.(jwt.MapClaims)
106            if !ok {
107                http.Error(w, "Invalid claims", http.StatusUnauthorized)
108                return
109            }
110            
111            // Validate audience
112            if !claims.VerifyAudience(config.Audience, true) {
113                http.Error(w, "Invalid audience", http.StatusUnauthorized)
114                return
115            }
116            
117            // Validate issuer
118            expectedIssuer := "https://" + config.Domain + "/"
119            if !claims.VerifyIssuer(expectedIssuer, true) {
120                http.Error(w, "Invalid issuer", http.StatusUnauthorized)
121                return
122            }
123            
124            // Add claims to context
125            ctx := context.WithValue(r.Context(), UserContextKey, claims)
126            next.ServeHTTP(w, r.WithContext(ctx))
127        })
128    }
129}
130
131func fetchJWKS(domain string) (*JWKS, error) {
132    jwksURL := "https://" + domain + "/.well-known/jwks.json"
133    
134    client := &http.Client{Timeout: 10 * time.Second}
135    resp, err := client.Get(jwksURL)
136    if err != nil {
137        return nil, err
138    }
139    defer resp.Body.Close()
140    
141    var jwks JWKS
142    if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
143        return nil, err
144    }
145    
146    return &jwks, nil
147}
148
149// Usage
150func main() {
151    auth0Config := Auth0Config{
152        Domain:   "your-tenant.auth0.com",
153        Audience: "https://your-api.com",
154    }
155    
156    r := mux.NewRouter()
157    
158    api := r.PathPrefix("/api").Subrouter()
159    api.Use(Auth0Middleware(auth0Config))
160    api.HandleFunc("/protected", protectedHandler).Methods("GET")
161    
162    http.ListenAndServe(":8080", r)
163}

OAuth2 Middleware (Google, GitHub)

  1package middleware
  2
  3import (
  4    "context"
  5    "encoding/json"
  6    "net/http"
  7    
  8    "golang.org/x/oauth2"
  9    "golang.org/x/oauth2/google"
 10    "golang.org/x/oauth2/github"
 11)
 12
 13type OAuthConfig struct {
 14    ClientID     string
 15    ClientSecret string
 16    RedirectURL  string
 17    Scopes       []string
 18}
 19
 20// Google OAuth
 21func GoogleOAuthConfig(config OAuthConfig) *oauth2.Config {
 22    return &oauth2.Config{
 23        ClientID:     config.ClientID,
 24        ClientSecret: config.ClientSecret,
 25        RedirectURL:  config.RedirectURL,
 26        Scopes:       config.Scopes,
 27        Endpoint:     google.Endpoint,
 28    }
 29}
 30
 31// GitHub OAuth
 32func GitHubOAuthConfig(config OAuthConfig) *oauth2.Config {
 33    return &oauth2.Config{
 34        ClientID:     config.ClientID,
 35        ClientSecret: config.ClientSecret,
 36        RedirectURL:  config.RedirectURL,
 37        Scopes:       config.Scopes,
 38        Endpoint:     github.Endpoint,
 39    }
 40}
 41
 42// OAuth handlers
 43func OAuthLoginHandler(oauthConfig *oauth2.Config) http.HandlerFunc {
 44    return func(w http.ResponseWriter, r *http.Request) {
 45        // Generate random state for CSRF protection
 46        state := generateRandomState() // Implement this
 47        
 48        // Store state in session/cookie
 49        http.SetCookie(w, &http.Cookie{
 50            Name:     "oauth_state",
 51            Value:    state,
 52            MaxAge:   300, // 5 minutes
 53            HttpOnly: true,
 54            Secure:   true,
 55            SameSite: http.SameSiteLaxMode,
 56        })
 57        
 58        // Redirect to OAuth provider
 59        url := oauthConfig.AuthCodeURL(state)
 60        http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 61    }
 62}
 63
 64func OAuthCallbackHandler(oauthConfig *oauth2.Config) http.HandlerFunc {
 65    return func(w http.ResponseWriter, r *http.Request) {
 66        // Validate state (CSRF protection)
 67        stateCookie, err := r.Cookie("oauth_state")
 68        if err != nil {
 69            http.Error(w, "State cookie not found", http.StatusBadRequest)
 70            return
 71        }
 72        
 73        state := r.URL.Query().Get("state")
 74        if state != stateCookie.Value {
 75            http.Error(w, "Invalid state parameter", http.StatusBadRequest)
 76            return
 77        }
 78        
 79        // Exchange code for token
 80        code := r.URL.Query().Get("code")
 81        token, err := oauthConfig.Exchange(context.Background(), code)
 82        if err != nil {
 83            http.Error(w, "Failed to exchange token", http.StatusInternalServerError)
 84            return
 85        }
 86        
 87        // Fetch user info
 88        client := oauthConfig.Client(context.Background(), token)
 89        resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") // For Google
 90        if err != nil {
 91            http.Error(w, "Failed to get user info", http.StatusInternalServerError)
 92            return
 93        }
 94        defer resp.Body.Close()
 95        
 96        var userInfo map[string]interface{}
 97        if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
 98            http.Error(w, "Failed to decode user info", http.StatusInternalServerError)
 99            return
100        }
101        
102        // Create session or JWT token
103        // Store user info in database
104        // Redirect to application
105        
106        json.NewEncoder(w).Encode(userInfo)
107    }
108}
109
110// Usage
111func main() {
112    googleConfig := GoogleOAuthConfig(OAuthConfig{
113        ClientID:     "your-client-id",
114        ClientSecret: "your-client-secret",
115        RedirectURL:  "http://localhost:8080/auth/google/callback",
116        Scopes:       []string{"email", "profile"},
117    })
118    
119    r := mux.NewRouter()
120    r.HandleFunc("/auth/google", OAuthLoginHandler(googleConfig)).Methods("GET")
121    r.HandleFunc("/auth/google/callback", OAuthCallbackHandler(googleConfig)).Methods("GET")
122    
123    http.ListenAndServe(":8080", r)
124}

CORS Middleware

 1package middleware
 2
 3import (
 4    "net/http"
 5)
 6
 7type CORSConfig struct {
 8    AllowedOrigins   []string
 9    AllowedMethods   []string
10    AllowedHeaders   []string
11    ExposedHeaders   []string
12    AllowCredentials bool
13    MaxAge           int
14}
15
16func DefaultCORSConfig() CORSConfig {
17    return CORSConfig{
18        AllowedOrigins:   []string{"*"},
19        AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
20        AllowedHeaders:   []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
21        ExposedHeaders:   []string{"Link"},
22        AllowCredentials: false,
23        MaxAge:           300,
24    }
25}
26
27func CORSMiddleware(config CORSConfig) func(http.Handler) http.Handler {
28    return func(next http.Handler) http.Handler {
29        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30            origin := r.Header.Get("Origin")
31            
32            // Check if origin is allowed
33            allowed := false
34            for _, allowedOrigin := range config.AllowedOrigins {
35                if allowedOrigin == "*" || allowedOrigin == origin {
36                    allowed = true
37                    break
38                }
39            }
40            
41            if !allowed {
42                next.ServeHTTP(w, r)
43                return
44            }
45            
46            // Set CORS headers
47            if origin != "" {
48                w.Header().Set("Access-Control-Allow-Origin", origin)
49            } else if len(config.AllowedOrigins) == 1 {
50                w.Header().Set("Access-Control-Allow-Origin", config.AllowedOrigins[0])
51            }
52            
53            w.Header().Set("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
54            w.Header().Set("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
55            w.Header().Set("Access-Control-Expose-Headers", strings.Join(config.ExposedHeaders, ", "))
56            
57            if config.AllowCredentials {
58                w.Header().Set("Access-Control-Allow-Credentials", "true")
59            }
60            
61            if config.MaxAge > 0 {
62                w.Header().Set("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
63            }
64            
65            // Handle preflight requests
66            if r.Method == "OPTIONS" {
67                w.WriteHeader(http.StatusNoContent)
68                return
69            }
70            
71            next.ServeHTTP(w, r)
72        })
73    }
74}
75
76// Usage
77func main() {
78    corsConfig := CORSConfig{
79        AllowedOrigins:   []string{"http://localhost:3000", "https://app.example.com"},
80        AllowedMethods:   []string{"GET", "POST", "PUT", "DELETE"},
81        AllowedHeaders:   []string{"Authorization", "Content-Type"},
82        AllowCredentials: true,
83        MaxAge:           3600,
84    }
85    
86    r := mux.NewRouter()
87    r.Use(CORSMiddleware(corsConfig))
88    
89    r.HandleFunc("/api/data", dataHandler).Methods("GET", "POST")
90    
91    http.ListenAndServe(":8080", r)
92}

General Middleware Pattern

 1package middleware
 2
 3import (
 4    "log"
 5    "net/http"
 6    "time"
 7)
 8
 9// Middleware type
10type Middleware func(http.Handler) http.Handler
11
12// Chain multiple middleware
13func Chain(middlewares ...Middleware) Middleware {
14    return func(final http.Handler) http.Handler {
15        for i := len(middlewares) - 1; i >= 0; i-- {
16            final = middlewares[i](final)
17        }
18        return final
19    }
20}
21
22// Logging middleware
23func LoggingMiddleware(next http.Handler) http.Handler {
24    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25        start := time.Now()
26        
27        // Call next handler
28        next.ServeHTTP(w, r)
29        
30        log.Printf("%s %s %s", r.Method, r.RequestURI, time.Since(start))
31    })
32}
33
34// Recovery middleware
35func RecoveryMiddleware(next http.Handler) http.Handler {
36    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37        defer func() {
38            if err := recover(); err != nil {
39                log.Printf("Panic: %v", err)
40                http.Error(w, "Internal Server Error", http.StatusInternalServerError)
41            }
42        }()
43        
44        next.ServeHTTP(w, r)
45    })
46}
47
48// Usage: Chain multiple middleware
49func main() {
50    handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
51        w.Write([]byte("Hello, World!"))
52    })
53    
54    // Apply middleware in order
55    wrapped := Chain(
56        RecoveryMiddleware,
57        LoggingMiddleware,
58        CORSMiddleware(DefaultCORSConfig()),
59        JWTMiddleware(jwtSecret),
60    )(handler)
61    
62    http.ListenAndServe(":8080", wrapped)
63}

Notes

Security Checklist:

  • ✅ Always validate JWT signatures
  • ✅ Use HTTPS in production
  • ✅ Set secure cookie flags (HttpOnly, Secure, SameSite)
  • ✅ Implement rate limiting on auth endpoints
  • ✅ Use short-lived access tokens (15 min) + refresh tokens
  • ✅ Validate token expiration
  • ✅ Implement CSRF protection for cookies
  • ✅ Whitelist CORS origins (don't use * in production)
  • ✅ Log authentication failures
  • ✅ Implement account lockout after failed attempts

Common Pitfalls:

  • ❌ Storing JWT secret in code (use environment variables)
  • ❌ Not validating token expiration
  • ❌ Using weak signing algorithms
  • ❌ Exposing sensitive data in JWT payload
  • ❌ Not implementing token refresh
  • ❌ Allowing CORS * with credentials

Related Snippets