forked from casdoor/casdoor
Compare commits
5 Commits
v2.335.1
...
copilot/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6ab3a18ea | ||
|
|
b7fa3a6194 | ||
|
|
ac1307e576 | ||
|
|
7f2c238eba | ||
|
|
2cf83d3b0c |
@@ -30,6 +30,8 @@ ldapsServerPort = 636
|
||||
radiusServerPort = 1812
|
||||
radiusDefaultOrganization = "built-in"
|
||||
radiusSecret = "secret"
|
||||
proxyHttpPort =
|
||||
proxyHttpsPort =
|
||||
quota = {"organization": -1, "user": -1, "application": -1, "provider": -1}
|
||||
logConfig = {"adapter":"file", "filename": "logs/casdoor.log", "maxdays":99999, "perm":"0770"}
|
||||
initDataNewOnly = false
|
||||
|
||||
2
main.go
2
main.go
@@ -72,6 +72,7 @@ func main() {
|
||||
object.InitFromFile()
|
||||
object.InitCasvisorConfig()
|
||||
object.InitCleanupTokens()
|
||||
object.InitApplicationMap()
|
||||
|
||||
util.SafeGoroutine(func() { object.RunSyncUsersJob() })
|
||||
util.SafeGoroutine(func() { controllers.InitCLIDownloader() })
|
||||
@@ -125,6 +126,7 @@ func main() {
|
||||
go ldap.StartLdapServer()
|
||||
go radius.StartRadiusServer()
|
||||
go object.ClearThroughputPerSecond()
|
||||
go proxy.StartProxyServer()
|
||||
|
||||
web.Run(fmt.Sprintf(":%v", port))
|
||||
}
|
||||
|
||||
@@ -173,6 +173,16 @@ func GetOrganizationApplicationCount(owner, organization, field, value string) (
|
||||
return session.Where("organization = ? or is_shared = ? ", organization, true).Count(&Application{})
|
||||
}
|
||||
|
||||
func GetGlobalApplications() ([]*Application, error) {
|
||||
applications := []*Application{}
|
||||
err := ormer.Engine.Desc("created_time").Find(&applications)
|
||||
if err != nil {
|
||||
return applications, err
|
||||
}
|
||||
|
||||
return applications, nil
|
||||
}
|
||||
|
||||
func GetApplications(owner string) ([]*Application, error) {
|
||||
applications := []*Application{}
|
||||
err := ormer.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner})
|
||||
@@ -758,6 +768,12 @@ func UpdateApplication(id string, application *Application, isGlobalAdmin bool,
|
||||
return false, err
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after update: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
@@ -809,6 +825,12 @@ func AddApplication(application *Application) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after add: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
@@ -818,6 +840,12 @@ func deleteApplication(application *Application) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after delete: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
|
||||
85
object/application_cache.go
Normal file
85
object/application_cache.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package object
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/casdoor/casdoor/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
applicationMap = make(map[string]*Application)
|
||||
applicationMapMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func InitApplicationMap() error {
|
||||
// Set up the application lookup function for the proxy package
|
||||
proxy.SetApplicationLookup(func(domain string) *proxy.Application {
|
||||
app := GetApplicationByDomain(domain)
|
||||
if app == nil {
|
||||
return nil
|
||||
}
|
||||
return &proxy.Application{
|
||||
Owner: app.Owner,
|
||||
Name: app.Name,
|
||||
UpstreamHost: app.UpstreamHost,
|
||||
}
|
||||
})
|
||||
|
||||
return refreshApplicationMap()
|
||||
}
|
||||
|
||||
func refreshApplicationMap() error {
|
||||
applications, err := GetGlobalApplications()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get global applications: %w", err)
|
||||
}
|
||||
|
||||
newApplicationMap := make(map[string]*Application)
|
||||
for _, app := range applications {
|
||||
if app.Domain != "" {
|
||||
newApplicationMap[strings.ToLower(app.Domain)] = app
|
||||
}
|
||||
for _, domain := range app.OtherDomains {
|
||||
if domain != "" {
|
||||
newApplicationMap[strings.ToLower(domain)] = app
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applicationMapMutex.Lock()
|
||||
applicationMap = newApplicationMap
|
||||
applicationMapMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetApplicationByDomain(domain string) *Application {
|
||||
applicationMapMutex.RLock()
|
||||
defer applicationMapMutex.RUnlock()
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
if app, ok := applicationMap[domain]; ok {
|
||||
return app
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RefreshApplicationCache() error {
|
||||
return refreshApplicationMap()
|
||||
}
|
||||
229
proxy/reverse_proxy.go
Normal file
229
proxy/reverse_proxy.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/beego/beego/v2/core/logs"
|
||||
"github.com/casdoor/casdoor/conf"
|
||||
)
|
||||
|
||||
// Application represents a simplified application structure for reverse proxy
|
||||
type Application struct {
|
||||
Owner string
|
||||
Name string
|
||||
UpstreamHost string
|
||||
}
|
||||
|
||||
// ApplicationLookupFunc is a function type for looking up applications by domain
|
||||
type ApplicationLookupFunc func(domain string) *Application
|
||||
|
||||
var applicationLookup ApplicationLookupFunc
|
||||
|
||||
// SetApplicationLookup sets the function to use for looking up applications by domain
|
||||
func SetApplicationLookup(lookupFunc ApplicationLookupFunc) {
|
||||
applicationLookup = lookupFunc
|
||||
}
|
||||
|
||||
// getDomainWithoutPort removes the port from a domain string
|
||||
func getDomainWithoutPort(domain string) string {
|
||||
if !strings.Contains(domain, ":") {
|
||||
return domain
|
||||
}
|
||||
|
||||
tokens := strings.SplitN(domain, ":", 2)
|
||||
if len(tokens) > 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
return domain
|
||||
}
|
||||
|
||||
// forwardHandler creates and configures a reverse proxy for the given target URL
|
||||
func forwardHandler(targetUrl string, writer http.ResponseWriter, request *http.Request) {
|
||||
target, err := url.Parse(targetUrl)
|
||||
if err != nil {
|
||||
logs.Error("Failed to parse target URL %s: %v", targetUrl, err)
|
||||
http.Error(writer, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
// Configure the Director to set proper headers
|
||||
proxy.Director = func(r *http.Request) {
|
||||
r.URL.Scheme = target.Scheme
|
||||
r.URL.Host = target.Host
|
||||
r.Host = target.Host
|
||||
|
||||
// Set X-Real-IP and X-Forwarded-For headers
|
||||
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
r.Header.Set("X-Forwarded-For", fmt.Sprintf("%s, %s", xff, clientIP))
|
||||
} else {
|
||||
r.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
r.Header.Set("X-Real-IP", clientIP)
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Proto header
|
||||
if r.TLS != nil {
|
||||
r.Header.Set("X-Forwarded-Proto", "https")
|
||||
} else {
|
||||
r.Header.Set("X-Forwarded-Proto", "http")
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Host header
|
||||
r.Header.Set("X-Forwarded-Host", request.Host)
|
||||
}
|
||||
|
||||
// Handle ModifyResponse for security enhancements
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Add Secure flag to all Set-Cookie headers in HTTPS responses
|
||||
if request.TLS != nil {
|
||||
// Add HSTS header for HTTPS responses if not already set by backend
|
||||
if resp.Header.Get("Strict-Transport-Security") == "" {
|
||||
resp.Header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
cookies := resp.Header["Set-Cookie"]
|
||||
if len(cookies) > 0 {
|
||||
// Clear existing Set-Cookie headers
|
||||
resp.Header.Del("Set-Cookie")
|
||||
// Add them back with Secure flag if not already present
|
||||
for _, cookie := range cookies {
|
||||
// Check if Secure attribute is already present (case-insensitive)
|
||||
cookieLower := strings.ToLower(cookie)
|
||||
hasSecure := strings.Contains(cookieLower, ";secure;") ||
|
||||
strings.Contains(cookieLower, "; secure;") ||
|
||||
strings.HasSuffix(cookieLower, ";secure") ||
|
||||
strings.HasSuffix(cookieLower, "; secure")
|
||||
if !hasSecure {
|
||||
cookie = cookie + "; Secure"
|
||||
}
|
||||
resp.Header.Add("Set-Cookie", cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
// HandleReverseProxy handles incoming requests and forwards them to the appropriate upstream
|
||||
func HandleReverseProxy(w http.ResponseWriter, r *http.Request) {
|
||||
domain := getDomainWithoutPort(r.Host)
|
||||
|
||||
if applicationLookup == nil {
|
||||
logs.Error("Application lookup function not set")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Lookup the application by domain
|
||||
app := applicationLookup(domain)
|
||||
if app == nil {
|
||||
logs.Info("No application found for domain: %s", domain)
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the application has an upstream host configured
|
||||
if app.UpstreamHost == "" {
|
||||
logs.Warn("Application %s/%s has no upstream host configured", app.Owner, app.Name)
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the target URL - just use the upstream host, the actual path/query will be set by the proxy Director
|
||||
targetUrl := app.UpstreamHost
|
||||
if !strings.HasPrefix(targetUrl, "http://") && !strings.HasPrefix(targetUrl, "https://") {
|
||||
targetUrl = "http://" + targetUrl
|
||||
}
|
||||
|
||||
logs.Debug("Forwarding request from %s%s to %s", r.Host, r.RequestURI, targetUrl)
|
||||
forwardHandler(targetUrl, w, r)
|
||||
}
|
||||
|
||||
// StartProxyServer starts the HTTP and HTTPS proxy servers based on configuration
|
||||
func StartProxyServer() {
|
||||
proxyHttpPort := conf.GetConfigString("proxyHttpPort")
|
||||
proxyHttpsPort := conf.GetConfigString("proxyHttpsPort")
|
||||
|
||||
if proxyHttpPort == "" && proxyHttpsPort == "" {
|
||||
logs.Info("Reverse proxy not enabled (proxyHttpPort and proxyHttpsPort are empty)")
|
||||
return
|
||||
}
|
||||
|
||||
serverMux := http.NewServeMux()
|
||||
serverMux.HandleFunc("/", HandleReverseProxy)
|
||||
|
||||
// Start HTTP proxy if configured
|
||||
if proxyHttpPort != "" {
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%s", proxyHttpPort)
|
||||
logs.Info("Starting reverse proxy HTTP server on %s", addr)
|
||||
err := http.ListenAndServe(addr, serverMux)
|
||||
if err != nil {
|
||||
logs.Error("Failed to start HTTP proxy server: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start HTTPS proxy if configured
|
||||
if proxyHttpsPort != "" {
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%s", proxyHttpsPort)
|
||||
|
||||
// For now, HTTPS will need certificate configuration
|
||||
// This can be enhanced later to use Application's SslCert field
|
||||
logs.Info("HTTPS proxy server on %s requires certificate configuration - not implemented yet", addr)
|
||||
|
||||
// When implemented, use code like:
|
||||
// server := &http.Server{
|
||||
// Handler: serverMux,
|
||||
// Addr: addr,
|
||||
// TLSConfig: &tls.Config{
|
||||
// MinVersion: tls.VersionTLS12,
|
||||
// PreferServerCipherSuites: true,
|
||||
// CipherSuites: []uint16{
|
||||
// tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
// tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
// tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
// },
|
||||
// CurvePreferences: []tls.CurveID{
|
||||
// tls.X25519,
|
||||
// tls.CurveP256,
|
||||
// tls.CurveP384,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// err := server.ListenAndServeTLS("", "")
|
||||
// if err != nil {
|
||||
// logs.Error("Failed to start HTTPS proxy server: %v", err)
|
||||
// }
|
||||
}()
|
||||
}
|
||||
}
|
||||
210
proxy/reverse_proxy_integration_test.go
Normal file
210
proxy/reverse_proxy_integration_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestReverseProxyIntegration tests the reverse proxy with a real backend server
|
||||
func TestReverseProxyIntegration(t *testing.T) {
|
||||
// Create a test backend server that echoes the request path
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify headers
|
||||
headers := []string{
|
||||
"X-Forwarded-For",
|
||||
"X-Forwarded-Proto",
|
||||
"X-Real-IP",
|
||||
"X-Forwarded-Host",
|
||||
}
|
||||
|
||||
for _, header := range headers {
|
||||
if r.Header.Get(header) == "" {
|
||||
t.Errorf("Expected header %s to be set", header)
|
||||
}
|
||||
}
|
||||
|
||||
// Echo the path and query
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Path: " + r.URL.Path + "\n"))
|
||||
w.Write([]byte("Query: " + r.URL.RawQuery + "\n"))
|
||||
w.Write([]byte("Host: " + r.Host + "\n"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Set up the application lookup
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "myapp.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "my-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test various request paths
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
query string
|
||||
expected string
|
||||
}{
|
||||
{"Simple path", "/", "", "Path: /\n"},
|
||||
{"Path with segments", "/api/v1/users", "", "Path: /api/v1/users\n"},
|
||||
{"Path with query", "/search", "q=test&limit=10", "Query: q=test&limit=10\n"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := "http://myapp.example.com" + tt.path
|
||||
if tt.query != "" {
|
||||
url += "?" + tt.query
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req.Host = "myapp.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
bodyStr := string(body)
|
||||
|
||||
if !strings.Contains(bodyStr, tt.expected) {
|
||||
t.Errorf("Expected response to contain %q, got %q", tt.expected, bodyStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReverseProxyWebSocket tests that WebSocket upgrade headers are preserved
|
||||
func TestReverseProxyWebSocket(t *testing.T) {
|
||||
// Note: WebSocket upgrade through httptest.ResponseRecorder has limitations
|
||||
// This test verifies that WebSocket headers are passed through, but
|
||||
// full WebSocket functionality would need integration testing with real servers
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify WebSocket headers are present
|
||||
if r.Header.Get("Upgrade") == "websocket" &&
|
||||
r.Header.Get("Connection") != "" &&
|
||||
r.Header.Get("Sec-WebSocket-Version") != "" &&
|
||||
r.Header.Get("Sec-WebSocket-Key") != "" {
|
||||
// Headers are present - this is what we're testing
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("WebSocket headers received"))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("Missing WebSocket headers"))
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "ws.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "ws-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://ws.example.com/ws", nil)
|
||||
req.Host = "ws.example.com"
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
bodyStr := string(body)
|
||||
|
||||
// We expect the headers to be passed through to the backend
|
||||
if !strings.Contains(bodyStr, "WebSocket headers received") {
|
||||
t.Errorf("WebSocket headers were not properly forwarded. Got: %s", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReverseProxyUpstreamHostVariations tests different UpstreamHost formats
|
||||
func TestReverseProxyUpstreamHostVariations(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Parse backend URL to get host
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse backend URL: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
upstreamHost string
|
||||
shouldWork bool
|
||||
}{
|
||||
{"Full URL", backend.URL, true},
|
||||
{"Host only", backendURL.Host, true},
|
||||
{"Empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "test.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app",
|
||||
UpstreamHost: tt.upstreamHost,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://test.example.com/", nil)
|
||||
req.Host = "test.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if tt.shouldWork {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
} else {
|
||||
if w.Code == http.StatusOK {
|
||||
t.Errorf("Expected failure, but got status 200")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
148
proxy/reverse_proxy_test.go
Normal file
148
proxy/reverse_proxy_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDomainWithoutPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"example.com", "example.com"},
|
||||
{"example.com:8080", "example.com"},
|
||||
{"localhost:3000", "localhost"},
|
||||
{"subdomain.example.com:443", "subdomain.example.com"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := getDomainWithoutPort(test.input)
|
||||
if result != test.expected {
|
||||
t.Errorf("getDomainWithoutPort(%s) = %s; want %s", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReverseProxy(t *testing.T) {
|
||||
// Create a test backend server
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that headers are set correctly
|
||||
if r.Header.Get("X-Forwarded-For") == "" {
|
||||
t.Error("X-Forwarded-For header not set")
|
||||
}
|
||||
if r.Header.Get("X-Forwarded-Proto") == "" {
|
||||
t.Error("X-Forwarded-Proto header not set")
|
||||
}
|
||||
if r.Header.Get("X-Real-IP") == "" {
|
||||
t.Error("X-Real-IP header not set")
|
||||
}
|
||||
if r.Header.Get("X-Forwarded-Host") == "" {
|
||||
t.Error("X-Forwarded-Host header not set")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "Backend response")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Set up a mock application lookup function
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "test.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test successful proxy
|
||||
req := httptest.NewRequest("GET", "http://test.example.com/path", nil)
|
||||
req.Host = "test.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test domain not found
|
||||
req = httptest.NewRequest("GET", "http://unknown.example.com/path", nil)
|
||||
req.Host = "unknown.example.com"
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404 for unknown domain, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test application without upstream host
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "no-upstream.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app-no-upstream",
|
||||
UpstreamHost: "",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("GET", "http://no-upstream.example.com/path", nil)
|
||||
req.Host = "no-upstream.example.com"
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404 for app without upstream, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationLookup(t *testing.T) {
|
||||
// Test setting and using the application lookup function
|
||||
called := false
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
called = true
|
||||
return &Application{
|
||||
Owner: "test",
|
||||
Name: "app",
|
||||
UpstreamHost: "http://localhost:8080",
|
||||
}
|
||||
})
|
||||
|
||||
if applicationLookup == nil {
|
||||
t.Error("applicationLookup should not be nil after SetApplicationLookup")
|
||||
}
|
||||
|
||||
app := applicationLookup("test.com")
|
||||
if !called {
|
||||
t.Error("applicationLookup function was not called")
|
||||
}
|
||||
if app == nil {
|
||||
t.Error("applicationLookup should return non-nil application")
|
||||
}
|
||||
if app.Owner != "test" {
|
||||
t.Errorf("Expected owner 'test', got '%s'", app.Owner)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user