Commit 4b3670fb authored by Confusion-ymc's avatar Confusion-ymc

优化结构

parent 36737064
import datetime
from typing import Union
from fastapi import APIRouter
from motor.core import AgnosticCollection
from fastapi import APIRouter, Depends
from pymongo import ReturnDocument
from starlette.requests import Request
from db.mongodb_helper import AioMongodbManager
from exception.db import NotFundError
from model import Response
from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund
from tools.jwt_tools import get_current_user
from dependencies import get_current_user, get_mongodb_manager
router = APIRouter()
def get_mongodb_client(request, db='pyfund', collect='fund') -> AgnosticCollection:
collection: AgnosticCollection = request.app.state.mongodb_manger.get_client(name='pyfund', db=db, collect=collect)
return collection
@router.post('/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='创建基金',
@router.post('/',
response_model=Union[Response[StakingFund], Response[NormalFund]],
summary='创建基金',
description='创建基金')
async def create(create_fund: CreateFund, request: Request):
async def create(
create_fund: CreateFund,
user: dict = Depends(get_current_user),
mongodb_manger: AioMongodbManager = Depends(get_mongodb_manager)
):
if create_fund.fund_type == FundType.staking:
create_model = StakingFund(**create_fund.dict()) # **user_payload)
create_model = StakingFund(**create_fund.dict(), **user)
response = Response[StakingFund](data=create_model.dict())
else:
create_model = NormalFund(**create_fund.dict())
create_model = NormalFund(**create_fund.dict(), **user)
response = Response[NormalFund](data=create_model.dict())
collection = get_mongodb_client(request)
client = mongodb_manger.get_client(name='pyfund', db='pyfund', collect='fund')
insert_data = create_model.dict()
await collection.insert_one(insert_data)
await client.insert_one(insert_data)
return response
@router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金',
description='更新基金')
async def update(fund_id: str, update_fund: UpdateFund, request: Request):
collection = get_mongodb_client(request)
async def update(
fund_id: str,
update_fund: UpdateFund,
user: dict = Depends(get_current_user),
mongodb_manger: AioMongodbManager = Depends(get_mongodb_manager)
):
client = mongodb_manger.get_client(name='pyfund', db='pyfund', collect='fund')
db_update_data = update_fund.dict(exclude_unset=True)
db_update_data.update({
"update_time": int(datetime.datetime.utcnow().timestamp())
})
data = await collection.find_one_and_update({'id': fund_id}, {'$set': db_update_data},
data = await client.find_one_and_update({'id': fund_id, 'user_id': user['user_id']}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER)
if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data))
......@@ -53,10 +59,13 @@ async def update(fund_id: str, update_fund: UpdateFund, request: Request):
@router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金',
description='查询基金')
async def get(fund_id: str, request: Request):
collection = get_mongodb_client(request)
# data = await collection.find_one({'id': fund_id, 'user_id': user_payload['user_id']})
data = await collection.find_one({'id': fund_id})
async def get(
fund_id: str,
user: dict = Depends(get_current_user),
mongodb_manger: AioMongodbManager = Depends(get_mongodb_manager)
):
client = mongodb_manger.get_client(name='pyfund', db='pyfund', collect='fund')
data = await client.find_one({'id': fund_id, 'user_id': user['user_id']})
if not data:
raise NotFundError()
if data['fund_type'] == FundType.staking:
......
......@@ -6,14 +6,14 @@ from loguru import logger
from motor.core import AgnosticCollection
from motor.motor_asyncio import AsyncIOMotorClient
from configs import settings
class AioMongodbManager:
def __init__(self):
self.mongodb_pool: Dict[str, AsyncIOMotorClient] = {}
def setup_pool(self, mongodb_url, name: str = None):
# addr = mongodb_url.split("@")[1]
# name = name or addr
if name not in self.mongodb_pool:
logger.debug(f'新创建Mongodb连接池 [{mongodb_url}] [{name}]')
else:
......@@ -27,3 +27,9 @@ class AioMongodbManager:
return self.mongodb_pool[name][db][collect].with_options(
codec_options=CodecOptions(tz_aware=True, tzinfo=pytz.UTC))
def register_mongodb(app):
mongodb_manger = AioMongodbManager()
mongodb_manger.setup_pool(settings.mongodb, 'pyfund')
app.state.mongodb_manger = mongodb_manger
from fastapi import Security
from fastapi.security import HTTPAuthorizationCredentials
from configs import settings
from db.mongodb_helper import AioMongodbManager
from tools import jwt_tools
from starlette.requests import Request
def get_current_user(credentials: HTTPAuthorizationCredentials = Security(jwt_tools.security)) -> dict:
if settings.env == 'LOCAL':
return {'user_id': "local_test", 'user_email': "local_test@qq.com"}
return jwt_tools.get_current_user(credentials)
def get_mongodb_manager(request: Request) -> AioMongodbManager:
return request.app.state.mongodb_manger
......@@ -5,13 +5,14 @@ from loguru import logger
class MyException(Exception):
default_error = '系统错误'
status_code = 400
message = '系统错误'
status = 400
def __init__(self, message: Optional[str] = None):
def __init__(self, message: Optional[str] = None, status: Optional[int] = None):
if not message:
logger.warning(traceback.format_exc())
self.message = message or self.default_error
self.message = message or self.message
self.status = status or self.status
def __str__(self):
return self.message
......@@ -5,6 +5,6 @@ from exception import MyException
class NotFundError(MyException):
status_code = status.HTTP_404_NOT_FOUND
default_error = '未找到数据'
status = status.HTTP_404_NOT_FOUND
message = '未找到数据'
import traceback
from typing import Union
import uvicorn as uvicorn
from fastapi import FastAPI
......@@ -7,10 +6,11 @@ from fastapi.exceptions import RequestValidationError
from loguru import logger
from starlette import status
from starlette.requests import Request
from starlette.responses import JSONResponse
from api import api_router
from configs import settings
from db.mongodb_helper import AioMongodbManager
from db.mongodb_helper import register_mongodb
from exception import MyException
from model import ErrorResponse
from tools.jwt_tools import get_identify_key
......@@ -24,16 +24,10 @@ else:
app = FastAPI(docs_url='/swagger', openapi_prefix=openapi_prefix, debug=debug)
mongodb_manger = AioMongodbManager()
mongodb_manger.setup_pool(settings.mongodb, 'pyfund')
app.state.mongodb_manger = mongodb_manger
# 添加路由
app.include_router(api_router)
@app.exception_handler(MyException)
async def not_fund_exception_handler(request: Request, exc: MyException):
return ErrorResponse(message=str(exc), status_code=exc.status_code)
return JSONResponse(ErrorResponse(message=str(exc), status=exc.status).dict())
@app.exception_handler(RequestValidationError)
......@@ -45,19 +39,25 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
:return:
"""
# 日志记录异常详细上下文
return ErrorResponse(message='参数错误 ' + str(exc), status_code=status.HTTP_400_BAD_REQUEST)
return JSONResponse(ErrorResponse(message='参数错误 ' + str(exc), status=status.HTTP_400_BAD_REQUEST).dict())
@app.exception_handler(Exception)
async def sys_exception_handler(request: Request, exc: Exception):
logger.error(f"全局异\n{request.method}URL{request.url}\nHeaders:{request.headers}\n{traceback.format_exc()}")
return ErrorResponse(message='系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
return JSONResponse(
ErrorResponse(message='系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '',
status=status.HTTP_500_INTERNAL_SERVER_ERROR).dict())
@app.on_event('startup')
async def startup():
# 鉴权中心获取公钥
await get_identify_key()
# 挂载 mongodb
register_mongodb(app)
# 添加路由
app.include_router(api_router)
if __name__ == '__main__':
......
......@@ -26,7 +26,7 @@ class BaseResponse(BaseModel):
status: int = 200
class ErrorResponse(BaseResponse):
class ErrorResponse(BaseModel):
data: Any
message: str = 'failed'
status: int = 500
......
......@@ -27,10 +27,13 @@ class CreateFund(BaseFundItem):
# 传入数据库类型
class NormalFund(BaseFundItem, BaseCreateModel):
pass
user_id: str
user_email: str
class StakingFund(BaseFundItem, BaseCreateModel):
user_id: str
user_email: str
nodes: List[BaseNode]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment