auth_service.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # app/services/auth_service.py
  2. import uuid
  3. import bcrypt
  4. import random
  5. import string
  6. from datetime import datetime, timedelta
  7. from typing import Dict
  8. from sqlalchemy import select
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from redis.asyncio import Redis
  11. from app.core.biz_exception import (
  12. NotFoundError,
  13. PermissionDeniedError,
  14. BizLogicError,
  15. )
  16. from app.models.user import VasUser
  17. from app.models.session import VasSession
  18. from app.models.verification_token import VasVerificationToken
  19. from app.schemas.auth import (
  20. AutoRegisterRequest,
  21. SendBindCodeRequest,
  22. SendResetCodeRequest,
  23. BindEmailRequest,
  24. ResetPasswordRequest,
  25. LoginRequest,
  26. )
  27. from app.services.notification_service import NotificationService
  28. def _random_password(length: int = 16) -> str:
  29. return "".join(
  30. random.choices(
  31. string.ascii_letters + string.digits + "!@#$%",
  32. k=length,
  33. )
  34. )
  35. class AuthService:
  36. # =========================
  37. # 自动注册(游客)
  38. # =========================
  39. @staticmethod
  40. async def auto_register(
  41. db: AsyncSession,
  42. req: AutoRegisterRequest,
  43. ip: str = None,
  44. user_agent = None
  45. ) -> Dict:
  46. uid = f"usr-{uuid.uuid4().hex[:8]}"
  47. user = VasUser(
  48. id=uid,
  49. role="user",
  50. nickname="anonymous visitor",
  51. preferred_language="en",
  52. timezone="Asia/Shanghai",
  53. register_ip=ip or '',
  54. )
  55. db.add(user)
  56. token = "tok_" + uuid.uuid4().hex
  57. session = VasSession(
  58. id=token,
  59. user_id=uid,
  60. user_agent=user_agent or '',
  61. ip=ip or '',
  62. expire_at=datetime.utcnow() + timedelta(days=7),
  63. )
  64. db.add(session)
  65. await db.commit()
  66. await db.refresh(user)
  67. return {"user": user, "token": token}
  68. # =========================
  69. # 发送绑定邮箱验证码
  70. # =========================
  71. @staticmethod
  72. async def send_bind_code(
  73. db: AsyncSession,
  74. payload: SendBindCodeRequest,
  75. auth_user: VasUser,
  76. redis_client: Redis,
  77. ):
  78. token = uuid.uuid4().hex[:6]
  79. record = VasVerificationToken(
  80. token=token,
  81. expire_at=datetime.utcnow() + timedelta(minutes=30),
  82. )
  83. db.add(record)
  84. await db.commit()
  85. await NotificationService.create(
  86. redis_client=redis_client,
  87. ntype="email_verification",
  88. user_id=auth_user.id,
  89. channels=["email"],
  90. template_id="email_verification_for_bind",
  91. payload={"token": token},
  92. )
  93. # =========================
  94. # 发送重置密码验证码
  95. # =========================
  96. @staticmethod
  97. async def send_reset_code(
  98. db: AsyncSession,
  99. payload: SendResetCodeRequest,
  100. redis_client: Redis,
  101. ):
  102. stmt = select(VasUser).where(
  103. VasUser.email == payload.email,
  104. VasUser.email_verified == 1,
  105. )
  106. user = (await db.execute(stmt)).scalar_one_or_none()
  107. if not user:
  108. raise BizLogicError("User not exist")
  109. token = uuid.uuid4().hex[:6]
  110. record = VasVerificationToken(
  111. token=token,
  112. expire_at=datetime.utcnow() + timedelta(minutes=30),
  113. )
  114. db.add(record)
  115. await db.commit()
  116. await NotificationService.create(
  117. redis_client=redis_client,
  118. ntype="email_verification",
  119. user_id=user.id,
  120. channels=["email"],
  121. template_id="email_verification_for_reset",
  122. payload={"token": token},
  123. )
  124. # =========================
  125. # 绑定邮箱
  126. # =========================
  127. @staticmethod
  128. async def bind_email(
  129. db: AsyncSession,
  130. payload: BindEmailRequest,
  131. auth_user: VasUser,
  132. redis_client: Redis,
  133. ip: str = None,
  134. user_agent = None
  135. ) -> Dict:
  136. # 邮箱是否已被绑定
  137. stmt = select(VasUser).where(
  138. VasUser.email == payload.email,
  139. VasUser.email_verified == 1,
  140. )
  141. if (await db.execute(stmt)).scalar_one_or_none():
  142. raise BizLogicError("Email already bound")
  143. # 校验验证码
  144. stmt = select(VasVerificationToken).where(
  145. VasVerificationToken.token == payload.code,
  146. VasVerificationToken.used == 0,
  147. )
  148. record = (await db.execute(stmt)).scalar_one_or_none()
  149. if not record:
  150. raise BizLogicError("Token invalid")
  151. if record.expire_at < datetime.utcnow():
  152. raise BizLogicError("Token expired")
  153. user = await db.get(VasUser, auth_user.id)
  154. plain_pwd = _random_password()
  155. hashed_pwd = bcrypt.hashpw(
  156. plain_pwd.encode(),
  157. bcrypt.gensalt(),
  158. ).decode()
  159. user.email = payload.email
  160. user.password_hash = hashed_pwd
  161. user.email_verified = 1
  162. record.used = 1
  163. token = "tok_" + uuid.uuid4().hex
  164. session = VasSession(
  165. id=token,
  166. user_id=user.id,
  167. ip=ip or '',
  168. user_agent=user_agent or '',
  169. expire_at=datetime.utcnow() + timedelta(days=30),
  170. )
  171. db.add(session)
  172. await db.commit()
  173. await db.refresh(user)
  174. await NotificationService.create(
  175. redis_client=redis_client,
  176. ntype="login_credentials",
  177. user_id=user.id,
  178. channels=["email"],
  179. template_id="login_credentials",
  180. payload={
  181. "username": payload.email,
  182. "password": plain_pwd,
  183. },
  184. )
  185. return {"user": user, "token": token}
  186. # =========================
  187. # 重置密码
  188. # =========================
  189. @staticmethod
  190. async def reset_password(
  191. db: AsyncSession,
  192. payload: ResetPasswordRequest,
  193. ) -> bool:
  194. stmt = select(VasUser).where(
  195. VasUser.email == payload.email,
  196. VasUser.email_verified == 1,
  197. )
  198. user = (await db.execute(stmt)).scalar_one_or_none()
  199. if not user:
  200. raise BizLogicError("User not exist")
  201. stmt = select(VasVerificationToken).where(
  202. VasVerificationToken.token == payload.code,
  203. VasVerificationToken.used == 0,
  204. )
  205. record = (await db.execute(stmt)).scalar_one_or_none()
  206. if not record:
  207. raise BizLogicError("Token invalid")
  208. if record.expire_at < datetime.utcnow():
  209. raise BizLogicError("Token expired")
  210. user.password_hash = bcrypt.hashpw(
  211. payload.new_password.encode(),
  212. bcrypt.gensalt(),
  213. ).decode()
  214. record.used = 1
  215. await db.commit()
  216. return True
  217. # =========================
  218. # 登录
  219. # =========================
  220. @staticmethod
  221. async def login(
  222. db: AsyncSession,
  223. req: LoginRequest,
  224. ip: str = None,
  225. user_agent: str = None
  226. ) -> Dict:
  227. stmt = select(VasUser).where(VasUser.email == req.email)
  228. user = (await db.execute(stmt)).scalar_one_or_none()
  229. if not user:
  230. raise NotFoundError("User not found")
  231. if not bcrypt.checkpw(
  232. req.password.encode(),
  233. user.password_hash.encode(),
  234. ):
  235. raise PermissionDeniedError("Password incorrect")
  236. token = "tok_" + uuid.uuid4().hex
  237. session = VasSession(
  238. id=token,
  239. user_id=user.id,
  240. user_agent=user_agent or "",
  241. ip=ip or "",
  242. expire_at=datetime.utcnow() + timedelta(days=7),
  243. )
  244. db.add(session)
  245. await db.commit()
  246. return {"user": user, "token": token}