Commit 36737064 authored by Confusion-ymc's avatar Confusion-ymc

优化代码

parent 71d892a7
import datetime
from typing import Union
from fastapi import APIRouter, Depends
from fastapi import APIRouter
from motor.core import AgnosticCollection
from pymongo import ReturnDocument
from starlette.requests import Request
from exception.db import NotFundError
from model import success_res
from model.fund import ApiCreateFund, DBCreateStakingFund, response_staking_fund_model, response_fund_model, FundType, \
DBCreateFund, ApiUpdateFund
from model import Response
from model.fund import FundType, CreateFund, StakingFund, NormalFund, UpdateFund
from tools.jwt_tools import get_current_user
router = APIRouter()
......@@ -19,30 +19,39 @@ def get_mongodb_client(request, db='pyfund', collect='fund') -> AgnosticCollecti
return collection
@router.post('/', response_model=Union[response_staking_fund_model, response_fund_model], summary='创建基金',
@router.post('/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='创建基金',
description='创建基金')
async def create(fund: ApiCreateFund, request: Request):
if fund.fund_type == FundType.staking:
# add_fund = DBCreateStakingFund(**fund.dict(), **user_payload)
add_fund = DBCreateStakingFund(**fund.dict())
async def create(create_fund: CreateFund, request: Request):
if create_fund.fund_type == FundType.staking:
create_model = StakingFund(**create_fund.dict()) # **user_payload)
response = Response[StakingFund](data=create_model.dict())
else:
# add_fund = DBCreateFund.from_orm(**fund.dict(), **user_payload)
add_fund = DBCreateFund(**fund.dict())
create_model = NormalFund(**create_fund.dict())
response = Response[NormalFund](data=create_model.dict())
collection = get_mongodb_client(request)
await collection.insert_one(add_fund.dict())
return success_res(data=add_fund)
insert_data = create_model.dict()
await collection.insert_one(insert_data)
return response
@router.put('/{fund_id}/', response_model=Union[response_fund_model, response_staking_fund_model], summary='更新基金',
@router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金',
description='更新基金')
async def update(fund_id: str, update_fund: ApiUpdateFund, request: Request):
async def update(fund_id: str, update_fund: UpdateFund, request: Request):
collection = get_mongodb_client(request)
data = await collection.find_one_and_update({'id': fund_id}, {'$set': update_fund.dict(exclude_unset=True)},
db_update_data = update_fund.dict(exclude_unset=True)
db_update_data.update({
"update_time": int(datetime.datetime.utcnow().timestamp())
})
data = await collection.find_one_and_update({'id': fund_id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER)
return success_res(data=data)
if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data))
else:
return Response[NormalFund](data=NormalFund(**data))
@router.get('/{fund_id}/', response_model=Union[response_staking_fund_model], summary='查询基金',
@router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金',
description='查询基金')
async def get(fund_id: str, request: Request):
collection = get_mongodb_client(request)
......@@ -51,7 +60,7 @@ async def get(fund_id: str, request: Request):
if not data:
raise NotFundError()
if data['fund_type'] == FundType.staking:
res = DBCreateStakingFund(**data)
response = Response[StakingFund](data=StakingFund(**data).dict())
else:
res = DBCreateFund(**data)
return success_res(data=res)
response = Response[NormalFund](data=NormalFund(**data).dict())
return response
from exception import MyException
class ReqException(MyException):
class RequestHttpException(MyException):
pass
from exception import MyException
class RequestPubKeyError(MyException):
pass
class TokenError(MyException):
pass
import json
import traceback
from typing import Union
import httpx as httpx
import uvicorn as uvicorn
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
from jwt.algorithms import get_default_algorithms
from loguru import logger
from starlette import status
from starlette.requests import Request
import configs
from api import api_router
from configs import settings
from db.mongodb_helper import AioMongodbManager
from exception import MyException
from model import error_response
from model import ErrorResponse
from tools.jwt_tools import get_identify_key
app = FastAPI()
if settings.env != 'LOCAL':
openapi_prefix = '/coinsdataapiv2'
debug = False
else:
openapi_prefix = ''
debug = True
app = FastAPI(docs_url='/swagger', openapi_prefix=openapi_prefix, debug=debug)
mongodb_manger = AioMongodbManager()
mongodb_manger.setup_pool(settings.mongodb, 'pyfund')
......@@ -29,7 +33,7 @@ app.include_router(api_router)
@app.exception_handler(MyException)
async def not_fund_exception_handler(request: Request, exc: MyException):
return error_response(str(exc), status_code=exc.status_code)
return ErrorResponse(message=str(exc), status_code=exc.status_code)
@app.exception_handler(RequestValidationError)
......@@ -41,13 +45,13 @@ async def request_validation_exception_handler(request: Request, exc: RequestVal
:return:
"""
# 日志记录异常详细上下文
return error_response('参数错误 ' + str(exc), status_code=status.HTTP_400_BAD_REQUEST)
return ErrorResponse(message='参数错误 ' + str(exc), status_code=status.HTTP_400_BAD_REQUEST)
@app.exception_handler(Exception)
async def sys_exception_handler(request: Request, exc: Exception):
logger.error(f"全局异\n{request.method}URL{request.url}\nHeaders:{request.headers}\n{traceback.format_exc()}")
return error_response('系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '',
return ErrorResponse(message='系统异常' + f' {str(exc)}' if settings.name in ['本地环境', "测试环境"] else '',
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
......
import datetime
import json
import uuid
from typing import Any, List, Optional, TypeVar, Generic
from typing import Any, TypeVar, Generic
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from starlette import status
from starlette.responses import JSONResponse
DataT = TypeVar('DataT')
......@@ -19,71 +17,19 @@ class DataModel(BaseModel):
class Response(GenericModel, Generic[DataT]):
data: DataT | None
message: str = 'success'
status: int
status: int = status.HTTP_200_OK
class BaseJsonResponse(BaseModel):
class BaseResponse(BaseModel):
data: Any
message: str = 'success'
status: int = 200
class BasePageJsonResponse(BaseJsonResponse):
page_size: Optional[int]
page_num: Optional[int]
total: Optional[int]
def dynamic_response(data_type):
class DyResponse(BaseJsonResponse):
data: Optional[data_type]
dy_response_model = type('Response' + data_type.__name__, (DyResponse,), {})
return dy_response_model
def list_dynamic_response(data_type):
class DyResponse(BaseJsonResponse):
data: Optional[List[data_type]]
dy_response_model = type('ListResponse' + data_type.__name__, (DyResponse,), {})
return dy_response_model
def page_dynamic_response(data_type):
class DyPageResponse(BasePageJsonResponse):
data: Optional[List[data_type]]
page_size: Optional[int]
page_num: Optional[int]
total: Optional[int]
dy_response_model = type('PageResponse' + data_type.__name__, (DyPageResponse,), {})
return dy_response_model
def success_res(data=None, message='success', status_code=status.HTTP_200_OK, total=None, page_num=None,
page_size=None):
res = {
"data": data,
"message": message,
"status": status_code
}
if total is not None:
res.update({
"total": total,
"page_num": page_num,
"page_size": page_size
})
return res
def error_response(message='failed', status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, data=None):
res = {
"data": data,
"message": message,
"status": status_code
}
return JSONResponse(content=res, status_code=status_code)
class ErrorResponse(BaseResponse):
data: Any
message: str = 'failed'
status: int = 500
class BaseCreateModel(BaseModel):
......@@ -95,11 +41,3 @@ class BaseCreateModel(BaseModel):
class Config:
orm_mode = True
class BaseUpdateModel(BaseModel):
update_time: int = Field(default_factory=lambda: int(datetime.datetime.utcnow().timestamp()),
description='更新时间')
class Config:
orm_mode = True
......@@ -3,38 +3,38 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from model import dynamic_response, BaseCreateModel, BaseUpdateModel, Response
from model.node import DBNode
from model import BaseCreateModel
from model.node import BaseNode
class FundType(str, Enum):
staking = 'staking'
other = 'other'
normal = 'normal'
class ApiCreateFund(BaseModel):
class BaseFundItem(BaseModel):
name: str = Field(..., description='基金名称')
fund_type: FundType = Field(default=FundType.staking, description='基金类型')
fund_type: str = Field(default=FundType.staking.value, description='基金类型')
base_coin: str = Field(default='USD', description='基准币种')
base_nav: float = Field(default=1, description='初始净值')
settlement_time: str = Field(default='08:00', description='结算时间')
# nodes: List[DBNode] = Field(default=[], description='绑定节点')
class Config:
orm_mode = True
# 接口传入模型
class CreateFund(BaseFundItem):
nodes: List[BaseNode] = Field(default=[], description='绑定节点')
class DBCreateFund(ApiCreateFund, BaseCreateModel):
user_id: str = Field(None, description='创建人')
user_email: str = Field(None, description='创建人')
nav: float = Field(default=1, description='当前净值')
# 传入数据库类型
class NormalFund(BaseFundItem, BaseCreateModel):
pass
class DBCreateStakingFund(DBCreateFund):
nodes: List[DBNode] = Field(default=[], description='绑定节点')
class StakingFund(BaseFundItem, BaseCreateModel):
nodes: List[BaseNode]
class ApiUpdateFund(BaseModel):
class UpdateFund(BaseModel):
name: Optional[str] = Field(None, description='基金名称')
fund_type: Optional[FundType] = Field(default=None, description='基金类型')
base_coin: Optional[str] = Field(None, description='基准币种')
......@@ -43,14 +43,3 @@ class ApiUpdateFund(BaseModel):
class Config:
orm_mode = True
class DBUpdateFund(ApiUpdateFund, BaseUpdateModel):
pass
# response_fund_model = dynamic_response(DBCreateFund)
response_fund_model = Response[DBCreateFund]
response_staking_fund_model = Response[DBCreateStakingFund]
# response_staking_fund_model = dynamic_response(DBCreateStakingFund)
# response_list_branch_model = list_dynamic_response(ResBranch)
from enum import Enum
from pydantic import Field, BaseModel
from pydantic import Field
from model import BaseCreateModel
......@@ -11,13 +11,6 @@ class NodeStatus(str, Enum):
stop = 'stop'
class ApiCreateNode(BaseModel):
class BaseNode(BaseCreateModel):
pub_key: str = Field(..., description='创建人')
status: NodeStatus = Field(default=NodeStatus.pending, description='创建人')
class DBNode(BaseCreateModel, ApiCreateNode):
pass
# response_node_model = dynamic_response(DBCreateFund)
# response_list_branch_model = list_dynamic_response(ResBranch)
import json
import aiohttp as aiohttp
from aiohttp import ClientTimeout
import httpx
from loguru import logger
from exception.http import ReqException
from exception.http import RequestHttpException
async def aio_request(url, method='GET', **kwargs):
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36',
'Content-Type': 'application/json',
'Accept': 'application/json'
}
headers.update(kwargs.pop('headers', {}))
method = method.upper()
logger.info(f"[请求内容] url={url}")
try:
async with aiohttp.ClientSession(headers=headers, timeout=ClientTimeout(total=120)) as session:
async with session.request(url=url, method=method, ssl=False, **kwargs) as r:
logger.info(f"[返回状态] url={url}, status={r.status}")
json_body = await r.json()
return json_body
async with httpx.AsyncClient() as client:
method = method.upper()
response = await client.request(method=method, url=url, **kwargs)
content = response.content
res = json.loads(content)
logger.info(f'请求成功 [{method}] [{url}]')
return res
except Exception as e:
raise ReqException(f'请求失败 url: {url} :: {e}')
\ No newline at end of file
logger.error(f'请求失败 [{method}] [{url}] [{e}]')
raise RequestHttpException(message=str(e))
import json
from pathlib import Path
from urllib.parse import urlparse
import httpx
import jwt
from fastapi import Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
......@@ -12,7 +10,8 @@ from jwt.algorithms import get_default_algorithms
from loguru import logger
from configs import settings
from exception.token import RequestPubKeyError, TokenError
from exception.token import TokenError
from tools.http_helper import aio_request
security = HTTPBearer()
......@@ -23,20 +22,13 @@ async def get_identify_key():
:return:
"""
# 请求key
content = ''
try:
async with httpx.AsyncClient() as client:
response = await client.get(settings.identify_jwt)
content = response.content
res = json.loads(content)
print('公钥获取成功')
res = await aio_request(settings.identify_jwt)
key_data = res['keys'][0]
rsa = get_default_algorithms()[key_data['alg']]
public_key = rsa.from_jwk(json.dumps(key_data))
settings.public_key = public_key
settings.algorithms = key_data['alg']
except Exception as e:
raise RequestPubKeyError(f'公钥获取失败, url:{settings.identify_jwt}, res:{content} {e}')
logger.info('公钥获取成功')
def decode_token(token):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment