forked from casdoor/casdoor
Compare commits
5 Commits
v2.323.0
...
copilot/ad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6ab3a18ea | ||
|
|
b7fa3a6194 | ||
|
|
ac1307e576 | ||
|
|
7f2c238eba | ||
|
|
2cf83d3b0c |
@@ -1,5 +1,5 @@
|
||||
<h1 align="center" style="border-bottom: none;">📦⚡️ Casdoor</h1>
|
||||
<h3 align="center">An open-source AI-first Identity and Access Management (IAM) /AI MCP gateway and auth server with web UI supporting MCP, A2A, OAuth 2.1, OIDC, SAML, CAS, LDAP, SCIM, WebAuthn, TOTP, MFA, Face ID, Google Workspace, Azure AD</h3>
|
||||
<h3 align="center">An open-source UI-first Identity and Access Management (IAM) / Single-Sign-On (SSO) platform with web UI supporting OAuth 2.0, OIDC, SAML, CAS, LDAP, SCIM, WebAuthn, TOTP, MFA and RADIUS</h3>
|
||||
<p align="center">
|
||||
<a href="#badge">
|
||||
<img alt="semantic-release" src="https://img.shields.io/badge/%20%20%F0%9F%93%A6%F0%9F%9A%80-semantic--release-e10079.svg">
|
||||
|
||||
@@ -30,6 +30,8 @@ ldapsServerPort = 636
|
||||
radiusServerPort = 1812
|
||||
radiusDefaultOrganization = "built-in"
|
||||
radiusSecret = "secret"
|
||||
proxyHttpPort =
|
||||
proxyHttpsPort =
|
||||
quota = {"organization": -1, "user": -1, "application": -1, "provider": -1}
|
||||
logConfig = {"adapter":"file", "filename": "logs/casdoor.log", "maxdays":99999, "perm":"0770"}
|
||||
initDataNewOnly = false
|
||||
|
||||
@@ -739,11 +739,7 @@ func (c *ApiController) Login() {
|
||||
}
|
||||
} else if provider.Category == "OAuth" || provider.Category == "Web3" {
|
||||
// OAuth
|
||||
idpInfo, err := object.FromProviderToIdpInfo(c.Ctx, provider)
|
||||
if err != nil {
|
||||
c.ResponseError(err.Error())
|
||||
return
|
||||
}
|
||||
idpInfo := object.FromProviderToIdpInfo(c.Ctx, provider)
|
||||
idpInfo.CodeVerifier = authForm.CodeVerifier
|
||||
var idProvider idp.IdProvider
|
||||
idProvider, err = idp.GetIdProvider(idpInfo, authForm.RedirectUri)
|
||||
|
||||
@@ -264,31 +264,27 @@ func rsaSignWithRSA256(signContent string, privateKey string) (string, error) {
|
||||
|
||||
// privateKey in database is a string, format it to PEM style
|
||||
func formatPrivateKey(privateKey string) string {
|
||||
// Check if the key is already in PEM format
|
||||
if strings.HasPrefix(privateKey, "-----BEGIN PRIVATE KEY-----") ||
|
||||
strings.HasPrefix(privateKey, "-----BEGIN RSA PRIVATE KEY-----") {
|
||||
// Key is already in PEM format, return as is
|
||||
return privateKey
|
||||
}
|
||||
|
||||
// Remove any whitespace from the key
|
||||
privateKey = strings.ReplaceAll(privateKey, "\n", "")
|
||||
privateKey = strings.ReplaceAll(privateKey, "\r", "")
|
||||
privateKey = strings.ReplaceAll(privateKey, " ", "")
|
||||
|
||||
// Format the key with line breaks every 64 characters using strings.Builder
|
||||
var builder strings.Builder
|
||||
for i := 0; i < len(privateKey); i += 64 {
|
||||
end := i + 64
|
||||
if end > len(privateKey) {
|
||||
end = len(privateKey)
|
||||
}
|
||||
builder.WriteString(privateKey[i:end])
|
||||
if end < len(privateKey) {
|
||||
builder.WriteString("\n")
|
||||
// each line length is 64
|
||||
preFmtPrivateKey := ""
|
||||
for i := 0; ; {
|
||||
if i+64 <= len(privateKey) {
|
||||
preFmtPrivateKey = preFmtPrivateKey + privateKey[i:i+64] + "\n"
|
||||
i += 64
|
||||
} else {
|
||||
preFmtPrivateKey = preFmtPrivateKey + privateKey[i:]
|
||||
break
|
||||
}
|
||||
}
|
||||
privateKey = strings.Trim(preFmtPrivateKey, "\n")
|
||||
|
||||
// add pkcs#8 BEGIN and END
|
||||
return "-----BEGIN PRIVATE KEY-----\n" + builder.String() + "\n-----END PRIVATE KEY-----"
|
||||
PemBegin := "-----BEGIN PRIVATE KEY-----\n"
|
||||
PemEnd := "\n-----END PRIVATE KEY-----"
|
||||
if !strings.HasPrefix(privateKey, PemBegin) {
|
||||
privateKey = PemBegin + privateKey
|
||||
}
|
||||
if !strings.HasSuffix(privateKey, PemEnd) {
|
||||
privateKey = privateKey + PemEnd
|
||||
}
|
||||
return privateKey
|
||||
}
|
||||
|
||||
2
main.go
2
main.go
@@ -72,6 +72,7 @@ func main() {
|
||||
object.InitFromFile()
|
||||
object.InitCasvisorConfig()
|
||||
object.InitCleanupTokens()
|
||||
object.InitApplicationMap()
|
||||
|
||||
util.SafeGoroutine(func() { object.RunSyncUsersJob() })
|
||||
util.SafeGoroutine(func() { controllers.InitCLIDownloader() })
|
||||
@@ -125,6 +126,7 @@ func main() {
|
||||
go ldap.StartLdapServer()
|
||||
go radius.StartRadiusServer()
|
||||
go object.ClearThroughputPerSecond()
|
||||
go proxy.StartProxyServer()
|
||||
|
||||
web.Run(fmt.Sprintf(":%v", port))
|
||||
}
|
||||
|
||||
@@ -173,6 +173,16 @@ func GetOrganizationApplicationCount(owner, organization, field, value string) (
|
||||
return session.Where("organization = ? or is_shared = ? ", organization, true).Count(&Application{})
|
||||
}
|
||||
|
||||
func GetGlobalApplications() ([]*Application, error) {
|
||||
applications := []*Application{}
|
||||
err := ormer.Engine.Desc("created_time").Find(&applications)
|
||||
if err != nil {
|
||||
return applications, err
|
||||
}
|
||||
|
||||
return applications, nil
|
||||
}
|
||||
|
||||
func GetApplications(owner string) ([]*Application, error) {
|
||||
applications := []*Application{}
|
||||
err := ormer.Engine.Desc("created_time").Find(&applications, &Application{Owner: owner})
|
||||
@@ -758,6 +768,12 @@ func UpdateApplication(id string, application *Application, isGlobalAdmin bool,
|
||||
return false, err
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after update: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
@@ -809,6 +825,12 @@ func AddApplication(application *Application) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after add: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
@@ -818,6 +840,12 @@ func deleteApplication(application *Application) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if affected != 0 {
|
||||
if err := RefreshApplicationCache(); err != nil {
|
||||
fmt.Printf("Failed to refresh application cache after delete: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
return affected != 0, nil
|
||||
}
|
||||
|
||||
|
||||
85
object/application_cache.go
Normal file
85
object/application_cache.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package object
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/casdoor/casdoor/proxy"
|
||||
)
|
||||
|
||||
var (
|
||||
applicationMap = make(map[string]*Application)
|
||||
applicationMapMutex sync.RWMutex
|
||||
)
|
||||
|
||||
func InitApplicationMap() error {
|
||||
// Set up the application lookup function for the proxy package
|
||||
proxy.SetApplicationLookup(func(domain string) *proxy.Application {
|
||||
app := GetApplicationByDomain(domain)
|
||||
if app == nil {
|
||||
return nil
|
||||
}
|
||||
return &proxy.Application{
|
||||
Owner: app.Owner,
|
||||
Name: app.Name,
|
||||
UpstreamHost: app.UpstreamHost,
|
||||
}
|
||||
})
|
||||
|
||||
return refreshApplicationMap()
|
||||
}
|
||||
|
||||
func refreshApplicationMap() error {
|
||||
applications, err := GetGlobalApplications()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get global applications: %w", err)
|
||||
}
|
||||
|
||||
newApplicationMap := make(map[string]*Application)
|
||||
for _, app := range applications {
|
||||
if app.Domain != "" {
|
||||
newApplicationMap[strings.ToLower(app.Domain)] = app
|
||||
}
|
||||
for _, domain := range app.OtherDomains {
|
||||
if domain != "" {
|
||||
newApplicationMap[strings.ToLower(domain)] = app
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applicationMapMutex.Lock()
|
||||
applicationMap = newApplicationMap
|
||||
applicationMapMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetApplicationByDomain(domain string) *Application {
|
||||
applicationMapMutex.RLock()
|
||||
defer applicationMapMutex.RUnlock()
|
||||
|
||||
domain = strings.ToLower(domain)
|
||||
if app, ok := applicationMap[domain]; ok {
|
||||
return app
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RefreshApplicationCache() error {
|
||||
return refreshApplicationMap()
|
||||
}
|
||||
@@ -564,7 +564,7 @@ func providerChangeTrigger(oldName string, newName string) error {
|
||||
return session.Commit()
|
||||
}
|
||||
|
||||
func FromProviderToIdpInfo(ctx *context.Context, provider *Provider) (*idp.ProviderInfo, error) {
|
||||
func FromProviderToIdpInfo(ctx *context.Context, provider *Provider) *idp.ProviderInfo {
|
||||
providerInfo := &idp.ProviderInfo{
|
||||
Type: provider.Type,
|
||||
SubType: provider.SubType,
|
||||
@@ -588,19 +588,9 @@ func FromProviderToIdpInfo(ctx *context.Context, provider *Provider) (*idp.Provi
|
||||
}
|
||||
} else if provider.Type == "ADFS" || provider.Type == "AzureAD" || provider.Type == "AzureADB2C" || provider.Type == "Casdoor" || provider.Type == "Okta" {
|
||||
providerInfo.HostUrl = provider.Domain
|
||||
} else if provider.Type == "Alipay" && provider.Cert != "" {
|
||||
// For Alipay with certificate mode, load private key from certificate
|
||||
cert, err := GetCert(util.GetId(provider.Owner, provider.Cert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load certificate for Alipay provider %s: %w", provider.Name, err)
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, fmt.Errorf("certificate not found for Alipay provider %s", provider.Name)
|
||||
}
|
||||
providerInfo.ClientSecret = cert.PrivateKey
|
||||
}
|
||||
|
||||
return providerInfo, nil
|
||||
return providerInfo
|
||||
}
|
||||
|
||||
func GetIdvProviderFromProvider(provider *Provider) idv.IdvProvider {
|
||||
|
||||
229
proxy/reverse_proxy.go
Normal file
229
proxy/reverse_proxy.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/beego/beego/v2/core/logs"
|
||||
"github.com/casdoor/casdoor/conf"
|
||||
)
|
||||
|
||||
// Application represents a simplified application structure for reverse proxy
|
||||
type Application struct {
|
||||
Owner string
|
||||
Name string
|
||||
UpstreamHost string
|
||||
}
|
||||
|
||||
// ApplicationLookupFunc is a function type for looking up applications by domain
|
||||
type ApplicationLookupFunc func(domain string) *Application
|
||||
|
||||
var applicationLookup ApplicationLookupFunc
|
||||
|
||||
// SetApplicationLookup sets the function to use for looking up applications by domain
|
||||
func SetApplicationLookup(lookupFunc ApplicationLookupFunc) {
|
||||
applicationLookup = lookupFunc
|
||||
}
|
||||
|
||||
// getDomainWithoutPort removes the port from a domain string
|
||||
func getDomainWithoutPort(domain string) string {
|
||||
if !strings.Contains(domain, ":") {
|
||||
return domain
|
||||
}
|
||||
|
||||
tokens := strings.SplitN(domain, ":", 2)
|
||||
if len(tokens) > 1 {
|
||||
return tokens[0]
|
||||
}
|
||||
return domain
|
||||
}
|
||||
|
||||
// forwardHandler creates and configures a reverse proxy for the given target URL
|
||||
func forwardHandler(targetUrl string, writer http.ResponseWriter, request *http.Request) {
|
||||
target, err := url.Parse(targetUrl)
|
||||
if err != nil {
|
||||
logs.Error("Failed to parse target URL %s: %v", targetUrl, err)
|
||||
http.Error(writer, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
// Configure the Director to set proper headers
|
||||
proxy.Director = func(r *http.Request) {
|
||||
r.URL.Scheme = target.Scheme
|
||||
r.URL.Host = target.Host
|
||||
r.Host = target.Host
|
||||
|
||||
// Set X-Real-IP and X-Forwarded-For headers
|
||||
if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
r.Header.Set("X-Forwarded-For", fmt.Sprintf("%s, %s", xff, clientIP))
|
||||
} else {
|
||||
r.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
r.Header.Set("X-Real-IP", clientIP)
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Proto header
|
||||
if r.TLS != nil {
|
||||
r.Header.Set("X-Forwarded-Proto", "https")
|
||||
} else {
|
||||
r.Header.Set("X-Forwarded-Proto", "http")
|
||||
}
|
||||
|
||||
// Set X-Forwarded-Host header
|
||||
r.Header.Set("X-Forwarded-Host", request.Host)
|
||||
}
|
||||
|
||||
// Handle ModifyResponse for security enhancements
|
||||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||||
// Add Secure flag to all Set-Cookie headers in HTTPS responses
|
||||
if request.TLS != nil {
|
||||
// Add HSTS header for HTTPS responses if not already set by backend
|
||||
if resp.Header.Get("Strict-Transport-Security") == "" {
|
||||
resp.Header.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
cookies := resp.Header["Set-Cookie"]
|
||||
if len(cookies) > 0 {
|
||||
// Clear existing Set-Cookie headers
|
||||
resp.Header.Del("Set-Cookie")
|
||||
// Add them back with Secure flag if not already present
|
||||
for _, cookie := range cookies {
|
||||
// Check if Secure attribute is already present (case-insensitive)
|
||||
cookieLower := strings.ToLower(cookie)
|
||||
hasSecure := strings.Contains(cookieLower, ";secure;") ||
|
||||
strings.Contains(cookieLower, "; secure;") ||
|
||||
strings.HasSuffix(cookieLower, ";secure") ||
|
||||
strings.HasSuffix(cookieLower, "; secure")
|
||||
if !hasSecure {
|
||||
cookie = cookie + "; Secure"
|
||||
}
|
||||
resp.Header.Add("Set-Cookie", cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
proxy.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
// HandleReverseProxy handles incoming requests and forwards them to the appropriate upstream
|
||||
func HandleReverseProxy(w http.ResponseWriter, r *http.Request) {
|
||||
domain := getDomainWithoutPort(r.Host)
|
||||
|
||||
if applicationLookup == nil {
|
||||
logs.Error("Application lookup function not set")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Lookup the application by domain
|
||||
app := applicationLookup(domain)
|
||||
if app == nil {
|
||||
logs.Info("No application found for domain: %s", domain)
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the application has an upstream host configured
|
||||
if app.UpstreamHost == "" {
|
||||
logs.Warn("Application %s/%s has no upstream host configured", app.Owner, app.Name)
|
||||
http.Error(w, "Not Found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the target URL - just use the upstream host, the actual path/query will be set by the proxy Director
|
||||
targetUrl := app.UpstreamHost
|
||||
if !strings.HasPrefix(targetUrl, "http://") && !strings.HasPrefix(targetUrl, "https://") {
|
||||
targetUrl = "http://" + targetUrl
|
||||
}
|
||||
|
||||
logs.Debug("Forwarding request from %s%s to %s", r.Host, r.RequestURI, targetUrl)
|
||||
forwardHandler(targetUrl, w, r)
|
||||
}
|
||||
|
||||
// StartProxyServer starts the HTTP and HTTPS proxy servers based on configuration
|
||||
func StartProxyServer() {
|
||||
proxyHttpPort := conf.GetConfigString("proxyHttpPort")
|
||||
proxyHttpsPort := conf.GetConfigString("proxyHttpsPort")
|
||||
|
||||
if proxyHttpPort == "" && proxyHttpsPort == "" {
|
||||
logs.Info("Reverse proxy not enabled (proxyHttpPort and proxyHttpsPort are empty)")
|
||||
return
|
||||
}
|
||||
|
||||
serverMux := http.NewServeMux()
|
||||
serverMux.HandleFunc("/", HandleReverseProxy)
|
||||
|
||||
// Start HTTP proxy if configured
|
||||
if proxyHttpPort != "" {
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%s", proxyHttpPort)
|
||||
logs.Info("Starting reverse proxy HTTP server on %s", addr)
|
||||
err := http.ListenAndServe(addr, serverMux)
|
||||
if err != nil {
|
||||
logs.Error("Failed to start HTTP proxy server: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start HTTPS proxy if configured
|
||||
if proxyHttpsPort != "" {
|
||||
go func() {
|
||||
addr := fmt.Sprintf(":%s", proxyHttpsPort)
|
||||
|
||||
// For now, HTTPS will need certificate configuration
|
||||
// This can be enhanced later to use Application's SslCert field
|
||||
logs.Info("HTTPS proxy server on %s requires certificate configuration - not implemented yet", addr)
|
||||
|
||||
// When implemented, use code like:
|
||||
// server := &http.Server{
|
||||
// Handler: serverMux,
|
||||
// Addr: addr,
|
||||
// TLSConfig: &tls.Config{
|
||||
// MinVersion: tls.VersionTLS12,
|
||||
// PreferServerCipherSuites: true,
|
||||
// CipherSuites: []uint16{
|
||||
// tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
// tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
// tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
// tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
// },
|
||||
// CurvePreferences: []tls.CurveID{
|
||||
// tls.X25519,
|
||||
// tls.CurveP256,
|
||||
// tls.CurveP384,
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// err := server.ListenAndServeTLS("", "")
|
||||
// if err != nil {
|
||||
// logs.Error("Failed to start HTTPS proxy server: %v", err)
|
||||
// }
|
||||
}()
|
||||
}
|
||||
}
|
||||
210
proxy/reverse_proxy_integration_test.go
Normal file
210
proxy/reverse_proxy_integration_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestReverseProxyIntegration tests the reverse proxy with a real backend server
|
||||
func TestReverseProxyIntegration(t *testing.T) {
|
||||
// Create a test backend server that echoes the request path
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify headers
|
||||
headers := []string{
|
||||
"X-Forwarded-For",
|
||||
"X-Forwarded-Proto",
|
||||
"X-Real-IP",
|
||||
"X-Forwarded-Host",
|
||||
}
|
||||
|
||||
for _, header := range headers {
|
||||
if r.Header.Get(header) == "" {
|
||||
t.Errorf("Expected header %s to be set", header)
|
||||
}
|
||||
}
|
||||
|
||||
// Echo the path and query
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Path: " + r.URL.Path + "\n"))
|
||||
w.Write([]byte("Query: " + r.URL.RawQuery + "\n"))
|
||||
w.Write([]byte("Host: " + r.Host + "\n"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Set up the application lookup
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "myapp.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "my-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test various request paths
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
query string
|
||||
expected string
|
||||
}{
|
||||
{"Simple path", "/", "", "Path: /\n"},
|
||||
{"Path with segments", "/api/v1/users", "", "Path: /api/v1/users\n"},
|
||||
{"Path with query", "/search", "q=test&limit=10", "Query: q=test&limit=10\n"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := "http://myapp.example.com" + tt.path
|
||||
if tt.query != "" {
|
||||
url += "?" + tt.query
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", url, nil)
|
||||
req.Host = "myapp.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
bodyStr := string(body)
|
||||
|
||||
if !strings.Contains(bodyStr, tt.expected) {
|
||||
t.Errorf("Expected response to contain %q, got %q", tt.expected, bodyStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReverseProxyWebSocket tests that WebSocket upgrade headers are preserved
|
||||
func TestReverseProxyWebSocket(t *testing.T) {
|
||||
// Note: WebSocket upgrade through httptest.ResponseRecorder has limitations
|
||||
// This test verifies that WebSocket headers are passed through, but
|
||||
// full WebSocket functionality would need integration testing with real servers
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify WebSocket headers are present
|
||||
if r.Header.Get("Upgrade") == "websocket" &&
|
||||
r.Header.Get("Connection") != "" &&
|
||||
r.Header.Get("Sec-WebSocket-Version") != "" &&
|
||||
r.Header.Get("Sec-WebSocket-Key") != "" {
|
||||
// Headers are present - this is what we're testing
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("WebSocket headers received"))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("Missing WebSocket headers"))
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "ws.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "ws-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://ws.example.com/ws", nil)
|
||||
req.Host = "ws.example.com"
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
body, _ := io.ReadAll(w.Body)
|
||||
bodyStr := string(body)
|
||||
|
||||
// We expect the headers to be passed through to the backend
|
||||
if !strings.Contains(bodyStr, "WebSocket headers received") {
|
||||
t.Errorf("WebSocket headers were not properly forwarded. Got: %s", bodyStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReverseProxyUpstreamHostVariations tests different UpstreamHost formats
|
||||
func TestReverseProxyUpstreamHostVariations(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Parse backend URL to get host
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse backend URL: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
upstreamHost string
|
||||
shouldWork bool
|
||||
}{
|
||||
{"Full URL", backend.URL, true},
|
||||
{"Host only", backendURL.Host, true},
|
||||
{"Empty", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "test.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app",
|
||||
UpstreamHost: tt.upstreamHost,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "http://test.example.com/", nil)
|
||||
req.Host = "test.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if tt.shouldWork {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
} else {
|
||||
if w.Code == http.StatusOK {
|
||||
t.Errorf("Expected failure, but got status 200")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
148
proxy/reverse_proxy_test.go
Normal file
148
proxy/reverse_proxy_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2021 The Casdoor Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDomainWithoutPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"example.com", "example.com"},
|
||||
{"example.com:8080", "example.com"},
|
||||
{"localhost:3000", "localhost"},
|
||||
{"subdomain.example.com:443", "subdomain.example.com"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := getDomainWithoutPort(test.input)
|
||||
if result != test.expected {
|
||||
t.Errorf("getDomainWithoutPort(%s) = %s; want %s", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleReverseProxy(t *testing.T) {
|
||||
// Create a test backend server
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that headers are set correctly
|
||||
if r.Header.Get("X-Forwarded-For") == "" {
|
||||
t.Error("X-Forwarded-For header not set")
|
||||
}
|
||||
if r.Header.Get("X-Forwarded-Proto") == "" {
|
||||
t.Error("X-Forwarded-Proto header not set")
|
||||
}
|
||||
if r.Header.Get("X-Real-IP") == "" {
|
||||
t.Error("X-Real-IP header not set")
|
||||
}
|
||||
if r.Header.Get("X-Forwarded-Host") == "" {
|
||||
t.Error("X-Forwarded-Host header not set")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "Backend response")
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
// Set up a mock application lookup function
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "test.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app",
|
||||
UpstreamHost: backend.URL,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Test successful proxy
|
||||
req := httptest.NewRequest("GET", "http://test.example.com/path", nil)
|
||||
req.Host = "test.example.com"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test domain not found
|
||||
req = httptest.NewRequest("GET", "http://unknown.example.com/path", nil)
|
||||
req.Host = "unknown.example.com"
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404 for unknown domain, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Test application without upstream host
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
if domain == "no-upstream.example.com" {
|
||||
return &Application{
|
||||
Owner: "test-owner",
|
||||
Name: "test-app-no-upstream",
|
||||
UpstreamHost: "",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
req = httptest.NewRequest("GET", "http://no-upstream.example.com/path", nil)
|
||||
req.Host = "no-upstream.example.com"
|
||||
w = httptest.NewRecorder()
|
||||
|
||||
HandleReverseProxy(w, req)
|
||||
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("Expected status 404 for app without upstream, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplicationLookup(t *testing.T) {
|
||||
// Test setting and using the application lookup function
|
||||
called := false
|
||||
SetApplicationLookup(func(domain string) *Application {
|
||||
called = true
|
||||
return &Application{
|
||||
Owner: "test",
|
||||
Name: "app",
|
||||
UpstreamHost: "http://localhost:8080",
|
||||
}
|
||||
})
|
||||
|
||||
if applicationLookup == nil {
|
||||
t.Error("applicationLookup should not be nil after SetApplicationLookup")
|
||||
}
|
||||
|
||||
app := applicationLookup("test.com")
|
||||
if !called {
|
||||
t.Error("applicationLookup function was not called")
|
||||
}
|
||||
if app == nil {
|
||||
t.Error("applicationLookup should return non-nil application")
|
||||
}
|
||||
if app.Owner != "test" {
|
||||
t.Errorf("Expected owner 'test', got '%s'", app.Owner)
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import {InfoCircleTwoTone} from "@ant-design/icons";
|
||||
import * as PaymentBackend from "./backend/PaymentBackend";
|
||||
import * as Setting from "./Setting";
|
||||
import i18next from "i18next";
|
||||
import * as ProductBackend from "./backend/ProductBackend";
|
||||
|
||||
const {Option} = Select;
|
||||
|
||||
@@ -29,6 +30,7 @@ class PaymentEditPage extends React.Component {
|
||||
organizationName: props.organizationName !== undefined ? props.organizationName : props.match.params.organizationName,
|
||||
paymentName: props.match.params.paymentName,
|
||||
payment: null,
|
||||
products: [],
|
||||
isModalVisible: false,
|
||||
isInvoiceLoading: false,
|
||||
mode: props.location.mode !== undefined ? props.location.mode : "edit",
|
||||
@@ -37,6 +39,7 @@ class PaymentEditPage extends React.Component {
|
||||
|
||||
UNSAFE_componentWillMount() {
|
||||
this.getPayment();
|
||||
this.getProducts();
|
||||
}
|
||||
|
||||
getPayment() {
|
||||
@@ -55,6 +58,19 @@ class PaymentEditPage extends React.Component {
|
||||
});
|
||||
}
|
||||
|
||||
getProducts() {
|
||||
ProductBackend.getProducts(this.state.organizationName)
|
||||
.then((res) => {
|
||||
if (res.status === "ok") {
|
||||
this.setState({
|
||||
products: res.data,
|
||||
});
|
||||
} else {
|
||||
Setting.showMessage("error", `Failed to get products: ${res.msg}`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
goToViewOrder() {
|
||||
const payment = this.state.payment;
|
||||
if (payment && payment.order) {
|
||||
@@ -224,6 +240,29 @@ class PaymentEditPage extends React.Component {
|
||||
}} />
|
||||
</Col>
|
||||
</Row>
|
||||
<Row style={{marginTop: "20px"}} >
|
||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||
{Setting.getLabel(i18next.t("general:Products"), i18next.t("payment:Products - Tooltip"))} :
|
||||
</Col>
|
||||
<Col span={22} >
|
||||
<Select
|
||||
mode="multiple"
|
||||
style={{width: "100%"}}
|
||||
value={this.state.payment?.products || []}
|
||||
disabled={isViewMode}
|
||||
allowClear
|
||||
options={(this.state.products || [])
|
||||
.map((p) => ({
|
||||
label: Setting.getLanguageText(p?.displayName) || p?.name,
|
||||
value: p?.name,
|
||||
}))
|
||||
.filter((o) => o.value)}
|
||||
onChange={(value) => {
|
||||
this.updatePaymentField("products", value);
|
||||
}}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
<Row style={{marginTop: "20px"}} >
|
||||
<Col style={{marginTop: "5px"}} span={(Setting.isMobile()) ? 22 : 2}>
|
||||
{Setting.getLabel(i18next.t("order:Price"), i18next.t("plan:Price - Tooltip"))} :
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import React from "react";
|
||||
import {Link} from "react-router-dom";
|
||||
import {Button, Col, List, Row, Table, Tooltip} from "antd";
|
||||
import {Button, List, Table, Tooltip} from "antd";
|
||||
import moment from "moment";
|
||||
import * as Setting from "./Setting";
|
||||
import * as PaymentBackend from "./backend/PaymentBackend";
|
||||
@@ -195,31 +195,21 @@ class PaymentListPage extends BaseListPage {
|
||||
paddingBottom: 8,
|
||||
}}
|
||||
renderItem={(productInfo, i) => {
|
||||
const price = productInfo.price || 0;
|
||||
const number = productInfo.quantity || 1;
|
||||
const price = productInfo.price * (productInfo.quantity || 1);
|
||||
const currency = record.currency || "USD";
|
||||
const productName = productInfo.displayName || productInfo.name;
|
||||
return (
|
||||
<List.Item>
|
||||
<Row style={{width: "100%"}} wrap={false} gutter={[12, 0]}>
|
||||
<Col flex="auto" style={{minWidth: 0}}>
|
||||
<div style={{display: "flex", alignItems: "center", minWidth: 0}}>
|
||||
<Tooltip placement="topLeft" title={i18next.t("general:Edit")}>
|
||||
<Button style={{marginRight: "5px"}} icon={<EditOutlined />} size="small" onClick={() => Setting.goToLinkSoft(this, `/products/${record.owner}/${productInfo.name}`)} />
|
||||
</Tooltip>
|
||||
<Tooltip placement="topLeft" title={productName}>
|
||||
<Link to={`/products/${record.owner}/${productInfo.name}`} style={{display: "inline-block", maxWidth: "100%", minWidth: 0, overflow: "hidden", textOverflow: "ellipsis", whiteSpace: "nowrap"}}>
|
||||
{productName}
|
||||
</Link>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</Col>
|
||||
<Col flex="none" style={{whiteSpace: "nowrap"}}>
|
||||
<span style={{color: "#666"}}>
|
||||
{Setting.getCurrencySymbol(currency)}{price} ({Setting.getCurrencyText(currency)}) × {number}
|
||||
</span>
|
||||
</Col>
|
||||
</Row>
|
||||
<div style={{display: "inline"}}>
|
||||
<Tooltip placement="topLeft" title={i18next.t("general:Edit")}>
|
||||
<Button style={{marginRight: "5px"}} icon={<EditOutlined />} size="small" onClick={() => Setting.goToLinkSoft(this, `/products/${record.owner}/${productInfo.name}`)} />
|
||||
</Tooltip>
|
||||
<Link to={`/products/${record.owner}/${productInfo.name}`}>
|
||||
{productInfo.displayName || productInfo.name}
|
||||
</Link>
|
||||
<span style={{marginLeft: "8px", color: "#666"}}>
|
||||
{Setting.getPriceDisplay(price, currency)}
|
||||
</span>
|
||||
</div>
|
||||
</List.Item>
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -44,15 +44,20 @@ function generateCodeChallenge(verifier) {
|
||||
}
|
||||
|
||||
function storeCodeVerifier(state, verifier) {
|
||||
localStorage.setItem(`pkce_verifier_${state}`, verifier);
|
||||
localStorage.setItem("pkce_verifier", `${state}#${verifier}`);
|
||||
}
|
||||
|
||||
export function getCodeVerifier(state) {
|
||||
return localStorage.getItem(`pkce_verifier_${state}`);
|
||||
const verifierStore = localStorage.getItem("pkce_verifier");
|
||||
const [storedState, verifier] = verifierStore ? verifierStore.split("#") : [null, null];
|
||||
if (storedState !== state) {
|
||||
return null;
|
||||
}
|
||||
return verifier;
|
||||
}
|
||||
|
||||
export function clearCodeVerifier(state) {
|
||||
localStorage.removeItem(`pkce_verifier_${state}`);
|
||||
localStorage.removeItem("pkce_verifier");
|
||||
}
|
||||
|
||||
const authInfo = {
|
||||
|
||||
Reference in New Issue
Block a user