Files
casdoor/object/syncer_database.go
2026-04-06 12:26:29 +08:00

184 lines
5.7 KiB
Go

// Copyright 2025 The Casdoor Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package object
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"strings"
"github.com/casdoor/casdoor/util"
"github.com/go-sql-driver/mysql"
"golang.org/x/crypto/ssh"
)
// DatabaseSyncerProvider implements SyncerProvider for database-based syncers
type DatabaseSyncerProvider struct {
Syncer *Syncer
}
// InitAdapter initializes the database adapter
func (p *DatabaseSyncerProvider) InitAdapter() error {
if p.Syncer.Ormer != nil {
return nil
}
var dataSourceName string
if p.Syncer.DatabaseType == "mssql" {
dataSourceName = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", p.Syncer.User, p.Syncer.Password, p.Syncer.Host, p.Syncer.Port, p.Syncer.Database)
} else if p.Syncer.DatabaseType == "postgres" {
sslMode := "disable"
if p.Syncer.SslMode != "" {
sslMode = p.Syncer.SslMode
}
dataSourceName = fmt.Sprintf("user=%s password=%s host=%s port=%d sslmode=%s dbname=%s", p.Syncer.User, p.Syncer.Password, p.Syncer.Host, p.Syncer.Port, sslMode, p.Syncer.Database)
} else {
dataSourceName = fmt.Sprintf("%s:%s@tcp(%s:%d)/", p.Syncer.User, p.Syncer.Password, p.Syncer.Host, p.Syncer.Port)
}
var db *sql.DB
var err error
if p.Syncer.SshType != "" && (p.Syncer.DatabaseType == "mysql" || p.Syncer.DatabaseType == "postgres" || p.Syncer.DatabaseType == "mssql") {
var dial *ssh.Client
if p.Syncer.SshType == "password" {
dial, err = DialWithPassword(p.Syncer.SshUser, p.Syncer.SshPassword, p.Syncer.SshHost, p.Syncer.SshPort)
} else {
dial, err = DialWithCert(p.Syncer.SshUser, p.Syncer.Owner+"/"+p.Syncer.Cert, p.Syncer.SshHost, p.Syncer.SshPort)
}
if err != nil {
return err
}
// Store SSH client for proper cleanup
p.Syncer.SshClient = dial
if p.Syncer.DatabaseType == "mysql" {
dataSourceName = fmt.Sprintf("%s:%s@%s(%s:%d)/", p.Syncer.User, p.Syncer.Password, p.Syncer.Owner+p.Syncer.Name, p.Syncer.Host, p.Syncer.Port)
mysql.RegisterDialContext(p.Syncer.Owner+p.Syncer.Name, (&ViaSSHDialer{Client: dial, Context: nil}).MysqlDial)
} else if p.Syncer.DatabaseType == "postgres" || p.Syncer.DatabaseType == "mssql" {
db = sql.OpenDB(dsnConnector{dsn: dataSourceName, driver: &ViaSSHDialer{Client: dial, Context: nil, DatabaseType: p.Syncer.DatabaseType}})
}
}
if !isCloudIntranet {
dataSourceName = strings.ReplaceAll(dataSourceName, "dbi.", "db.")
}
if db != nil {
p.Syncer.Ormer, err = NewAdapterFromDb(p.Syncer.DatabaseType, dataSourceName, p.Syncer.Database, db)
} else {
p.Syncer.Ormer, err = NewAdapter(p.Syncer.DatabaseType, dataSourceName, p.Syncer.Database)
}
return err
}
// GetOriginalUsers retrieves all users from the database
func (p *DatabaseSyncerProvider) GetOriginalUsers() ([]*OriginalUser, error) {
var results []map[string]sql.NullString
err := p.Syncer.Ormer.Engine.Table(p.Syncer.getTable()).Find(&results)
if err != nil {
return nil, err
}
// Memory leak problem handling
// https://github.com/casdoor/casdoor/issues/1256
users := p.Syncer.getOriginalUsersFromMap(results)
// Clear map contents to help garbage collection
for i := range results {
for k := range results[i] {
delete(results[i], k)
}
}
results = nil
return users, nil
}
// AddUser adds a new user to the database
func (p *DatabaseSyncerProvider) AddUser(user *OriginalUser) (bool, error) {
m := p.Syncer.getMapFromOriginalUser(user)
affected, err := p.Syncer.Ormer.Engine.Table(p.Syncer.getTable()).Insert(m)
if err != nil {
return false, err
}
return affected != 0, nil
}
// UpdateUser updates an existing user in the database
func (p *DatabaseSyncerProvider) UpdateUser(user *OriginalUser) (bool, error) {
key := p.Syncer.getTargetTablePrimaryKey()
if !util.FilterSQLIdentifier(key) {
return false, fmt.Errorf("object.UpdateUser: invalid primary key column name: %s", key)
}
m := p.Syncer.getMapFromOriginalUser(user)
pkValue := m[key]
delete(m, key)
affected, err := p.Syncer.Ormer.Engine.Table(p.Syncer.getTable()).Where(fmt.Sprintf("%s = ?", key), pkValue).Update(&m)
if err != nil {
return false, err
}
return affected != 0, nil
}
// TestConnection tests the database connection
func (p *DatabaseSyncerProvider) TestConnection() error {
err := p.InitAdapter()
if err != nil {
return err
}
err = p.Syncer.Ormer.Engine.Ping()
if err != nil {
return err
}
return nil
}
// Close closes the database connection and SSH tunnel
func (p *DatabaseSyncerProvider) Close() error {
return p.Syncer.Close()
}
type dsnConnector struct {
dsn string
driver driver.Driver
}
func (t dsnConnector) Connect(ctx context.Context) (driver.Conn, error) {
return t.driver.Open(t.dsn)
}
func (t dsnConnector) Driver() driver.Driver {
return t.driver
}
// GetOriginalGroups retrieves all groups from Database (not implemented yet)
func (p *DatabaseSyncerProvider) GetOriginalGroups() ([]*OriginalGroup, error) {
// TODO: Implement Database group sync
return []*OriginalGroup{}, nil
}
// GetOriginalUserGroups retrieves the group IDs that a user belongs to (not implemented yet)
func (p *DatabaseSyncerProvider) GetOriginalUserGroups(userId string) ([]string, error) {
// TODO: Implement Database user group membership sync
return []string{}, nil
}