【优化】租户功能,框架优化

This commit is contained in:
PandaGoAdmin
2022-07-18 18:17:11 +08:00
parent d33bd39570
commit ae38e7bcef
60 changed files with 861 additions and 647 deletions

View File

@@ -2,7 +2,6 @@ package biz
import (
"fmt"
"pandax/base/global"
"pandax/base/utils"
"reflect"
)
@@ -12,7 +11,6 @@ func ErrIsNil(err error, msg string, params ...any) {
if err.Error() == "record not found" {
return
}
global.Log.Error(msg + ": " + err.Error())
panic(any(NewBizErr(fmt.Sprintf(msg, params...))))
}
}
@@ -28,7 +26,6 @@ func IsNil(err error) {
case *BizError:
panic(any(t))
case error:
global.Log.Error("非业务异常: " + err.Error())
panic(any(NewBizErr(fmt.Sprintf("非业务异常: %s", err.Error()))))
}
}

View File

@@ -4,16 +4,15 @@ import (
"github.com/casbin/casbin/v2"
gormadapter "github.com/casbin/gorm-adapter/v3"
"pandax/base/biz"
"pandax/base/config"
"pandax/base/global"
"pandax/pkg/global"
"sync"
)
func UpdateCasbin(roleKey string, casbinInfos []CasbinRule) error {
ClearCasbin(0, roleKey)
func UpdateCasbin(tenantId string, roleKey string, casbinInfos []CasbinRule) error {
ClearCasbin(0, tenantId, roleKey)
rules := [][]string{}
for _, v := range casbinInfos {
rules = append(rules, []string{roleKey, v.Path, v.Method})
rules = append(rules, []string{tenantId, roleKey, v.Path, v.Method})
}
e := Casbin()
success, _ := e.AddPolicies(rules)
@@ -22,20 +21,20 @@ func UpdateCasbin(roleKey string, casbinInfos []CasbinRule) error {
}
func UpdateCasbinApi(oldPath string, newPath string, oldMethod string, newMethod string) {
err := global.Db.Table("casbin_rule").Model(&CasbinRule{}).Where("v1 = ? AND v2 = ?", oldPath, oldMethod).Updates(map[string]any{
"v1": newPath,
"v2": newMethod,
err := global.Db.Table("casbin_rule").Model(&CasbinRule{}).Where("v2 = ? AND v3 = ?", oldPath, oldMethod).Updates(map[string]any{
"v2": newPath,
"v3": newMethod,
}).Error
biz.ErrIsNil(err, "修改api失败")
}
func GetPolicyPathByRoleId(roleKey string) (pathMaps []CasbinRule) {
func GetPolicyPathByRoleId(tenantId, roleKey string) (pathMaps []CasbinRule) {
e := Casbin()
list := e.GetFilteredPolicy(0, roleKey)
list := e.GetFilteredPolicy(0, tenantId, roleKey)
for _, v := range list {
pathMaps = append(pathMaps, CasbinRule{
Path: v[1],
Method: v[2],
Path: v[2],
Method: v[3],
})
}
return pathMaps
@@ -57,7 +56,7 @@ func Casbin() *casbin.SyncedEnforcer {
once.Do(func() {
a, err := gormadapter.NewAdapterByDB(global.Db)
biz.ErrIsNil(err, "新建权限适配器失败")
syncedEnforcer, err = casbin.NewSyncedEnforcer(config.Conf.Casbin.ModelPath, a)
syncedEnforcer, err = casbin.NewSyncedEnforcer(global.Conf.Casbin.ModelPath, a)
biz.ErrIsNil(err, "新建权限适配器失败")
})
_ = syncedEnforcer.LoadPolicy()

View File

@@ -8,23 +8,19 @@ import (
"path/filepath"
)
// 配置文件映射对象
var Conf *Config
func init() {
configFilePath := flag.String("c", "./config.yml", "配置文件路径,默认为可执行文件目录")
flag.Parse()
func InitConfig(configFilePath string) *Config {
// 获取启动参数中,配置文件的绝对路径
path, _ := filepath.Abs(*configFilePath)
path, _ := filepath.Abs(configFilePath)
startConfigParam = &CmdConfigParam{ConfigFilePath: path}
// 读取配置文件信息
yc := &Config{}
if err := utils.LoadYml(startConfigParam.ConfigFilePath, yc); err != nil {
panic(fmt.Sprintf("读取配置文件[%s]失败: %s", startConfigParam.ConfigFilePath, err.Error()))
panic(any(fmt.Sprintf("读取配置文件[%s]失败: %s", startConfigParam.ConfigFilePath, err.Error())))
}
// 校验配置文件内容信息
yc.Valid()
Conf = yc
return yc
}
// 启动配置参数

View File

@@ -4,8 +4,8 @@ import (
"encoding/json"
"fmt"
"pandax/base/biz"
"pandax/base/logger"
"pandax/base/utils"
"pandax/pkg/global"
"reflect"
"runtime/debug"
@@ -42,10 +42,10 @@ func LogHandler(rc *ReqCtx) error {
lfs[req.Method] = req.URL.Path
if err := rc.Err; err != nil {
logger.Log.WithFields(lfs).Error(getErrMsg(rc, err))
global.Log.WithFields(lfs).Error(getErrMsg(rc, err))
return nil
}
logger.Log.WithFields(lfs).Info(getLogMsg(rc))
global.Log.WithFields(lfs).Info(getLogMsg(rc))
return nil
}

View File

@@ -1,8 +1,11 @@
package ctx
import (
"github.com/dgrijalva/jwt-go"
"pandax/base/biz"
"pandax/base/casbin"
"pandax/pkg/global"
"strconv"
)
type Permission struct {
@@ -34,7 +37,8 @@ func PermissionHandler(rc *ReqCtx) error {
if tokenStr == "" {
return biz.PermissionErr
}
loginAccount, err := ParseToken(tokenStr)
j := NewJWT("", []byte(global.Conf.Jwt.Key), jwt.SigningMethodHS256)
loginAccount, err := j.ParseToken(tokenStr)
if err != nil || loginAccount == nil {
return biz.PermissionErr
}
@@ -45,7 +49,8 @@ func PermissionHandler(rc *ReqCtx) error {
}
e := casbin.Casbin()
// 判断策略中是否存在
success, _ := e.Enforce(loginAccount.RoleKey, rc.GinCtx.Request.URL.Path, rc.GinCtx.Request.Method)
tenantId := strconv.Itoa(int(rc.LoginAccount.TenantId))
success, err := e.Enforce(tenantId, loginAccount.RoleKey, rc.GinCtx.Request.URL.Path, rc.GinCtx.Request.Method)
if !success {
return biz.CasbinErr
}

View File

@@ -2,17 +2,15 @@ package ctx
import (
"errors"
"fmt"
"github.com/dgrijalva/jwt-go"
"pandax/base/config"
"strings"
"time"
)
var (
jwtSecret = []byte(config.Conf.Jwt.Key)
)
type Claims struct {
UserId int64
TenantId int64
UserName string
RoleId int64
RoleKey string
@@ -21,43 +19,116 @@ type Claims struct {
jwt.StandardClaims
}
func CreateToken(claims Claims) (string, error) {
tokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token, err := tokenClaims.SignedString(jwtSecret)
return token, err
type JWT struct {
SignedKeyID string
SignedKey []byte
SignedMethod jwt.SigningMethod
}
func ParseToken(token string) (*Claims, error) {
tokenClaims, err := jwt.ParseWithClaims(token, &Claims{}, func(token *jwt.Token) (any, error) {
return jwtSecret, nil
})
var (
TokenExpired = errors.New("token is expired")
TokenNotValidYet = errors.New("token not active yet")
TokenMalformed = errors.New("that's not even a token")
TokenInvalid = errors.New("couldn't handle this token")
)
if tokenClaims != nil {
if claims, ok := tokenClaims.Claims.(*Claims); ok && tokenClaims.Valid {
return claims, nil
func NewJWT(kid string, key []byte, method jwt.SigningMethod) *JWT {
return &JWT{
SignedKeyID: kid,
SignedKey: key,
SignedMethod: method,
}
}
// CreateToken 创建一个token
func (j *JWT) CreateToken(claims Claims) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, &claims)
var key interface{}
if j.isEs() {
v, err := jwt.ParseECPrivateKeyFromPEM(j.SignedKey)
if err != nil {
return "", err
}
key = v
} else if j.isRsOrPS() {
v, err := jwt.ParseRSAPrivateKeyFromPEM(j.SignedKey)
if err != nil {
return "", err
}
key = v
} else if j.isHs() {
key = j.SignedKey
} else {
return "", errors.New("unsupported sign method")
}
return token.SignedString(key)
}
// ParseToken 解析 token
func (j *JWT) ParseToken(tokenString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (i interface{}, e error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("parse error")
}
return j.SignedKey, nil
})
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return nil, TokenMalformed
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
// Token is expired
return nil, TokenExpired
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
return nil, TokenNotValidYet
} else {
return nil, TokenInvalid
}
}
}
if token != nil {
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, TokenInvalid
} else {
return nil, TokenInvalid
}
return nil, err
}
// 更新token
func RefreshToken(tokenString string) (string, error) {
func (j *JWT) RefreshToken(tokenString string) (string, error) {
jwt.TimeFunc = func() time.Time {
return time.Unix(0, 0)
}
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) {
return jwtSecret, nil
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return j.SignedKey, nil
})
if err != nil {
return "", err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
jwt.TimeFunc = time.Now
claims.StandardClaims.ExpiresAt = time.Now().Add(1 * time.Hour).Unix()
return CreateToken(*claims)
claims.StandardClaims.ExpiresAt = time.Now().Unix() + 60*60*24*7
return j.CreateToken(*claims)
}
return "", errors.New("Couldn't handle this token:")
return "", TokenInvalid
}
func (a *JWT) isEs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "ES")
}
func (a *JWT) isRsOrPS() bool {
isRs := strings.HasPrefix(a.SignedMethod.Alg(), "RS")
isPs := strings.HasPrefix(a.SignedMethod.Alg(), "PS")
return isRs || isPs
}
func (a *JWT) isHs() bool {
return strings.HasPrefix(a.SignedMethod.Alg(), "HS")
}

View File

@@ -4,8 +4,8 @@ import (
"encoding/json"
"net/http"
"pandax/base/biz"
"pandax/base/global"
"pandax/base/model"
"pandax/pkg/global"
"strconv"
"github.com/gin-gonic/gin"

View File

@@ -1,11 +0,0 @@
package global
import (
"github.com/sirupsen/logrus"
"gorm.io/gorm"
)
var (
Log *logrus.Logger // 日志
Db *gorm.DB // gorm
)

View File

@@ -4,24 +4,21 @@ import (
"fmt"
"os"
"pandax/base/config"
"pandax/base/global"
"strings"
"time"
"github.com/sirupsen/logrus"
)
var Log = logrus.New()
func init() {
func InitLog(logConf *config.Log) *logrus.Logger {
var Log = logrus.New()
Log.SetFormatter(new(LogFormatter))
Log.SetReportCaller(true)
logConf := config.Conf.Log
// 如果不存在日志配置信息则默认debug级别
if logConf == nil {
Log.SetLevel(logrus.DebugLevel)
return
return nil
}
// 根据配置文件设置日志级别
@@ -44,8 +41,7 @@ func init() {
Log.Out = file
}
global.Log = Log
return Log
}
type LogFormatter struct{}

View File

@@ -5,9 +5,10 @@ type AppContext struct {
type LoginAccount struct {
UserId int64
TenantId int64
RoleId int64
DeptId int64
PostId int64
Username string
Rolename string
RoleKey string
}

View File

@@ -3,7 +3,7 @@ package model
import (
"fmt"
"pandax/base/biz"
"pandax/base/global"
"pandax/pkg/global"
"strconv"
"strings"
@@ -48,7 +48,9 @@ func (m *Model) SetBaseInfo(account *LoginAccount) {
func Tx(funcs ...func(db *gorm.DB) error) (err error) {
tx := global.Db.Begin()
defer func() {
if r := recover(); r != nil {
var err any
err = recover()
if err != nil {
tx.Rollback()
err = fmt.Errorf("%v", err)
}

View File

@@ -6,8 +6,7 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"pandax/base/config"
"pandax/base/global"
"pandax/pkg/global"
"time"
_ "github.com/lib/pq"
@@ -23,7 +22,7 @@ func GormInit(ty string) *gorm.DB {
return nil
}
func GormMysql() *gorm.DB {
m := config.Conf.Mysql
m := global.Conf.Mysql
if m == nil || m.Dbname == "" {
global.Log.Panic("未找到数据库配置信息")
return nil
@@ -50,7 +49,7 @@ func GormMysql() *gorm.DB {
}
func GormPostgresql() *gorm.DB {
m := config.Conf.Postgresql
m := global.Conf.Postgresql
if m == nil || m.Dbname == "" {
global.Log.Panic("未找到数据库配置信息")
return nil

View File

@@ -2,15 +2,14 @@ package starter
import (
"fmt"
"pandax/base/config"
"pandax/base/global"
"pandax/pkg/global"
"github.com/go-redis/redis"
)
func ConnRedis() *redis.Client {
// 设置redis客户端
redisConf := config.Conf.Redis
redisConf := global.Conf.Redis
if redisConf == nil {
global.Log.Panic("未找到redis配置信息")
}

View File

@@ -1,16 +1,15 @@
package starter
import (
"pandax/base/config"
"pandax/base/global"
"pandax/pkg/global"
"github.com/gin-gonic/gin"
)
func RunWebServer(web *gin.Engine) {
server := config.Conf.Server
server := global.Conf.Server
port := server.GetPort()
if app := config.Conf.App; app != nil {
if app := global.Conf.App; app != nil {
global.Log.Infof("%s- Listening and serving HTTP on %s", app.GetAppInfo(), port)
} else {
global.Log.Infof("Listening and serving HTTP on %s", port)

View File

@@ -1,7 +1,6 @@
package utils
import (
"pandax/base/global"
"regexp"
"sync"
)
@@ -33,7 +32,6 @@ func GetRegexp(pattern string) (regex *regexp.Regexp, err error) {
// it compiles the pattern and creates one.
regex, err = regexp.Compile(pattern)
if err != nil {
global.Log.Warnf(`regexp.Compile failed for pattern "%s"`, pattern)
return
}
// Cache the result object using writing lock.

View File

@@ -3,7 +3,7 @@ package ws
import (
"encoding/json"
"net/http"
"pandax/base/global"
"pandax/pkg/global"
"time"
"github.com/gorilla/websocket"