🐛Not write response body for HEAD requests (#277)

* Fixes

* XX

* Fix

* 🍱 Lint

* 🍱 remove useless comments

---------

Co-authored-by: Max Lerebourg <maxlerebourg@gmail.com>
This commit is contained in:
David
2025-10-06 11:19:19 +02:00
committed by GitHub
parent a2ecc95dc9
commit 65a2f79fb3
2 changed files with 155 additions and 9 deletions

View File

@@ -296,13 +296,13 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
remoteIP, err := ip.GetRemoteIP(req, bouncer.serverPoolStrategy, bouncer.forwardedCustomHeader)
if err != nil {
bouncer.log.Error(fmt.Sprintf("ServeHTTP:getRemoteIp ip:%s %s", remoteIP, err.Error()))
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
return
}
isTrusted, err := bouncer.clientPoolStrategy.Checker.Contains(remoteIP)
if err != nil {
bouncer.log.Error(fmt.Sprintf("ServeHTTP:checkerContains ip:%s %s", remoteIP, err.Error()))
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
return
}
// if our IP is in the trusted list we bypass the next checks
@@ -330,7 +330,7 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if cacheErrString != cache.CacheMiss {
bouncer.log.Error(fmt.Sprintf("ServeHTTP:Get ip:%s %s", remoteIP, cacheErrString))
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
return
}
} else {
@@ -350,7 +350,7 @@ func (bouncer *Bouncer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
handleNextServeHTTP(bouncer, remoteIP, rw, req)
} else {
bouncer.log.Debug(fmt.Sprintf("ServeHTTP isCrowdsecStreamHealthy:false ip:%s updateFailure:%d", remoteIP, updateFailure))
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
}
} else {
value, err := handleNoStreamCache(bouncer, remoteIP)
@@ -392,7 +392,7 @@ type Login struct {
}
// To append Headers we need to call rw.WriteHeader after set any header.
func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter) {
func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter, method string) {
atomic.AddInt64(&blockedRequests, 1)
if bouncer.remediationCustomHeader != "" {
@@ -404,15 +404,20 @@ func handleBanServeHTTP(bouncer *Bouncer, rw http.ResponseWriter) {
}
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(bouncer.remediationStatusCode)
if method == http.MethodHead {
return
}
_, err := fmt.Fprint(rw, bouncer.banTemplateString)
if err != nil {
bouncer.log.Error("handleBanServeHTTP could not write template to ResponseWriter")
// use warn when https://github.com/maxlerebourg/crowdsec-bouncer-traefik-plugin/pull/276 is completed
bouncer.log.Error("handleBanServeHTTP could not write template to ResponseWriter: " + err.Error())
}
}
func handleRemediationServeHTTP(bouncer *Bouncer, remoteIP, remediation string, rw http.ResponseWriter, req *http.Request) {
bouncer.log.Debug(fmt.Sprintf("handleRemediationServeHTTP ip:%s remediation:%s", remoteIP, remediation))
if bouncer.captchaClient.Valid && remediation == cache.CaptchaValue {
if bouncer.captchaClient.Valid && remediation == cache.CaptchaValue && req.Method != http.MethodHead {
if bouncer.captchaClient.Check(remoteIP) {
handleNextServeHTTP(bouncer, remoteIP, rw, req)
return
@@ -421,14 +426,14 @@ func handleRemediationServeHTTP(bouncer *Bouncer, remoteIP, remediation string,
bouncer.captchaClient.ServeHTTP(rw, req, remoteIP)
return
}
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
}
func handleNextServeHTTP(bouncer *Bouncer, remoteIP string, rw http.ResponseWriter, req *http.Request) {
if bouncer.appsecEnabled {
if err := appsecQuery(bouncer, remoteIP, req); err != nil {
bouncer.log.Debug(fmt.Sprintf("handleNextServeHTTP ip:%s isWaf:true %s", remoteIP, err.Error()))
handleBanServeHTTP(bouncer, rw)
handleBanServeHTTP(bouncer, rw, req.Method)
return
}
}

View File

@@ -186,3 +186,144 @@ func Test_crowdsecQuery(t *testing.T) {
})
}
}
func TestHandleBanServeHTTPWithDifferentMethods(t *testing.T) {
tests := []struct {
name string
method string
banTemplateString string
expectBodyContent bool
}{
{
name: "GET request should have body with template",
method: http.MethodGet,
banTemplateString: "<html>You are banned</html>",
expectBodyContent: true,
},
{
name: "HEAD request should NOT have body even with template",
method: http.MethodHead,
banTemplateString: "<html>You are banned</html>",
expectBodyContent: false,
},
{
name: "POST request should have body with template",
method: http.MethodPost,
banTemplateString: "<html>You are banned</html>",
expectBodyContent: true,
},
{
name: "PUT request should have body with template",
method: http.MethodPut,
banTemplateString: "<html>You are banned</html>",
expectBodyContent: true,
},
{
name: "DELETE request should have body with template",
method: http.MethodDelete,
banTemplateString: "<html>You are banned</html>",
expectBodyContent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bouncer := &Bouncer{
remediationStatusCode: 403,
remediationCustomHeader: "X-Test-Remediation",
banTemplateString: tt.banTemplateString,
}
rw := httptest.NewRecorder()
handleBanServeHTTP(bouncer, rw, tt.method)
// Check status code
if rw.Code != 403 {
t.Errorf("Expected status code 403, got %d", rw.Code)
}
// Check custom header
headerValue := rw.Header().Get("X-Test-Remediation")
if headerValue != "ban" {
t.Errorf("Expected header X-Test-Remediation to be 'ban', got %s", headerValue)
}
// Check body content
body := rw.Body.String()
hasBodyContent := len(body) > 0
if hasBodyContent != tt.expectBodyContent {
t.Errorf("Method %s: expected body content: %v, got body content: %v (body: %q)",
tt.method, tt.expectBodyContent, hasBodyContent, body)
}
// If we expect body content, verify it matches template
if tt.expectBodyContent && body != tt.banTemplateString {
t.Errorf("Expected body %q, got %q", tt.banTemplateString, body)
}
})
}
}
func TestCaptchaMethodBasedLogic(t *testing.T) {
tests := []struct {
name string
method string
remediation string
expectBanFallback bool
}{
{
name: "GET with captcha remediation should allow captcha",
method: http.MethodGet,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
{
name: "HEAD with captcha remediation should fallback to ban",
method: http.MethodHead,
remediation: cache.CaptchaValue,
expectBanFallback: true,
},
{
name: "POST with captcha remediation should allow captcha",
method: http.MethodPost,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
{
name: "PUT with captcha remediation should allow captcha",
method: http.MethodPut,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
{
name: "DELETE with captcha remediation should allow captcha",
method: http.MethodDelete,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
{
name: "PATCH with captcha remediation should allow captcha",
method: http.MethodPatch,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
{
name: "OPTIONS with captcha remediation should allow captcha",
method: http.MethodOptions,
remediation: cache.CaptchaValue,
expectBanFallback: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic: captcha is served for all methods except HEAD
shouldUseCaptcha := tt.remediation == cache.CaptchaValue && tt.method != http.MethodHead
if shouldUseCaptcha == tt.expectBanFallback {
t.Errorf("Method %s with %s remediation: expected ban fallback %v, but logic would use captcha %v",
tt.method, tt.remediation, tt.expectBanFallback, shouldUseCaptcha)
}
})
}
}