Commit c7b788dd authored by Confusion-ymc's avatar Confusion-ymc

修改鉴权后 返回User对象 而不是dict

parent 85673748
...@@ -9,6 +9,7 @@ from exception.db import NotFundError ...@@ -9,6 +9,7 @@ from exception.db import NotFundError
from model import Response from model import Response
from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund
from dependencies import get_current_user, get_fund_collect from dependencies import get_current_user, get_fund_collect
from tools.jwt_tools import User
router = APIRouter() router = APIRouter()
...@@ -19,17 +20,16 @@ router = APIRouter() ...@@ -19,17 +20,16 @@ router = APIRouter()
description='创建基金') description='创建基金')
async def create( async def create(
create_fund: CreateFund, create_fund: CreateFund,
user: dict = Depends(get_current_user), user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect) fund_collect: AgnosticCollection = Depends(get_fund_collect)
): ):
if create_fund.fund_type == FundType.staking: if create_fund.fund_type == FundType.staking:
create_model = StakingFund(**create_fund.dict(), **user) create_model = StakingFund(**create_fund.dict(), **user.dict())
response = Response[StakingFund](data=create_model.dict()) response = Response[StakingFund](data=create_model.dict())
else: else:
create_model = NormalFund(**create_fund.dict(), **user) create_model = NormalFund(**create_fund.dict(), **user.dict())
response = Response[NormalFund](data=create_model.dict()) response = Response[NormalFund](data=create_model.dict())
insert_data = create_model.dict() insert_data = create_model.dict()
await fund_collect.insert_one(insert_data) await fund_collect.insert_one(insert_data)
return response return response
...@@ -39,14 +39,14 @@ async def create( ...@@ -39,14 +39,14 @@ async def create(
async def update( async def update(
fund_id: str, fund_id: str,
update_fund: UpdateFund, update_fund: UpdateFund,
user: dict = Depends(get_current_user), user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect) fund_collect: AgnosticCollection = Depends(get_fund_collect)
): ):
db_update_data = update_fund.dict(exclude_unset=True) db_update_data = update_fund.dict(exclude_unset=True)
db_update_data.update({ db_update_data.update({
"update_time": int(datetime.datetime.utcnow().timestamp()) "update_time": int(datetime.datetime.utcnow().timestamp())
}) })
data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user['user_id']}, {'$set': db_update_data}, data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user.id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER) return_document=ReturnDocument.AFTER)
if data['fund_type'] == FundType.staking: if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data)) return Response[StakingFund](data=StakingFund(**data))
...@@ -58,10 +58,10 @@ async def update( ...@@ -58,10 +58,10 @@ async def update(
description='查询基金') description='查询基金')
async def get( async def get(
fund_id: str, fund_id: str,
user: dict = Depends(get_current_user), user: User = Depends(get_current_user),
fund_collect: AgnosticCollection = Depends(get_fund_collect) fund_collect: AgnosticCollection = Depends(get_fund_collect)
): ):
data = await fund_collect.find_one({'id': fund_id, 'user_id': user['user_id']}) data = await fund_collect.find_one({'id': fund_id, 'user_id': user.id})
if not data: if not data:
raise NotFundError() raise NotFundError()
if data['fund_type'] == FundType.staking: if data['fund_type'] == FundType.staking:
......
...@@ -8,10 +8,12 @@ from db.mongodb_helper import AioMongodbManager ...@@ -8,10 +8,12 @@ from db.mongodb_helper import AioMongodbManager
from tools import jwt_tools from tools import jwt_tools
from starlette.requests import Request from starlette.requests import Request
from tools.jwt_tools import User
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(jwt_tools.security)) -> dict:
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(jwt_tools.security)) -> User:
if settings.env == 'LOCAL': if settings.env == 'LOCAL':
return {'user_id': "local_test", 'user_email': "local_test@qq.com"} return User(id='local_test', email='local_test@qq.com')
return jwt_tools.get_current_user(credentials) return jwt_tools.get_current_user(credentials)
......
...@@ -16,6 +16,38 @@ from tools.http_helper import aio_request ...@@ -16,6 +16,38 @@ from tools.http_helper import aio_request
security = HTTPBearer() security = HTTPBearer()
class User(object):
def __init__(self, **kwargs):
self.nbf = None
self.exp = None
self.iss = None
self.aud = None
self.client_id = None
self.sub = None
self.auth_time = None
self.idp = None
self.id = None
self.email = None
self.role = None
self.FoundRole = None
self.scope = None
self.amr = None
for k, v in kwargs.items():
if hasattr(self, k):
self.__setattr__(k, v)
@property
def user_id(self):
return self.id
@property
def user_email(self):
return self.email
def dict(self):
return self.__dict__
async def get_identify_key(): async def get_identify_key():
""" """
生成公钥 生成公钥
...@@ -40,16 +72,16 @@ def decode_token(token): ...@@ -40,16 +72,16 @@ def decode_token(token):
return payload return payload
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> dict: def get_current_user(credentials: HTTPAuthorizationCredentials = Security(security)) -> User:
token = credentials.credentials token = credentials.credentials
try: try:
assert credentials.scheme == 'Bearer' assert credentials.scheme == 'Bearer'
payload = decode_token(token) # options={'verify_signature':False} payload = decode_token(token) # options={'verify_signature':False}
payload['user_id'] = payload.pop("id", None) user = User()
payload['user_email'] = payload.pop("email", None) user.__dict__ = payload
if payload['user_id'] is None: if not user.id:
raise TokenError('错误的Token') raise TokenError('错误的Token')
return payload return user
except ExpiredSignatureError: except ExpiredSignatureError:
raise TokenError('Token已过期') raise TokenError('Token已过期')
except Exception as e: except Exception as e:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment