Commit f819bb2b authored by confusion's avatar confusion

修改查询接口

parent c903b6cd
import datetime
from typing import Union, List, Any
from loguru import logger
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from motor.core import AgnosticCollection
from pymongo import ReturnDocument
from pymongo.operations import UpdateOne
from dependencies import get_current_user, get_fund_collect, get_bill_collect
from exception.db import NotFundError
from model import Response, Page, PageResponse
from model.bill import PCFBill, ExchangeBill, BillType, CreatePCFBill, CreateExchangeBill, StakingBill, CreateStaking, \
from model import Response, Page, PageResponse, SortParams, FilterTime
from model.bill import PCFBill, ExchangeBill, BillType, CreatePCFBill, CreateExchangeBill, StakingBill, \
AdjustBill, CreateAdjustBill, UpdatePCFBill, UpdateExchangeBill, UpdateStakingBill, UpdateAdjustBill
from service.bill import update_bill
from tools.jwt_tools import User
......@@ -17,6 +17,7 @@ router = APIRouter()
@router.post('/pcf/',
response_model=Response[PCFBill],
tags=['新增'],
summary='添加申购赎回账目',
description='添加申购赎回账目')
async def create_pcf(
......@@ -46,6 +47,7 @@ async def create_pcf(
@router.post('/exchange/',
response_model=Response[ExchangeBill],
tags=['新增'],
summary='添加置换币账目',
description='添加置换币账目')
async def create_exchange(
......@@ -83,6 +85,7 @@ async def create_exchange(
@router.post('/adjust/',
response_model=Response[AdjustBill],
tags=['新增'],
summary='添加调整账目',
description='添加调整账目')
async def create_adjust(
......@@ -108,83 +111,65 @@ async def create_adjust(
return response
@router.get('/exchange/{fund_id}/',
response_model=PageResponse[ExchangeBill],
summary='查询置换记录',
description='')
async def query_exchange_bill(
fund_id: str,
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result))
return response
# @router.get('/exchange/{fund_id}/',
# response_model=PageResponse[ExchangeBill],
# summary='查询置换记录',
# description='')
# async def query_exchange(
# fund_id: str,
# page: Page = Depends(Page),
# user: User = Depends(get_current_user),
# bill_collect: AgnosticCollection = Depends(get_bill_collect),
# ):
# skip = (page.page - 1) * page.page_size
# cursor = bill_collect.find(
# {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange})
# cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
# result = await cursor.to_list(length=None)
# response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result))
# return response
@router.get('/exchange/{fund_id}/',
response_model=PageResponse[ExchangeBill],
summary='查询置换记录',
description='')
async def query_exchange(
fund_id: str,
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.exchange})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[ExchangeBill](data=result, **page.dict(), total=len(result))
return response
@router.get('/staking/{fund_id}/',
response_model=PageResponse[StakingBill],
summary='查询质押记录',
description='')
async def query_staking(
fund_id: str,
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.staking})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[StakingBill](data=result, **page.dict(), total=len(result))
return response
# @router.get('/staking/{fund_id}/',
# response_model=PageResponse[StakingBill],
# summary='查询质押记录',
# description='')
# async def query_staking(
# fund_id: str,
# page: Page = Depends(Page),
# user: User = Depends(get_current_user),
# bill_collect: AgnosticCollection = Depends(get_bill_collect),
# ):
# skip = (page.page - 1) * page.page_size
# cursor = bill_collect.find(
# {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.staking})
# cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
# result = await cursor.to_list(length=None)
# response = PageResponse[StakingBill](data=result, **page.dict(), total=len(result))
# return response
@router.get('/adjust/{fund_id}/',
response_model=PageResponse[AdjustBill],
summary='查询调整记录',
description='')
async def query_adjust(
fund_id: str,
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(
{"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.adjust})
cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[AdjustBill](data=result, **page.dict(), total=len(result))
return response
# @router.get('/adjust/{fund_id}/',
# response_model=PageResponse[AdjustBill],
# summary='查询调整记录',
# description='')
# async def query_adjust(
# fund_id: str,
# page: Page = Depends(Page),
# user: User = Depends(get_current_user),
# bill_collect: AgnosticCollection = Depends(get_bill_collect),
# ):
# skip = (page.page - 1) * page.page_size
# cursor = bill_collect.find(
# {"fund_id": fund_id, "user_id": user.id, "bill_type": BillType.adjust})
# cursor = cursor.skip(skip).sort([('create_time', -1)]).limit(page.page_size)
# result = await cursor.to_list(length=None)
# response = PageResponse[AdjustBill](data=result, **page.dict(), total=len(result))
# return response
@router.put('/{fund_id}/pcf/',
@router.put('/pcf/{fund_id}/',
tags=['更新'],
response_model=Response[PCFBill],
summary='更新申购赎回记录',
description='')
......@@ -208,7 +193,8 @@ async def update_pcf_bill(
return response
@router.put('/{fund_id}/exchange/',
@router.put('/exchange/{fund_id}/',
tags=['更新'],
response_model=Response[ExchangeBill],
summary='更新置换记录',
description='')
......@@ -232,7 +218,8 @@ async def update_exchange_bill(
return response
@router.put('/{fund_id}/staking/',
@router.put('/staking/{fund_id}/',
tags=['更新'],
response_model=Response[StakingBill],
summary='更新申购赎回调整记录',
description='')
......@@ -256,7 +243,8 @@ async def update_staking_bill(
return response
@router.put('/{fund_id}/adjust/',
@router.put('/adjust/{fund_id}/',
tags=['更新'],
response_model=Response[AdjustBill],
summary='更新调整记录',
description='')
......@@ -278,3 +266,29 @@ async def update_adjust_bill(
res_model=AdjustBill
)
return response
@router.get('/{fund_id}/',
tags=["查询"],
response_model=PageResponse[Union[PCFBill, ExchangeBill, StakingBill, AdjustBill]],
summary='查询账单记录',
description='查询账单记录')
async def query_bill(
fund_id: str,
sort_by: SortParams = Depends(SortParams),
filter_time: FilterTime = Depends(FilterTime),
query: List[BillType] = Query(default=BillType.all(), description='账单类型'),
page: Page = Depends(Page),
user: User = Depends(get_current_user),
bill_collect: AgnosticCollection = Depends(get_bill_collect),
):
query = {"fund_id": fund_id, "user_id": user.id, "bill_type": {'$in': query}}
if filter_time.start_time and filter_time.end_time:
query.update({'create_time': filter_time.to_mongodb_query()})
skip = (page.page - 1) * page.page_size
cursor = bill_collect.find(query)
cursor = cursor.skip(skip).sort([(sort_by.sort_field, sort_by.sort_direction)]).limit(page.page_size)
result = await cursor.to_list(length=None)
response = PageResponse[Any](data=result, **page.dict(), total=len(result))
return response
......@@ -33,12 +33,12 @@ async def create(
create_model = fund_type_map[create_fund.fund_type](**create_fund.dict(), nodes=[], **user.db_save())
create_model.nav = create_model.base_nav
response = Response[fund_type_map[create_fund.fund_type]](data=create_model.dict())
insert_data = create_model.dict()
data = create_model.dict()
response_model = fund_type_map[data['fund_type']]
await fund_collect.insert_one(insert_data)
await calculate_nav_task(create_model.id, scheduler, fund_collect, user.id)
return response
await fund_collect.insert_one(data)
await calculate_nav_task(data['id'], scheduler, fund_collect, user.id)
return Response[response_model](data=response_model(**data))
@router.put('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='更新基金',
......@@ -56,10 +56,9 @@ async def update(
data = await fund_collect.find_one_and_update({'id': fund_id, 'user_id': user.id}, {'$set': db_update_data},
return_document=ReturnDocument.AFTER)
assert data, NotFundError()
if data['fund_type'] == FundType.staking:
return Response[StakingFund](data=StakingFund(**data))
else:
return Response[NormalFund](data=NormalFund(**data))
# return fund_type_map[data['fund_type']](data=fund_type_map[data['fund_type']](**data))
response_model = fund_type_map[data['fund_type']]
return Response[response_model](data=response_model(**data))
@router.get('/{fund_id}/', response_model=Union[Response[StakingFund], Response[NormalFund]], summary='查询基金',
......@@ -71,11 +70,9 @@ async def get(
):
data = await fund_collect.find_one({'id': fund_id, 'user_id': user.id})
assert data, NotFundError()
if data['fund_type'] == FundType.staking:
response = Response[StakingFund](data=StakingFund(**data).dict())
else:
response = Response[NormalFund](data=NormalFund(**data).dict())
return response
response_model = fund_type_map[data['fund_type']]
return Response[response_model](data=response_model(**data))
@router.get('/',
......
import traceback
import pytz
import uvicorn as uvicorn
import uvicorn
from apscheduler.triggers import interval
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
......@@ -83,6 +83,8 @@ async def startup():
misfire_grace_time=600 * 3
)
if settings.env == 'LOCAL':
return
app.state.scheduler.start()
app.state.scheduler.print_jobs()
......
import datetime
import uuid
from enum import IntEnum
from typing import Any, TypeVar, Generic, Optional, List
from pydantic import BaseModel, Field
......@@ -14,6 +15,24 @@ class Page(BaseModel):
page_size: int = Field(default=10, description="每页数据条数")
class SortDirection(IntEnum):
desc = -1
asc = 1
class SortParams(BaseModel):
sort_field: str = Field(default='create_time', description='排序字段')
sort_direction: SortDirection = Field(default=SortDirection.desc, description='排序方向,-1为倒序,1为正序')
class FilterTime(BaseModel):
start_time: Optional[int] = Field(None, description='查询开始时间')
end_time: Optional[int] = Field(None, description="查询结束时间")
def to_mongodb_query(self):
return {"$gte": self.start_time, "$lte": self.end_time}
class Response(GenericModel, Generic[DataT]):
data: DataT | None
message: str = 'success'
......
import datetime
from enum import Enum
from typing import List, Optional
from typing import Optional
from pydantic import BaseModel, Field
from model import BaseCreateModel
......@@ -18,6 +18,10 @@ class BillType(str, Enum):
# 调整
adjust = "adjust"
@staticmethod
def all():
return list(map(lambda c: c.value, BillType))
# 接口传入模型
......
......@@ -5,3 +5,7 @@ import pytz
def utc_now():
return datetime.datetime.utcnow().replace(tzinfo=pytz.UTC)
def timestamp_to_datetime(utc_timestamp):
return datetime.datetime.utcfromtimestamp(utc_timestamp).replace(tzinfo=pytz.UTC)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment