mirror of
https://gitee.com/XM-GO/PandaX.git
synced 2026-04-29 15:41:25 +08:00
【优化】租户功能,框架优化
This commit is contained in:
@@ -2,7 +2,6 @@ package biz
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"pandax/base/global"
|
||||
"pandax/base/utils"
|
||||
"reflect"
|
||||
)
|
||||
@@ -12,7 +11,6 @@ func ErrIsNil(err error, msg string, params ...any) {
|
||||
if err.Error() == "record not found" {
|
||||
return
|
||||
}
|
||||
global.Log.Error(msg + ": " + err.Error())
|
||||
panic(any(NewBizErr(fmt.Sprintf(msg, params...))))
|
||||
}
|
||||
}
|
||||
@@ -28,7 +26,6 @@ func IsNil(err error) {
|
||||
case *BizError:
|
||||
panic(any(t))
|
||||
case error:
|
||||
global.Log.Error("非业务异常: " + err.Error())
|
||||
panic(any(NewBizErr(fmt.Sprintf("非业务异常: %s", err.Error()))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,16 +4,15 @@ import (
|
||||
"github.com/casbin/casbin/v2"
|
||||
gormadapter "github.com/casbin/gorm-adapter/v3"
|
||||
"pandax/base/biz"
|
||||
"pandax/base/config"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func UpdateCasbin(roleKey string, casbinInfos []CasbinRule) error {
|
||||
ClearCasbin(0, roleKey)
|
||||
func UpdateCasbin(tenantId string, roleKey string, casbinInfos []CasbinRule) error {
|
||||
ClearCasbin(0, tenantId, roleKey)
|
||||
rules := [][]string{}
|
||||
for _, v := range casbinInfos {
|
||||
rules = append(rules, []string{roleKey, v.Path, v.Method})
|
||||
rules = append(rules, []string{tenantId, roleKey, v.Path, v.Method})
|
||||
}
|
||||
e := Casbin()
|
||||
success, _ := e.AddPolicies(rules)
|
||||
@@ -22,20 +21,20 @@ func UpdateCasbin(roleKey string, casbinInfos []CasbinRule) error {
|
||||
}
|
||||
|
||||
func UpdateCasbinApi(oldPath string, newPath string, oldMethod string, newMethod string) {
|
||||
err := global.Db.Table("casbin_rule").Model(&CasbinRule{}).Where("v1 = ? AND v2 = ?", oldPath, oldMethod).Updates(map[string]any{
|
||||
"v1": newPath,
|
||||
"v2": newMethod,
|
||||
err := global.Db.Table("casbin_rule").Model(&CasbinRule{}).Where("v2 = ? AND v3 = ?", oldPath, oldMethod).Updates(map[string]any{
|
||||
"v2": newPath,
|
||||
"v3": newMethod,
|
||||
}).Error
|
||||
biz.ErrIsNil(err, "修改api失败")
|
||||
}
|
||||
|
||||
func GetPolicyPathByRoleId(roleKey string) (pathMaps []CasbinRule) {
|
||||
func GetPolicyPathByRoleId(tenantId, roleKey string) (pathMaps []CasbinRule) {
|
||||
e := Casbin()
|
||||
list := e.GetFilteredPolicy(0, roleKey)
|
||||
list := e.GetFilteredPolicy(0, tenantId, roleKey)
|
||||
for _, v := range list {
|
||||
pathMaps = append(pathMaps, CasbinRule{
|
||||
Path: v[1],
|
||||
Method: v[2],
|
||||
Path: v[2],
|
||||
Method: v[3],
|
||||
})
|
||||
}
|
||||
return pathMaps
|
||||
@@ -57,7 +56,7 @@ func Casbin() *casbin.SyncedEnforcer {
|
||||
once.Do(func() {
|
||||
a, err := gormadapter.NewAdapterByDB(global.Db)
|
||||
biz.ErrIsNil(err, "新建权限适配器失败")
|
||||
syncedEnforcer, err = casbin.NewSyncedEnforcer(config.Conf.Casbin.ModelPath, a)
|
||||
syncedEnforcer, err = casbin.NewSyncedEnforcer(global.Conf.Casbin.ModelPath, a)
|
||||
biz.ErrIsNil(err, "新建权限适配器失败")
|
||||
})
|
||||
_ = syncedEnforcer.LoadPolicy()
|
||||
|
||||
@@ -8,23 +8,19 @@ import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// 配置文件映射对象
|
||||
var Conf *Config
|
||||
|
||||
func init() {
|
||||
configFilePath := flag.String("c", "./config.yml", "配置文件路径,默认为可执行文件目录")
|
||||
flag.Parse()
|
||||
func InitConfig(configFilePath string) *Config {
|
||||
// 获取启动参数中,配置文件的绝对路径
|
||||
path, _ := filepath.Abs(*configFilePath)
|
||||
path, _ := filepath.Abs(configFilePath)
|
||||
startConfigParam = &CmdConfigParam{ConfigFilePath: path}
|
||||
// 读取配置文件信息
|
||||
yc := &Config{}
|
||||
if err := utils.LoadYml(startConfigParam.ConfigFilePath, yc); err != nil {
|
||||
panic(fmt.Sprintf("读取配置文件[%s]失败: %s", startConfigParam.ConfigFilePath, err.Error()))
|
||||
panic(any(fmt.Sprintf("读取配置文件[%s]失败: %s", startConfigParam.ConfigFilePath, err.Error())))
|
||||
}
|
||||
// 校验配置文件内容信息
|
||||
yc.Valid()
|
||||
Conf = yc
|
||||
return yc
|
||||
|
||||
}
|
||||
|
||||
// 启动配置参数
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"pandax/base/biz"
|
||||
"pandax/base/logger"
|
||||
"pandax/base/utils"
|
||||
"pandax/pkg/global"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
|
||||
@@ -42,10 +42,10 @@ func LogHandler(rc *ReqCtx) error {
|
||||
lfs[req.Method] = req.URL.Path
|
||||
|
||||
if err := rc.Err; err != nil {
|
||||
logger.Log.WithFields(lfs).Error(getErrMsg(rc, err))
|
||||
global.Log.WithFields(lfs).Error(getErrMsg(rc, err))
|
||||
return nil
|
||||
}
|
||||
logger.Log.WithFields(lfs).Info(getLogMsg(rc))
|
||||
global.Log.WithFields(lfs).Info(getLogMsg(rc))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package ctx
|
||||
|
||||
import (
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"pandax/base/biz"
|
||||
"pandax/base/casbin"
|
||||
"pandax/pkg/global"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Permission struct {
|
||||
@@ -34,7 +37,8 @@ func PermissionHandler(rc *ReqCtx) error {
|
||||
if tokenStr == "" {
|
||||
return biz.PermissionErr
|
||||
}
|
||||
loginAccount, err := ParseToken(tokenStr)
|
||||
j := NewJWT("", []byte(global.Conf.Jwt.Key), jwt.SigningMethodHS256)
|
||||
loginAccount, err := j.ParseToken(tokenStr)
|
||||
if err != nil || loginAccount == nil {
|
||||
return biz.PermissionErr
|
||||
}
|
||||
@@ -45,7 +49,8 @@ func PermissionHandler(rc *ReqCtx) error {
|
||||
}
|
||||
e := casbin.Casbin()
|
||||
// 判断策略中是否存在
|
||||
success, _ := e.Enforce(loginAccount.RoleKey, rc.GinCtx.Request.URL.Path, rc.GinCtx.Request.Method)
|
||||
tenantId := strconv.Itoa(int(rc.LoginAccount.TenantId))
|
||||
success, err := e.Enforce(tenantId, loginAccount.RoleKey, rc.GinCtx.Request.URL.Path, rc.GinCtx.Request.Method)
|
||||
if !success {
|
||||
return biz.CasbinErr
|
||||
}
|
||||
|
||||
@@ -2,17 +2,15 @@ package ctx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"pandax/base/config"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
jwtSecret = []byte(config.Conf.Jwt.Key)
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
UserId int64
|
||||
TenantId int64
|
||||
UserName string
|
||||
RoleId int64
|
||||
RoleKey string
|
||||
@@ -21,43 +19,116 @@ type Claims struct {
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
func CreateToken(claims Claims) (string, error) {
|
||||
|
||||
tokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
token, err := tokenClaims.SignedString(jwtSecret)
|
||||
|
||||
return token, err
|
||||
type JWT struct {
|
||||
SignedKeyID string
|
||||
SignedKey []byte
|
||||
SignedMethod jwt.SigningMethod
|
||||
}
|
||||
|
||||
func ParseToken(token string) (*Claims, error) {
|
||||
tokenClaims, err := jwt.ParseWithClaims(token, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
return jwtSecret, nil
|
||||
})
|
||||
var (
|
||||
TokenExpired = errors.New("token is expired")
|
||||
TokenNotValidYet = errors.New("token not active yet")
|
||||
TokenMalformed = errors.New("that's not even a token")
|
||||
TokenInvalid = errors.New("couldn't handle this token")
|
||||
)
|
||||
|
||||
if tokenClaims != nil {
|
||||
if claims, ok := tokenClaims.Claims.(*Claims); ok && tokenClaims.Valid {
|
||||
return claims, nil
|
||||
func NewJWT(kid string, key []byte, method jwt.SigningMethod) *JWT {
|
||||
return &JWT{
|
||||
SignedKeyID: kid,
|
||||
SignedKey: key,
|
||||
SignedMethod: method,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateToken 创建一个token
|
||||
func (j *JWT) CreateToken(claims Claims) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &claims)
|
||||
var key interface{}
|
||||
if j.isEs() {
|
||||
v, err := jwt.ParseECPrivateKeyFromPEM(j.SignedKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key = v
|
||||
} else if j.isRsOrPS() {
|
||||
v, err := jwt.ParseRSAPrivateKeyFromPEM(j.SignedKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
key = v
|
||||
} else if j.isHs() {
|
||||
key = j.SignedKey
|
||||
} else {
|
||||
return "", errors.New("unsupported sign method")
|
||||
}
|
||||
return token.SignedString(key)
|
||||
}
|
||||
|
||||
// ParseToken 解析 token
|
||||
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("parse error")
|
||||
}
|
||||
return j.SignedKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||
return nil, TokenMalformed
|
||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||
// Token is expired
|
||||
return nil, TokenExpired
|
||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||
return nil, TokenNotValidYet
|
||||
} else {
|
||||
return nil, TokenInvalid
|
||||
}
|
||||
}
|
||||
}
|
||||
if token != nil {
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
return nil, TokenInvalid
|
||||
|
||||
} else {
|
||||
return nil, TokenInvalid
|
||||
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 更新token
|
||||
func RefreshToken(tokenString string) (string, error) {
|
||||
func (j *JWT) RefreshToken(tokenString string) (string, error) {
|
||||
jwt.TimeFunc = func() time.Time {
|
||||
return time.Unix(0, 0)
|
||||
}
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
return jwtSecret, nil
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return j.SignedKey, nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
jwt.TimeFunc = time.Now
|
||||
claims.StandardClaims.ExpiresAt = time.Now().Add(1 * time.Hour).Unix()
|
||||
return CreateToken(*claims)
|
||||
claims.StandardClaims.ExpiresAt = time.Now().Unix() + 60*60*24*7
|
||||
return j.CreateToken(*claims)
|
||||
}
|
||||
return "", errors.New("Couldn't handle this token:")
|
||||
return "", TokenInvalid
|
||||
}
|
||||
|
||||
func (a *JWT) isEs() bool {
|
||||
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
|
||||
}
|
||||
|
||||
func (a *JWT) isRsOrPS() bool {
|
||||
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
|
||||
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
|
||||
return isRs || isPs
|
||||
}
|
||||
|
||||
func (a *JWT) isHs() bool {
|
||||
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"pandax/base/biz"
|
||||
"pandax/base/global"
|
||||
"pandax/base/model"
|
||||
"pandax/pkg/global"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package global
|
||||
|
||||
import (
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
Log *logrus.Logger // 日志
|
||||
Db *gorm.DB // gorm
|
||||
)
|
||||
@@ -4,24 +4,21 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"pandax/base/config"
|
||||
"pandax/base/global"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var Log = logrus.New()
|
||||
|
||||
func init() {
|
||||
func InitLog(logConf *config.Log) *logrus.Logger {
|
||||
var Log = logrus.New()
|
||||
Log.SetFormatter(new(LogFormatter))
|
||||
Log.SetReportCaller(true)
|
||||
|
||||
logConf := config.Conf.Log
|
||||
// 如果不存在日志配置信息,则默认debug级别
|
||||
if logConf == nil {
|
||||
Log.SetLevel(logrus.DebugLevel)
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
// 根据配置文件设置日志级别
|
||||
@@ -44,8 +41,7 @@ func init() {
|
||||
|
||||
Log.Out = file
|
||||
}
|
||||
|
||||
global.Log = Log
|
||||
return Log
|
||||
}
|
||||
|
||||
type LogFormatter struct{}
|
||||
|
||||
@@ -5,9 +5,10 @@ type AppContext struct {
|
||||
|
||||
type LoginAccount struct {
|
||||
UserId int64
|
||||
TenantId int64
|
||||
RoleId int64
|
||||
DeptId int64
|
||||
PostId int64
|
||||
Username string
|
||||
Rolename string
|
||||
RoleKey string
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package model
|
||||
import (
|
||||
"fmt"
|
||||
"pandax/base/biz"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
"strconv"
|
||||
|
||||
"strings"
|
||||
@@ -48,7 +48,9 @@ func (m *Model) SetBaseInfo(account *LoginAccount) {
|
||||
func Tx(funcs ...func(db *gorm.DB) error) (err error) {
|
||||
tx := global.Db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var err any
|
||||
err = recover()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err = fmt.Errorf("%v", err)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"pandax/base/config"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
"time"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
@@ -23,7 +22,7 @@ func GormInit(ty string) *gorm.DB {
|
||||
return nil
|
||||
}
|
||||
func GormMysql() *gorm.DB {
|
||||
m := config.Conf.Mysql
|
||||
m := global.Conf.Mysql
|
||||
if m == nil || m.Dbname == "" {
|
||||
global.Log.Panic("未找到数据库配置信息")
|
||||
return nil
|
||||
@@ -50,7 +49,7 @@ func GormMysql() *gorm.DB {
|
||||
}
|
||||
|
||||
func GormPostgresql() *gorm.DB {
|
||||
m := config.Conf.Postgresql
|
||||
m := global.Conf.Postgresql
|
||||
if m == nil || m.Dbname == "" {
|
||||
global.Log.Panic("未找到数据库配置信息")
|
||||
return nil
|
||||
|
||||
@@ -2,15 +2,14 @@ package starter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"pandax/base/config"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
|
||||
"github.com/go-redis/redis"
|
||||
)
|
||||
|
||||
func ConnRedis() *redis.Client {
|
||||
// 设置redis客户端
|
||||
redisConf := config.Conf.Redis
|
||||
redisConf := global.Conf.Redis
|
||||
if redisConf == nil {
|
||||
global.Log.Panic("未找到redis配置信息")
|
||||
}
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
package starter
|
||||
|
||||
import (
|
||||
"pandax/base/config"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RunWebServer(web *gin.Engine) {
|
||||
server := config.Conf.Server
|
||||
server := global.Conf.Server
|
||||
port := server.GetPort()
|
||||
if app := config.Conf.App; app != nil {
|
||||
if app := global.Conf.App; app != nil {
|
||||
global.Log.Infof("%s- Listening and serving HTTP on %s", app.GetAppInfo(), port)
|
||||
} else {
|
||||
global.Log.Infof("Listening and serving HTTP on %s", port)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"pandax/base/global"
|
||||
"regexp"
|
||||
"sync"
|
||||
)
|
||||
@@ -33,7 +32,6 @@ func GetRegexp(pattern string) (regex *regexp.Regexp, err error) {
|
||||
// it compiles the pattern and creates one.
|
||||
regex, err = regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
global.Log.Warnf(`regexp.Compile failed for pattern "%s"`, pattern)
|
||||
return
|
||||
}
|
||||
// Cache the result object using writing lock.
|
||||
|
||||
@@ -3,7 +3,7 @@ package ws
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"pandax/base/global"
|
||||
"pandax/pkg/global"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
Reference in New Issue
Block a user