account_service.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import time
  2. from typing import Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, delete
  5. from app.utils.search import apply_keyword_search_stmt
  6. from app.utils.pagination import paginate
  7. from app.core.biz_exception import NotFoundError
  8. from app.models.account import Account
  9. from app.schemas.account import AccountCreate, LockRequest
  10. class AccountService:
  11. @staticmethod
  12. async def list_all(
  13. db: AsyncSession,
  14. page: int = 0,
  15. size: int = 10,
  16. keyword: Optional[str] = None
  17. ):
  18. stmt = select(Account)
  19. stmt = apply_keyword_search_stmt(
  20. stmt=stmt,
  21. model=Account,
  22. keyword=keyword,
  23. fields=["id", "pool_name", "username", "password", "extra_data", "status"],
  24. ).order_by(Account.id.desc())
  25. return await paginate(db, stmt, page, size)
  26. @staticmethod
  27. async def add_account(db: AsyncSession, payload: AccountCreate):
  28. """
  29. 添加新账号到数据库
  30. """
  31. # 检查是否存在
  32. stmt = select(Account).where(
  33. Account.pool_name == payload.pool_name,
  34. Account.username == payload.username,
  35. )
  36. obj = (await db.execute(stmt)).scalar_one_or_none()
  37. if obj:
  38. # 如果存在,更新密码,并重置为 active
  39. obj.password = payload.password
  40. obj.status = "active"
  41. if payload.extra_data:
  42. obj.extra_data = payload.extra_data
  43. await db.commit()
  44. await db.refresh(obj)
  45. return obj
  46. else:
  47. new_acc = Account(
  48. pool_name=payload.pool_name,
  49. username=payload.username,
  50. password=payload.password,
  51. extra_data=payload.extra_data
  52. )
  53. db.add(new_acc)
  54. await db.commit()
  55. await db.refresh(new_acc)
  56. return new_acc
  57. @staticmethod
  58. async def get_next_account(db: AsyncSession, pool_name: str, lock_duration: float) -> Account:
  59. now = time.time()
  60. stmt = select(Account).where(
  61. Account.pool_name == pool_name,
  62. Account.status == 'active',
  63. Account.lock_until < now
  64. ).order_by(
  65. Account.lock_until.asc()
  66. ).limit(1).with_for_update()
  67. obj = (await db.execute(stmt)).scalar_one_or_none()
  68. if not obj:
  69. raise NotFoundError('Account not found')
  70. new_lock_time = now + lock_duration
  71. obj.lock_until = new_lock_time
  72. await db.commit()
  73. await db.refresh(obj)
  74. return obj
  75. @staticmethod
  76. async def manual_lock(db: AsyncSession, payload: LockRequest):
  77. stmt = select(Account).where(
  78. Account.pool_name == payload.pool_name,
  79. Account.username == payload.username,
  80. )
  81. obj = (await db.execute(stmt)).scalar_one_or_none()
  82. if not obj:
  83. raise NotFoundError('Account not found')
  84. obj.lock_until = time.time() + payload.duration
  85. await db.commit()
  86. @staticmethod
  87. async def disable_account(db: AsyncSession, payload: LockRequest):
  88. stmt = select(Account).where(
  89. Account.pool_name == payload.pool_name,
  90. Account.username == payload.username,
  91. )
  92. obj = (await db.execute(stmt)).scalar_one_or_none()
  93. if not obj:
  94. raise NotFoundError('Account not found')
  95. obj.status = "disabled"
  96. await db.commit()