account_service.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import time
  2. from typing import Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, func, text, delete
  5. from datetime import datetime, timedelta
  6. from app.utils.search import apply_keyword_search_stmt
  7. from app.utils.pagination import paginate
  8. from app.core.biz_exception import NotFoundError
  9. from app.models.account import Account
  10. from app.schemas.account import AccountCreate, AccountUpdate
  11. class AccountService:
  12. @staticmethod
  13. async def add_account(db: AsyncSession, payload: AccountCreate):
  14. """
  15. 添加新账号到数据库
  16. """
  17. rec = Account(**payload.dict())
  18. db.add(rec)
  19. await db.commit()
  20. await db.refresh(rec)
  21. return rec
  22. @staticmethod
  23. async def update_account(db: AsyncSession, account_id: int, payload: AccountUpdate):
  24. stmt = select(Account).where(Account.id == account_id)
  25. rec = (await db.execute(stmt)).scalar_one_or_none()
  26. if not rec:
  27. raise NotFoundError("Account not exist")
  28. for k, v in payload.dict(exclude_unset=True).items():
  29. setattr(rec, k, v)
  30. await db.commit()
  31. await db.refresh(rec)
  32. return rec
  33. @staticmethod
  34. async def remove_account(db: AsyncSession, account_id: int):
  35. stmt = select(Account).where(Account.id == account_id)
  36. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  37. if not db_obj:
  38. raise NotFoundError(f"Account not exist")
  39. await db.delete(db_obj)
  40. await db.commit()
  41. return True
  42. @staticmethod
  43. async def list_all(
  44. db: AsyncSession,
  45. page: int = 0,
  46. size: int = 10,
  47. keyword: Optional[str] = None
  48. ):
  49. stmt = select(Account)
  50. stmt = apply_keyword_search_stmt(
  51. stmt=stmt,
  52. model=Account,
  53. keyword=keyword,
  54. fields=["id", "pool_name", "username", "password", "extra_data", "status"],
  55. ).order_by(Account.id.desc())
  56. return await paginate(db, stmt, page, size)
  57. @staticmethod
  58. async def get_next_account(
  59. db: AsyncSession,
  60. pool_name: str,
  61. account_cd: int
  62. ) -> Account:
  63. stmt = (
  64. select(Account)
  65. .where(
  66. Account.pool_name == pool_name,
  67. Account.status == 'active',
  68. Account.next_use_time <= func.utc_timestamp()
  69. )
  70. .order_by(Account.next_use_time.asc())
  71. .limit(1)
  72. .with_for_update(skip_locked=True)
  73. )
  74. result = await db.execute(stmt)
  75. obj = result.scalar_one_or_none()
  76. if not obj:
  77. raise NotFoundError('Account not found')
  78. obj.next_use_time = func.utc_timestamp() + text(f"INTERVAL {account_cd} SECOND")
  79. await db.commit()
  80. await db.refresh(obj)
  81. return obj