Commit f819bb2b authored by confusion's avatar confusion

修改查询接口

parent c903b6cd
import datetime from typing import Union, List, Any
from loguru import logger from loguru import logger
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, Query
from motor.core import AgnosticCollection from motor.core import AgnosticCollection
from pymongo import ReturnDocument
from pymongo.operations import UpdateOne from pymongo.operations import UpdateOne
from dependencies import get_current_user, get_fund_collect, get_bill_collect from dependencies import get_current_user, get_fund_collect, get_bill_collect
from exception.db import NotFundError from exception.db import NotFundError
from model import Response, Page, PageResponse from model import Response, Page, PageResponse, SortParams, FilterTime
from model.bill import PCFBill, ExchangeBill, BillType, CreatePCFBill, CreateExchangeBill, StakingBill, CreateStaking, \ from model.bill import PCFBill, ExchangeBill, BillType, CreatePCFBill, CreateExchangeBill, StakingBill, \
AdjustBill, CreateAdjustBill, UpdatePCFBill, UpdateExchangeBill, UpdateStakingBill, UpdateAdjustBill AdjustBill, CreateAdjustBill, UpdatePCFBill, UpdateExchangeBill, UpdateStakingBill, UpdateAdjustBill
from service.bill import update_bill from service.bill import update_bill
from tools.jwt_tools import User from tools.jwt_tools import User
...@@ -17,6 +17,7 @@ router = APIRouter() ...@@ -17,6 +17,7 @@ router = APIRouter()
@router.post('/pcf/', @router.post('/pcf/',
response_model=Response[PCFBill], response_model=Response[PCFBill],
tags=['新增'],
summary='添加申购赎回账目', summary='添加申购赎回账目',
description='添加申购赎回账目') description='添加申购赎回账目')
async def create_pcf( async def create_pcf(
...@@ -46,6 +47,7 @@ async def create_pcf( ...@@ -46,6 +47,7 @@ async def create_pcf(
@router.post('/exchange/', @router.post('/exchange/',
response_model=Response[ExchangeBill], response_model=Response[ExchangeBill],
tags=['新增'],
summary='添加置换币账目', summary='添加置换币账目',
description='添加置换币账目') description='添加置换币账目')
async def create_exchange( async def create_exchange(
...@@ -83,6 +85,7 @@ async def create_exchange( ...@@ -83,6 +85,7 @@ async def create_exchange(
@router.post('/adjust/', @router.post('/adjust/',
response_model=Response[AdjustBill], response_model=Response[AdjustBill],
tags=['新增'],
summary='添加调整账目', summary='添加调整账目',
description='添加调整账目') description='添加调整账目')
async def create_adjust( async def create_adjust(
...@@ -108,83 +111,65 @@ async def create_adjust( ...@@ -108,83 +111,65 @@ async def create_adjust(
return response return response
@router.get('/exchange/{fund_id}/', # @router.get('/exchange/{fund_id}/',
response_model=PageResponse[ExchangeBill], # response_model=PageResponse[ExchangeBill],
summary='查询置换记录', # summary='查询置换记录',
description='') # description='')
async def query_exchange_bill( # async def query_exchange(
fund_id: str, # fund_id: str,
page: Page = Depends(Page), # page: Page = Depends(Page),
user: User = Depends(get_current_user), # user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect), # bill_collect: AgnosticCollection = Depends(get_bill_collect),
): # ):
skip = (page.page - 1) * page.page_size # skip = (page.page - 1) * page.page_size
cursor = bill_collect.find( # cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange}) # {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size) # cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None) # result = await cursor.to_list(length=None)
response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result)) # response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result))
return response # return response
@router.get('/exchange/{fund_id}/', # @router.get('/staking/{fund_id}/',
response_model=PageResponse[ExchangeBill], # response_model=PageResponse[StakingBill],
summary='查询置换记录', # summary='查询质押记录',
description='') # description='')
async def query_exchange( # async def query_staking(
fund_id: str, # fund_id: str,
page: Page = Depends(Page), # page: Page = Depends(Page),
user: User = Depends(get_current_user), # user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect), # bill_collect: AgnosticCollection = Depends(get_bill_collect),
): # ):
skip = (page.page - 1) * page.page_size # skip = (page.page - 1) * page.page_size
cursor = bill_collect.find( # cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange}) # {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.staking})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size) # cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None) # result = await cursor.to_list(length=None)
response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result)) # response = PageResponse[StakingBill](data=result, **page.dict(), total=len(result))
return response # return response
@router.get('/staking/{fund_id}/',
response_model=PageResponse[StakingBill],
summary='查询质押记录',
description='')
async def query_staking(
fund_id: str,
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.staking})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[StakingBill](data=result, **page.dict(), total=len(result))
return response
@router.get('/adjust/{fund_id}/', # @router.get('/adjust/{fund_id}/',
response_model=PageResponse[AdjustBill], # response_model=PageResponse[AdjustBill],
summary='查询调整记录', # summary='查询调整记录',
description='') # description='')
async def query_adjust( # async def query_adjust(
fund_id: str, # fund_id: str,
page: Page = Depends(Page), # page: Page = Depends(Page),
user: User = Depends(get_current_user), # user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect), # bill_collect: AgnosticCollection = Depends(get_bill_collect),
): # ):
skip = (page.page - 1) * page.page_size # skip = (page.page - 1) * page.page_size
cursor = bill_collect.find( # cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.adjust}) # {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.adjust})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size) # cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None) # result = await cursor.to_list(length=None)
response = PageResponse[AdjustBill](data=result, **page.dict(), total=len(result)) # response = PageResponse[AdjustBill](data=result, **page.dict(), total=len(result))
return response # return response
@router.put('/{fund_id}/pcf/', @router.put('/pcf/{fund_id}/',
tags=['更新'],
response_model=Response[PCFBill], response_model=Response[PCFBill],
summary='更新申购赎回记录', summary='更新申购赎回记录',
description='') description='')
...@@ -208,7 +193,8 @@ async def update_pcf_bill( ...@@ -208,7 +193,8 @@ async def update_pcf_bill(
return response return response
@router.put('/{fund_id}/exchange/', @router.put('/exchange/{fund_id}/',
tags=['更新'],
response_model=Response[ExchangeBill], response_model=Response[ExchangeBill],
summary='更新置换记录', summary='更新置换记录',
description='') description='')
...@@ -232,7 +218,8 @@ async def update_exchange_bill( ...@@ -232,7 +218,8 @@ async def update_exchange_bill(
return response return response
@router.put('/{fund_id}/staking/', @router.put('/staking/{fund_id}/',
tags=['更新'],
response_model=Response[StakingBill], response_model=Response[StakingBill],
summary='更新申购赎回调整记录', summary='更新申购赎回调整记录',
description='') description='')
...@@ -256,7 +243,8 @@ async def update_staking_bill( ...@@ -256,7 +243,8 @@ async def update_staking_bill(
return response return response
@router.put('/{fund_id}/adjust/', @router.put('/adjust/{fund_id}/',
tags=['更新'],
response_model=Response[AdjustBill], response_model=Response[AdjustBill],
summary='更新调整记录', summary='更新调整记录',
description='') description='')
...@@ -278,3 +266,29 @@ async def update_adjust_bill( ...@@ -278,3 +266,29 @@ async def update_adjust_bill(
res_model=AdjustBill res_model=AdjustBill
) )
return response return response
@router.get('/{fund_id}/',
tags=["查询"],
response_model=PageResponse[Union[PCFBill, ExchangeBill, StakingBill, AdjustBill]],
summary='查询账单记录',
description='查询账单记录')
async def query_bill(
fund_id: str,
sort_by: SortParams = Depends(SortParams),
filter_time: FilterTime = Depends(FilterTime),
query: List[BillType] = Query(default=BillType.all(), description='账单类型'),
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
query = {"fund_id": fund_id, "user_id": user.id, "bill_type": {'$in': query}}
if filter_time.start_time and filter_time.end_time:
query.update({'create_time': filter_time.to_mongodb_query()})
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(query)
cursor = cursor.skip(skip).sort([(sort_by.sort_field, sort_by.sort_direction)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[Any](data=result, **page.dict(), total=len(result))
return response
...@@ -33,12 +33,12 @@ async def create( ...@@ -33,12 +33,12 @@ async def create(
create_model = fund_type_map[create_fund.fund_type](**create_fund.dict(), nodes=[], **user.db_save()) create_model = fund_type_map[create_fund.fund_type](**create_fund.dict(), nodes=[], **user.db_save())
create_model.nav = create_model.base_nav create_model.nav = create_model.base_nav
response = Response[fund_type_map[create_fund.fund_type]](data=create_model.dict()) data = create_model.dict()
insert_data = create_model.dict() response_model = fund_type_map[data['fund_type']]
await fund_collect.insert_one(insert_data) await fund_collect.insert_one(data)
await calculate_nav_task(create_model.id, scheduler, fund_collect, user.id) await calculate_nav_task(data['id'], scheduler, fund_collect, user.id)
return response return Response[response_model](data=response_model(**data))
@router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金', @router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金',
...@@ -56,10 +56,9 @@ async def update( ...@@ -56,10 +56,9 @@ async def update(
data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user.id}, {'$set': db_update_data}, data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user.id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER) return_document=ReturnDocument.AFTER)
assert data, NotFundError() assert data, NotFundError()
if data['fund_type'] == FundType.staking: # return fund_type_map[data['fund_type']](data=fund_type_map[data['fund_type']](**data))
return Response[StakingFund](data=StakingFund(**data)) response_model = fund_type_map[data['fund_type']]
else: return Response[response_model](data=response_model(**data))
return Response[NormalFund](data=NormalFund(**data))
@router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金', @router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金',
...@@ -71,11 +70,9 @@ async def get( ...@@ -71,11 +70,9 @@ async def get(
): ):
data = await fund_collect.find_one({'id': fund_id, 'user_id': user.id}) data = await fund_collect.find_one({'id': fund_id, 'user_id': user.id})
assert data, NotFundError() assert data, NotFundError()
if data['fund_type'] == FundType.staking:
response = Response[StakingFund](data=StakingFund(**data).dict()) response_model = fund_type_map[data['fund_type']]
else: return Response[response_model](data=response_model(**data))
response = Response[NormalFund](data=NormalFund(**data).dict())
return response
@router.get('/', @router.get('/',
......
import traceback import traceback
import pytz import pytz
import uvicorn as uvicorn import uvicorn
from apscheduler.triggers import interval from apscheduler.triggers import interval
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
...@@ -83,6 +83,8 @@ async def startup(): ...@@ -83,6 +83,8 @@ async def startup():
misfire_grace_time=600 * 3 misfire_grace_time=600 * 3
) )
if settings.env == 'LOCAL':
return
app.state.scheduler.start() app.state.scheduler.start()
app.state.scheduler.print_jobs() app.state.scheduler.print_jobs()
......
import datetime import datetime
import uuid import uuid
from enum import IntEnum
from typing import Any, TypeVar, Generic, Optional, List from typing import Any, TypeVar, Generic, Optional, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -14,6 +15,24 @@ class Page(BaseModel): ...@@ -14,6 +15,24 @@ class Page(BaseModel):
page_size: int = Field(default=10, description="每页数据条数") page_size: int = Field(default=10, description="每页数据条数")
class SortDirection(IntEnum):
desc = -1
asc = 1
class SortParams(BaseModel):
sort_field: str = Field(default='create_time', description='排序字段')
sort_direction: SortDirection = Field(default=SortDirection.desc, description='排序方向,-1为倒序,1为正序')
class FilterTime(BaseModel):
start_time: Optional[int] = Field(None, description='查询开始时间')
end_time: Optional[int] = Field(None, description="查询结束时间")
def to_mongodb_query(self):
return {"$gte": self.start_time, "$lte": self.end_time}
class Response(GenericModel, Generic[DataT]): class Response(GenericModel, Generic[DataT]):
data: DataT | None data: DataT | None
message: str = 'success' message: str = 'success'
......
import datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from model import BaseCreateModel from model import BaseCreateModel
...@@ -18,6 +18,10 @@ class BillType(str, Enum): ...@@ -18,6 +18,10 @@ class BillType(str, Enum):
# 调整 # 调整
adjust = "adjust" adjust = "adjust"
@staticmethod
def all():
return list(map(lambda c: c.value, BillType))
# 接口传入模型 # 接口传入模型
......
...@@ -5,3 +5,7 @@ import pytz ...@@ -5,3 +5,7 @@ import pytz
def utc_now(): def utc_now():
return datetime.datetime.utcnow().replace(tzinfo=pytz.UTC) return datetime.datetime.utcnow().replace(tzinfo=pytz.UTC)
def timestamp_to_datetime(utc_timestamp):
return datetime.datetime.utcfromtimestamp(utc_timestamp).replace(tzinfo=pytz.UTC)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment