account_service.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import time
  2. from typing import Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, delete
  5. from app.core.biz_exception import NotFoundError
  6. from app.models.account import Account
  7. from app.schemas.account import AccountCreate, LockRequest
  8. class AccountService:
  9. @staticmethod
  10. async def add_account(db: AsyncSession, payload: AccountCreate):
  11. """
  12. 添加新账号到数据库
  13. """
  14. # 检查是否存在
  15. stmt = select(Account).where(
  16. Account.pool_name == payload.pool_name,
  17. Account.username == payload.username,
  18. )
  19. obj = (await db.execute(stmt)).scalar_one_or_none()
  20. if obj:
  21. # 如果存在,更新密码,并重置为 active
  22. obj.password = payload.password
  23. obj.status = "active"
  24. if payload.extra_data:
  25. obj.extra_data = payload.extra_data
  26. await db.commit()
  27. await db.refresh(obj)
  28. return obj
  29. else:
  30. new_acc = Account(
  31. pool_name=payload.pool_name,
  32. username=payload.username,
  33. password=payload.password,
  34. extra_data=payload.extra_data
  35. )
  36. db.add(new_acc)
  37. await db.commit()
  38. await db.refresh(new_acc)
  39. return new_acc
  40. @staticmethod
  41. async def get_next_account(db: AsyncSession, pool_name: str, lock_duration: float) -> Account:
  42. now = time.time()
  43. stmt = select(Account).where(
  44. Account.pool_name == pool_name,
  45. Account.status == 'active',
  46. Account.lock_until < now
  47. ).order_by(
  48. Account.lock_until.asc()
  49. ).limit(1).with_for_update()
  50. obj = (await db.execute(stmt)).scalar_one_or_none()
  51. if not obj:
  52. raise NotFoundError('Account not found')
  53. new_lock_time = now + lock_duration
  54. obj.lock_until = new_lock_time
  55. await db.commit()
  56. await db.refresh(obj)
  57. return obj
  58. @staticmethod
  59. async def manual_lock(db: AsyncSession, payload: LockRequest):
  60. stmt = select(Account).where(
  61. Account.pool_name == payload.pool_name,
  62. Account.username == payload.username,
  63. )
  64. obj = (await db.execute(stmt)).scalar_one_or_none()
  65. if not obj:
  66. raise NotFoundError('Account not found')
  67. obj.lock_until = time.time() + payload.duration
  68. await db.commit()
  69. @staticmethod
  70. async def disable_account(db: AsyncSession, payload: LockRequest):
  71. stmt = select(Account).where(
  72. Account.pool_name == payload.pool_name,
  73. Account.username == payload.username,
  74. )
  75. obj = (await db.execute(stmt)).scalar_one_or_none()
  76. if not obj:
  77. raise NotFoundError('Account not found')
  78. obj.status = "disabled"
  79. await db.commit()