Commit 36737064 authored by Confusion-ymc's avatar Confusion-ymc

优化代码

parent 71d892a7
import datetime
from typing import Union from typing import Union
from fastapi import APIRouter, Depends from fastapi import APIRouter
from motor.core import AgnosticCollection from motor.core import AgnosticCollection
from pymongo import ReturnDocument from pymongo import ReturnDocument
from starlette.requests import Request from starlette.requests import Request
from exception.db import NotFundError from exception.db import NotFundError
from model import success_res from model import Response
from model.fund import ApiCreateFund, DBCreateStakingFund, response_staking_fund_model, response_fund_model, FundType, \ from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund
DBCreateFund, ApiUpdateFund
from tools.jwt_tools import get_current_user from tools.jwt_tools import get_current_user
router = APIRouter() router = APIRouter()
...@@ -19,30 +19,39 @@ def get_mongodb_client(request, db='pyfund', collect='fund') -> AgnosticCollecti ...@@ -19,30 +19,39 @@ def get_mongodb_client(request, db='pyfund', collect='fund') -> AgnosticCollecti
return collection return collection
@router.post('/', response_model=Union[response_staking_fund_model, response_fund_model], summary='创建基金', @router.post('/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='创建基金',
description='创建基金') description='创建基金')
async def create(fund: ApiCreateFund, request: Request): async def create(create_fund: CreateFund, request: Request):
if fund.fund_type == FundType.staking: if create_fund.fund_type == FundType.staking:
# add_fund = DBCreateStakingFund(**fund.dict(), **user_payload) create_model = StakingFund(**create_fund.dict()) # **user_payload)
add_fund = DBCreateStakingFund(**fund.dict()) response = Response[StakingFund](data=create_model.dict())
else: else:
# add_fund = DBCreateFund.from_orm(**fund.dict(), **user_payload) create_model = NormalFund(**create_fund.dict())
add_fund = DBCreateFund(**fund.dict()) response = Response[NormalFund](data=create_model.dict())
collection = get_mongodb_client(request) collection = get_mongodb_client(request)
await collection.insert_one(add_fund.dict()) insert_data = create_model.dict()
return success_res(data=add_fund)
await collection.insert_one(insert_data)
return response
@router.put('/{fund_id}/', response_model=Union[response_fund_model, response_staking_fund_model], summary='更新基金',
@router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金',
description='更新基金') description='更新基金')
async def update(fund_id: str, update_fund: ApiUpdateFund, request: Request): async def update(fund_id: str, update_fund: UpdateFund, request: Request):
collection = get_mongodb_client(request) collection = get_mongodb_client(request)
data = await collection.find_one_and_update({'id': fund_id}, {'$set': update_fund.dict(exclude_unset=True)}, db_update_data = update_fund.dict(exclude_unset=True)
db_update_data.update({
"update_time": int(datetime.datetime.utcnow().timestamp())
})
data = await collection.find_one_and_update({'id': fund_id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER) return_document=ReturnDocument.AFTER)
return success_res(data=data) if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data))
else:
return Response[NormalFund](data=NormalFund(**data))
@router.get('/{fund_id}/', response_model=Union[response_staking_fund_model], summary='查询基金', @router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金',
description='查询基金') description='查询基金')
async def get(fund_id: str, request: Request): async def get(fund_id: str, request: Request):
collection = get_mongodb_client(request) collection = get_mongodb_client(request)
...@@ -51,7 +60,7 @@ async def get(fund_id: str, request: Request): ...@@ -51,7 +60,7 @@ async def get(fund_id: str, request: Request):
if not data: if not data:
raise NotFundError() raise NotFundError()
if data['fund_type'] == FundType.staking: if data['fund_type'] == FundType.staking:
res = DBCreateStakingFund(**data) response = Response[StakingFund](data=StakingFund(**data).dict())
else: else:
res = DBCreateFund(**data) response = Response[NormalFund](data=NormalFund(**data).dict())
return success_res(data=res) return response
from exception import MyException from exception import MyException
class ReqException(MyException): class RequestHttpException(MyException):
pass pass
from exception import MyException from exception import MyException
class RequestPubKeyError(MyException):
pass
class TokenError(MyException): class TokenError(MyException):
pass pass
import json
import traceback import traceback
from typing import Union
import httpx as httpx
import uvicorn as uvicorn import uvicorn as uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from jwt.algorithms import get_default_algorithms
from loguru import logger from loguru import logger
from starlette import status from starlette import status
from starlette.requests import Request from starlette.requests import Request
import configs
from api import api_router from api import api_router
from configs import settings from configs import settings
from db.mongodb_helper import AioMongodbManager from db.mongodb_helper import AioMongodbManager
from exception import MyException from exception import MyException
from model import error_response from model import ErrorResponse
from tools.jwt_tools import get_identify_key from tools.jwt_tools import get_identify_key
app = FastAPI() if settings.env != 'LOCAL':
openapi_prefix = '/coinsdataapiv2'
debug = False
else:
openapi_prefix = ''
debug = True
app = FastAPI(docs_url='/swagger', openapi_prefix=openapi_prefix, debug=debug)
mongodb_manger = AioMongodbManager() mongodb_manger = AioMongodbManager()
mongodb_manger.setup_pool(settings.mongodb, 'pyfund') mongodb_manger.setup_pool(settings.mongodb, 'pyfund')
...@@ -29,7 +33,7 @@ app.include_router(api_router) ...@@ -29,7 +33,7 @@ app.include_router(api_router)
@app.exception_handler(MyException) @app.exception_handler(MyException)
async def not_fund_exception_handler(request: Request, exc: MyException): async def not_fund_exception_handler(request: Request, exc: MyException):
return error_response(str(exc), status_code=exc.status_code) return ErrorResponse(message=str(exc), status_code=exc.status_code)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
...@@ -41,13 +45,13 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal ...@@ -41,13 +45,13 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
:return: :return:
""" """
# 日志记录异常详细上下文 # 日志记录异常详细上下文
return error_response('参数错误 ' + str(exc), status_code=status.HTTP_400_BAD_REQUEST) return ErrorResponse(message='参数错误 ' + str(exc), status_code=status.HTTP_400_BAD_REQUEST)
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def sys_exception_handler(request: Request, exc: Exception): async def sys_exception_handler(request: Request, exc: Exception):
logger.error(f"全局异\n{request.method}URL{request.url}\nHeaders:{request.headers}\n{traceback.format_exc()}") logger.error(f"全局异\n{request.method}URL{request.url}\nHeaders:{request.headers}\n{traceback.format_exc()}")
return error_response('系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '', return ErrorResponse(message='系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
......
import datetime import datetime
import json
import uuid import uuid
from typing import Any, List, Optional, TypeVar, Generic from typing import Any, TypeVar, Generic
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from starlette import status from starlette import status
from starlette.responses import JSONResponse
DataT = TypeVar('DataT') DataT = TypeVar('DataT')
...@@ -19,71 +17,19 @@ class DataModel(BaseModel): ...@@ -19,71 +17,19 @@ class DataModel(BaseModel):
class Response(GenericModel, Generic[DataT]): class Response(GenericModel, Generic[DataT]):
data: DataT | None data: DataT | None
message: str = 'success' message: str = 'success'
status: int status: int = status.HTTP_200_OK
class BaseJsonResponse(BaseModel): class BaseResponse(BaseModel):
data: Any data: Any
message: str = 'success' message: str = 'success'
status: int = 200 status: int = 200
class BasePageJsonResponse(BaseJsonResponse): class ErrorResponse(BaseResponse):
page_size: Optional[int] data: Any
page_num: Optional[int] message: str = 'failed'
total: Optional[int] status: int = 500
def dynamic_response(data_type):
class DyResponse(BaseJsonResponse):
data: Optional[data_type]
dy_response_model = type('Response' + data_type.__name__, (DyResponse,), {})
return dy_response_model
def list_dynamic_response(data_type):
class DyResponse(BaseJsonResponse):
data: Optional[List[data_type]]
dy_response_model = type('ListResponse' + data_type.__name__, (DyResponse,), {})
return dy_response_model
def page_dynamic_response(data_type):
class DyPageResponse(BasePageJsonResponse):
data: Optional[List[data_type]]
page_size: Optional[int]
page_num: Optional[int]
total: Optional[int]
dy_response_model = type('PageResponse' + data_type.__name__, (DyPageResponse,), {})
return dy_response_model
def success_res(data=None, message='success', status_code=status.HTTP_200_OK, total=None, page_num=None,
page_size=None):
res = {
"data": data,
"message": message,
"status": status_code
}
if total is not None:
res.update({
"total": total,
"page_num": page_num,
"page_size": page_size
})
return res
def error_response(message='failed', status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=None):
res = {
"data": data,
"message": message,
"status": status_code
}
return JSONResponse(content=res, status_code=status_code)
class BaseCreateModel(BaseModel): class BaseCreateModel(BaseModel):
...@@ -95,11 +41,3 @@ class BaseCreateModel(BaseModel): ...@@ -95,11 +41,3 @@ class BaseCreateModel(BaseModel):
class Config: class Config:
orm_mode = True orm_mode = True
class BaseUpdateModel(BaseModel):
update_time: int = Field(default_factory=lambda: int(datetime.datetime.utcnow().timestamp()),
description='更新时间')
class Config:
orm_mode = True
...@@ -3,38 +3,38 @@ from typing import List, Optional ...@@ -3,38 +3,38 @@ from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from model import dynamic_response, BaseCreateModel, BaseUpdateModel, Response from model import BaseCreateModel
from model.node import DBNode from model.node import BaseNode
class FundType(str, Enum): class FundType(str, Enum):
staking = 'staking' staking = 'staking'
other = 'other' normal = 'normal'
class ApiCreateFund(BaseModel): class BaseFundItem(BaseModel):
name: str = Field(..., description='基金名称') name: str = Field(..., description='基金名称')
fund_type: FundType = Field(default=FundType.staking, description='基金类型') fund_type: str = Field(default=FundType.staking.value, description='基金类型')
base_coin: str = Field(default='USD', description='基准币种') base_coin: str = Field(default='USD', description='基准币种')
base_nav: float = Field(default=1, description='初始净值') base_nav: float = Field(default=1, description='初始净值')
settlement_time: str = Field(default='08:00', description='结算时间') settlement_time: str = Field(default='08:00', description='结算时间')
# nodes: List[DBNode] = Field(default=[], description='绑定节点')
class Config:
orm_mode = True
# 接口传入模型
class CreateFund(BaseFundItem):
nodes: List[BaseNode] = Field(default=[], description='绑定节点')
class DBCreateFund(ApiCreateFund, BaseCreateModel):
user_id: str = Field(None, description='创建人')
user_email: str = Field(None, description='创建人')
nav: float = Field(default=1, description='当前净值')
# 传入数据库类型
class NormalFund(BaseFundItem, BaseCreateModel):
pass
class DBCreateStakingFund(DBCreateFund):
nodes: List[DBNode] = Field(default=[], description='绑定节点')
class StakingFund(BaseFundItem, BaseCreateModel):
nodes: List[BaseNode]
class ApiUpdateFund(BaseModel):
class UpdateFund(BaseModel):
name: Optional[str] = Field(None, description='基金名称') name: Optional[str] = Field(None, description='基金名称')
fund_type: Optional[FundType] = Field(default=None, description='基金类型') fund_type: Optional[FundType] = Field(default=None, description='基金类型')
base_coin: Optional[str] = Field(None, description='基准币种') base_coin: Optional[str] = Field(None, description='基准币种')
...@@ -43,14 +43,3 @@ class ApiUpdateFund(BaseModel): ...@@ -43,14 +43,3 @@ class ApiUpdateFund(BaseModel):
class Config: class Config:
orm_mode = True orm_mode = True
class DBUpdateFund(ApiUpdateFund, BaseUpdateModel):
pass
# response_fund_model = dynamic_response(DBCreateFund)
response_fund_model = Response[DBCreateFund]
response_staking_fund_model = Response[DBCreateStakingFund]
# response_staking_fund_model = dynamic_response(DBCreateStakingFund)
# response_list_branch_model = list_dynamic_response(ResBranch)
from enum import Enum from enum import Enum
from pydantic import Field, BaseModel from pydantic import Field
from model import BaseCreateModel from model import BaseCreateModel
...@@ -11,13 +11,6 @@ class NodeStatus(str, Enum): ...@@ -11,13 +11,6 @@ class NodeStatus(str, Enum):
stop = 'stop' stop = 'stop'
class ApiCreateNode(BaseModel): class BaseNode(BaseCreateModel):
pub_key: str = Field(..., description='创建人') pub_key: str = Field(..., description='创建人')
status: NodeStatus = Field(default=NodeStatus.pending, description='创建人') status: NodeStatus = Field(default=NodeStatus.pending, description='创建人')
class DBNode(BaseCreateModel, ApiCreateNode):
pass
# response_node_model = dynamic_response(DBCreateFund)
# response_list_branch_model = list_dynamic_response(ResBranch)
import json import json
import aiohttp as aiohttp import httpx
from aiohttp import ClientTimeout
from loguru import logger from loguru import logger
from exception.http import ReqException from exception.http import RequestHttpException
async def aio_request(url, method='GET', **kwargs): async def aio_request(url, method='GET', **kwargs):
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36',
'Content-Type': 'application/json',
'Accept': 'application/json'
}
headers.update(kwargs.pop('headers', {}))
method = method.upper()
logger.info(f"[请求内容] url={url}")
try: try:
async with aiohttp.ClientSession(headers=headers, timeout=ClientTimeout(total=120)) as session: async with httpx.AsyncClient() as client:
async with session.request(url=url, method=method, ssl=False, **kwargs) as r: method = method.upper()
logger.info(f"[返回状态] url={url}, status={r.status}") response = await client.request(method=method, url=url, **kwargs)
json_body = await r.json() content = response.content
return json_body res = json.loads(content)
logger.info(f'请求成功 [{method}] [{url}]')
return res
except Exception as e: except Exception as e:
raise ReqException(f'请求失败 url: {url} :: {e}') logger.error(f'请求失败 [{method}] [{url}] [{e}]')
\ No newline at end of file raise RequestHttpException(message=str(e))
import json import json
from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx
import jwt import jwt
from fastapi import Security from fastapi import Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
...@@ -12,7 +10,8 @@ from jwt.algorithms import get_default_algorithms ...@@ -12,7 +10,8 @@ from jwt.algorithms import get_default_algorithms
from loguru import logger from loguru import logger
from configs import settings from configs import settings
from exception.token import RequestPubKeyError, TokenError from exception.token import TokenError
from tools.http_helper import aio_request
security = HTTPBearer() security = HTTPBearer()
...@@ -23,20 +22,13 @@ async def get_identify_key(): ...@@ -23,20 +22,13 @@ async def get_identify_key():
:return: :return:
""" """
# 请求key # 请求key
content = '' res = await aio_request(settings.identify_jwt)
try:
async with httpx.AsyncClient() as client:
response = await client.get(settings.identify_jwt)
content = response.content
res = json.loads(content)
print('公钥获取成功')
key_data = res['keys'][0] key_data = res['keys'][0]
rsa = get_default_algorithms()[key_data['alg']] rsa = get_default_algorithms()[key_data['alg']]
public_key = rsa.from_jwk(json.dumps(key_data)) public_key = rsa.from_jwk(json.dumps(key_data))
settings.public_key = public_key settings.public_key = public_key
settings.algorithms = key_data['alg'] settings.algorithms = key_data['alg']
except Exception as e: logger.info('公钥获取成功')
raise RequestPubKeyError(f'公钥获取失败, url:{settings.identify_jwt}, res:{content} {e}')
def decode_token(token): def decode_token(token):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment