feat: wrap xorm-adapter RemovePolicy to prevent mass deletion on empty fields (#5282)

This commit is contained in:
Modo
2026-03-18 17:32:31 +08:00
committed by GitHub
parent 5965e75610
commit 75bc8e6b0d
3 changed files with 99 additions and 3 deletions

View File

@@ -41,6 +41,7 @@ type Adapter struct {
Database string `xorm:"varchar(100)" json:"database"`
*xormadapter.Adapter `xorm:"-" json:"-"`
engine *xorm.Engine
}
func GetAdapterCount(owner, field, value string) (int64, error) {
@@ -146,7 +147,7 @@ func (adapter *Adapter) GetId() string {
}
func (adapter *Adapter) InitAdapter() error {
if adapter.Adapter != nil {
if adapter.Adapter != nil && adapter.engine != nil {
return nil
}
@@ -199,11 +200,15 @@ func (adapter *Adapter) InitAdapter() error {
tableName := adapter.Table
adapter.Adapter, err = xormadapter.NewAdapterByEngineWithTableName(engine, tableName, "")
xa, err := xormadapter.NewAdapterByEngineWithTableName(engine, tableName, "")
if err != nil {
_ = engine.Close()
return err
}
adapter.engine = engine
adapter.Adapter = xa
return nil
}

91
object/adapter_safe.go Normal file
View File

@@ -0,0 +1,91 @@
package object
import (
xormadapter "github.com/casdoor/xorm-adapter/v3"
"github.com/xorm-io/xorm"
)
type SafeAdapter struct {
*xormadapter.Adapter
engine *xorm.Engine
tableName string
}
func NewSafeAdapter(a *Adapter) *SafeAdapter {
if a == nil || a.Adapter == nil || a.engine == nil {
return nil
}
return &SafeAdapter{
Adapter: a.Adapter,
engine: a.engine,
tableName: a.Table,
}
}
func (a *SafeAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
line := a.buildCasbinRule(ptype, rule)
session := a.engine.NewSession()
defer session.Close()
if a.tableName != "" {
session = session.Table(a.tableName)
}
_, err := session.
MustCols("ptype", "v0", "v1", "v2", "v3", "v4", "v5").
Delete(line)
return err
}
func (a *SafeAdapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
_, err := a.engine.Transaction(func(tx *xorm.Session) (interface{}, error) {
for _, rule := range rules {
line := a.buildCasbinRule(ptype, rule)
var session *xorm.Session
if a.tableName != "" {
session = tx.Table(a.tableName)
} else {
session = tx
}
_, err := session.
MustCols("ptype", "v0", "v1", "v2", "v3", "v4", "v5").
Delete(line)
if err != nil {
return nil, err
}
}
return nil, nil
})
return err
}
func (a *SafeAdapter) buildCasbinRule(ptype string, rule []string) *xormadapter.CasbinRule {
line := xormadapter.CasbinRule{Ptype: ptype}
l := len(rule)
if l > 0 {
line.V0 = rule[0]
}
if l > 1 {
line.V1 = rule[1]
}
if l > 2 {
line.V2 = rule[2]
}
if l > 3 {
line.V3 = rule[3]
}
if l > 4 {
line.V4 = rule[4]
}
if l > 5 {
line.V5 = rule[5]
}
return &line
}

View File

@@ -171,7 +171,7 @@ func (enforcer *Enforcer) InitEnforcer() error {
return err
}
casbinEnforcer, err := casbin.NewEnforcer(m.Model, a.Adapter)
casbinEnforcer, err := casbin.NewEnforcer(m.Model, NewSafeAdapter(a))
if err != nil {
return err
}