diff --git a/bouncer.go b/bouncer.go index aaeaeb0..8b6f611 100644 --- a/bouncer.go +++ b/bouncer.go @@ -296,13 +296,13 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { remoteIP, err := ip.GetRemoteIP(req, bouncer.serverPoolStrategy, bouncer.forwardedCustomHeader) if err != nil { bouncer.log.Error(fmt.Sprintf("ServeHTTP:getRemoteIp ip:%s %s", remoteIP, err.Error())) - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) return } isTrusted, err := bouncer.clientPoolStrategy.Checker.Contains(remoteIP) if err != nil { bouncer.log.Error(fmt.Sprintf("ServeHTTP:checkerContains ip:%s %s", remoteIP, err.Error())) - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) return } // if our IP is in the trusted list we bypass the next checks @@ -330,7 +330,7 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if cacheErrString != cache.CacheMiss { bouncer.log.Error(fmt.Sprintf("ServeHTTP:Get ip:%s %s", remoteIP, cacheErrString)) - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) return } } else { @@ -350,7 +350,7 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) { handleNextServeHTTP(bouncer, remoteIP, rw, req) } else { bouncer.log.Debug(fmt.Sprintf("ServeHTTP isCrowdsecStreamHealthy:false ip:%s updateFailure:%d", remoteIP, updateFailure)) - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) } } else { value, err := handleNoStreamCache(bouncer, remoteIP) @@ -392,7 +392,7 @@ type Login struct { } // To append Headers we need to call rw.WriteHeader after set any header. -func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter) { +func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter, method string) { atomic.AddInt64(&blockedRequests, 1) if bouncer.remediationCustomHeader != "" { @@ -404,15 +404,20 @@ func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter) { } rw.Header().Set("Content-Type", "text/html; charset=utf-8") rw.WriteHeader(bouncer.remediationStatusCode) + + if method == http.MethodHead { + return + } _, err := fmt.Fprint(rw, bouncer.banTemplateString) if err != nil { - bouncer.log.Error("handleBanServeHTTP could not write template to ResponseWriter") + // use warn when https://github.com/maxlerebourg/crowdsec-bouncer-traefik-plugin/pull/276 is completed + bouncer.log.Error("handleBanServeHTTP could not write template to ResponseWriter: " + err.Error()) } } func handleRemediationServeHTTP(bouncer *Bouncer, remoteIP, remediation string, rw http.ResponseWriter, req *http.Request) { bouncer.log.Debug(fmt.Sprintf("handleRemediationServeHTTP ip:%s remediation:%s", remoteIP, remediation)) - if bouncer.captchaClient.Valid && remediation == cache.CaptchaValue { + if bouncer.captchaClient.Valid && remediation == cache.CaptchaValue && req.Method != http.MethodHead { if bouncer.captchaClient.Check(remoteIP) { handleNextServeHTTP(bouncer, remoteIP, rw, req) return @@ -421,14 +426,14 @@ func handleRemediationServeHTTP(bouncer *Bouncer, remoteIP, remediation string, bouncer.captchaClient.ServeHTTP(rw, req, remoteIP) return } - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) } func handleNextServeHTTP(bouncer *Bouncer, remoteIP string, rw http.ResponseWriter, req *http.Request) { if bouncer.appsecEnabled { if err := appsecQuery(bouncer, remoteIP, req); err != nil { bouncer.log.Debug(fmt.Sprintf("handleNextServeHTTP ip:%s isWaf:true %s", remoteIP, err.Error())) - handleBanServeHTTP(bouncer, rw) + handleBanServeHTTP(bouncer, rw, req.Method) return } } diff --git a/bouncer_test.go b/bouncer_test.go index 9499506..62ee7b5 100644 --- a/bouncer_test.go +++ b/bouncer_test.go @@ -186,3 +186,144 @@ func Test_crowdsecQuery(t *testing.T) { }) } } + +func TestHandleBanServeHTTPWithDifferentMethods(t *testing.T) { + tests := []struct { + name string + method string + banTemplateString string + expectBodyContent bool + }{ + { + name: "GET request should have body with template", + method: http.MethodGet, + banTemplateString: "You are banned", + expectBodyContent: true, + }, + { + name: "HEAD request should NOT have body even with template", + method: http.MethodHead, + banTemplateString: "You are banned", + expectBodyContent: false, + }, + { + name: "POST request should have body with template", + method: http.MethodPost, + banTemplateString: "You are banned", + expectBodyContent: true, + }, + { + name: "PUT request should have body with template", + method: http.MethodPut, + banTemplateString: "You are banned", + expectBodyContent: true, + }, + { + name: "DELETE request should have body with template", + method: http.MethodDelete, + banTemplateString: "You are banned", + expectBodyContent: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bouncer := &Bouncer{ + remediationStatusCode: 403, + remediationCustomHeader: "X-Test-Remediation", + banTemplateString: tt.banTemplateString, + } + + rw := httptest.NewRecorder() + handleBanServeHTTP(bouncer, rw, tt.method) + + // Check status code + if rw.Code != 403 { + t.Errorf("Expected status code 403, got %d", rw.Code) + } + + // Check custom header + headerValue := rw.Header().Get("X-Test-Remediation") + if headerValue != "ban" { + t.Errorf("Expected header X-Test-Remediation to be 'ban', got %s", headerValue) + } + + // Check body content + body := rw.Body.String() + hasBodyContent := len(body) > 0 + + if hasBodyContent != tt.expectBodyContent { + t.Errorf("Method %s: expected body content: %v, got body content: %v (body: %q)", + tt.method, tt.expectBodyContent, hasBodyContent, body) + } + + // If we expect body content, verify it matches template + if tt.expectBodyContent && body != tt.banTemplateString { + t.Errorf("Expected body %q, got %q", tt.banTemplateString, body) + } + }) + } +} +func TestCaptchaMethodBasedLogic(t *testing.T) { + tests := []struct { + name string + method string + remediation string + expectBanFallback bool + }{ + { + name: "GET with captcha remediation should allow captcha", + method: http.MethodGet, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + { + name: "HEAD with captcha remediation should fallback to ban", + method: http.MethodHead, + remediation: cache.CaptchaValue, + expectBanFallback: true, + }, + { + name: "POST with captcha remediation should allow captcha", + method: http.MethodPost, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + { + name: "PUT with captcha remediation should allow captcha", + method: http.MethodPut, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + { + name: "DELETE with captcha remediation should allow captcha", + method: http.MethodDelete, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + { + name: "PATCH with captcha remediation should allow captcha", + method: http.MethodPatch, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + { + name: "OPTIONS with captcha remediation should allow captcha", + method: http.MethodOptions, + remediation: cache.CaptchaValue, + expectBanFallback: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the core logic: captcha is served for all methods except HEAD + shouldUseCaptcha := tt.remediation == cache.CaptchaValue && tt.method != http.MethodHead + + if shouldUseCaptcha == tt.expectBanFallback { + t.Errorf("Method %s with %s remediation: expected ban fallback %v, but logic would use captcha %v", + tt.method, tt.remediation, tt.expectBanFallback, shouldUseCaptcha) + } + }) + } +}