forked from casdoor/casdoor
feat: wrap xorm-adapter RemovePolicy to prevent mass deletion on empty fields (#5282)
This commit is contained in:
@@ -41,6 +41,7 @@ type Adapter struct {
|
||||
Database string `xorm:"varchar(100)" json:"database"`
|
||||
|
||||
*xormadapter.Adapter `xorm:"-" json:"-"`
|
||||
engine *xorm.Engine
|
||||
}
|
||||
|
||||
func GetAdapterCount(owner, field, value string) (int64, error) {
|
||||
@@ -146,7 +147,7 @@ func (adapter *Adapter) GetId() string {
|
||||
}
|
||||
|
||||
func (adapter *Adapter) InitAdapter() error {
|
||||
if adapter.Adapter != nil {
|
||||
if adapter.Adapter != nil && adapter.engine != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -199,11 +200,15 @@ func (adapter *Adapter) InitAdapter() error {
|
||||
|
||||
tableName := adapter.Table
|
||||
|
||||
adapter.Adapter, err = xormadapter.NewAdapterByEngineWithTableName(engine, tableName, "")
|
||||
xa, err := xormadapter.NewAdapterByEngineWithTableName(engine, tableName, "")
|
||||
if err != nil {
|
||||
_ = engine.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
adapter.engine = engine
|
||||
adapter.Adapter = xa
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
91
object/adapter_safe.go
Normal file
91
object/adapter_safe.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package object
|
||||
|
||||
import (
|
||||
xormadapter "github.com/casdoor/xorm-adapter/v3"
|
||||
"github.com/xorm-io/xorm"
|
||||
)
|
||||
|
||||
type SafeAdapter struct {
|
||||
*xormadapter.Adapter
|
||||
engine *xorm.Engine
|
||||
tableName string
|
||||
}
|
||||
|
||||
func NewSafeAdapter(a *Adapter) *SafeAdapter {
|
||||
if a == nil || a.Adapter == nil || a.engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &SafeAdapter{
|
||||
Adapter: a.Adapter,
|
||||
engine: a.engine,
|
||||
tableName: a.Table,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *SafeAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
|
||||
line := a.buildCasbinRule(ptype, rule)
|
||||
|
||||
session := a.engine.NewSession()
|
||||
defer session.Close()
|
||||
|
||||
if a.tableName != "" {
|
||||
session = session.Table(a.tableName)
|
||||
}
|
||||
|
||||
_, err := session.
|
||||
MustCols("ptype", "v0", "v1", "v2", "v3", "v4", "v5").
|
||||
Delete(line)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *SafeAdapter) RemovePolicies(sec string, ptype string, rules [][]string) error {
|
||||
_, err := a.engine.Transaction(func(tx *xorm.Session) (interface{}, error) {
|
||||
for _, rule := range rules {
|
||||
line := a.buildCasbinRule(ptype, rule)
|
||||
|
||||
var session *xorm.Session
|
||||
if a.tableName != "" {
|
||||
session = tx.Table(a.tableName)
|
||||
} else {
|
||||
session = tx
|
||||
}
|
||||
|
||||
_, err := session.
|
||||
MustCols("ptype", "v0", "v1", "v2", "v3", "v4", "v5").
|
||||
Delete(line)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *SafeAdapter) buildCasbinRule(ptype string, rule []string) *xormadapter.CasbinRule {
|
||||
line := xormadapter.CasbinRule{Ptype: ptype}
|
||||
|
||||
l := len(rule)
|
||||
if l > 0 {
|
||||
line.V0 = rule[0]
|
||||
}
|
||||
if l > 1 {
|
||||
line.V1 = rule[1]
|
||||
}
|
||||
if l > 2 {
|
||||
line.V2 = rule[2]
|
||||
}
|
||||
if l > 3 {
|
||||
line.V3 = rule[3]
|
||||
}
|
||||
if l > 4 {
|
||||
line.V4 = rule[4]
|
||||
}
|
||||
if l > 5 {
|
||||
line.V5 = rule[5]
|
||||
}
|
||||
|
||||
return &line
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func (enforcer *Enforcer) InitEnforcer() error {
|
||||
return err
|
||||
}
|
||||
|
||||
casbinEnforcer, err := casbin.NewEnforcer(m.Model, a.Adapter)
|
||||
casbinEnforcer, err := casbin.NewEnforcer(m.Model, NewSafeAdapter(a))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user