Commit c7b788dd authored by Confusion-ymc's avatar Confusion-ymc

修改鉴权后 返回User对象 而不是dict

parent 85673748
......@@ -9,6 +9,7 @@ from exception.db import NotFundError
from model import Response
from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund
from dependencies import get_current_user, get_fund_collect
from tools.jwt_tools import User
router = APIRouter()
......@@ -19,17 +20,16 @@ router = APIRouter()
description='创建基金')
async def create(
create_fund: CreateFund,
user: dict = Depends(get_current_user),
user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect)
):
if create_fund.fund_type == FundType.staking:
create_model = StakingFund(**create_fund.dict(), **user)
create_model = StakingFund(**create_fund.dict(), **user.dict())
response = Response[StakingFund](data=create_model.dict())
else:
create_model = NormalFund(**create_fund.dict(), **user)
create_model = NormalFund(**create_fund.dict(), **user.dict())
response = Response[NormalFund](data=create_model.dict())
insert_data = create_model.dict()
await fund_collect.insert_one(insert_data)
return response
......@@ -39,14 +39,14 @@ async def create(
async def update(
fund_id: str,
update_fund: UpdateFund,
user: dict = Depends(get_current_user),
user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect)
):
db_update_data = update_fund.dict(exclude_unset=True)
db_update_data.update({
"update_time": int(datetime.datetime.utcnow().timestamp())
})
data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user['user_id']}, {'$set': db_update_data},
data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user.id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER)
if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data))
......@@ -58,10 +58,10 @@ async def update(
description='查询基金')
async def get(
fund_id: str,
user: dict = Depends(get_current_user),
user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect)
):
data = await fund_collect.find_one({'id': fund_id, 'user_id': user['user_id']})
data = await fund_collect.find_one({'id': fund_id, 'user_id': user.id})
if not data:
raise NotFundError()
if data['fund_type'] == FundType.staking:
......
......@@ -8,10 +8,12 @@ from db.mongodb_helper import AioMongodbManager
from tools import jwt_tools
from starlette.requests import Request
from tools.jwt_tools import User
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(jwt_tools.security)) -> dict:
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(jwt_tools.security)) -> User:
if settings.env == 'LOCAL':
return {'user_id': "local_test", 'user_email': "local_test@qq.com"}
return User(id='local_test', email='local_test@qq.com')
return jwt_tools.get_current_user(credentials)
......
......@@ -16,6 +16,38 @@ from tools.http_helper import aio_request
security = HTTPBearer()
class User(object):
def __init__(self, **kwargs):
self.nbf = None
self.exp = None
self.iss = None
self.aud = None
self.client_id = None
self.sub = None
self.auth_time = None
self.idp = None
self.id = None
self.email = None
self.role = None
self.FoundRole = None
self.scope = None
self.amr = None
for k, v in kwargs.items():
if hasattr(self, k):
self.__setattr__(k, v)
@property
def user_id(self):
return self.id
@property
def user_email(self):
return self.email
def dict(self):
return self.__dict__
async def get_identify_key():
"""
生成公钥
......@@ -40,16 +72,16 @@ def decode_token(token):
return payload
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> dict:
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> User:
token = credentials.credentials
try:
assert credentials.scheme == 'Bearer'
payload = decode_token(token) # options={'verify_signature':False}
payload['user_id'] = payload.pop("id", None)
payload['user_email'] = payload.pop("email", None)
if payload['user_id'] is None:
user = User()
user.__dict__ = payload
if not user.id:
raise TokenError('错误的Token')
return payload
return user
except ExpiredSignatureError:
raise TokenError('Token已过期')
except Exception as e:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment