auth_service.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # app/services/auth_service.py
  2. import uuid
  3. import bcrypt
  4. import random
  5. import string
  6. from datetime import datetime, timedelta
  7. from typing import Dict
  8. from sqlalchemy import select
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from redis.asyncio import Redis
  11. from app.core.biz_exception import (
  12. NotFoundError,
  13. PermissionDeniedError,
  14. BizLogicError,
  15. )
  16. from app.models.user import VasUser
  17. from app.models.session import VasSession
  18. from app.models.verification_token import VasVerificationToken
  19. from app.schemas.auth import (
  20. AutoRegisterRequest,
  21. SendBindCodeRequest,
  22. SendResetCodeRequest,
  23. BindEmailRequest,
  24. ResetPasswordRequest,
  25. LoginRequest,
  26. )
  27. from app.services.notification_service import NotificationService
  28. def _random_password(length: int = 16) -> str:
  29. return "".join(
  30. random.choices(
  31. string.ascii_letters + string.digits + "!@#$%",
  32. k=length,
  33. )
  34. )
  35. ADJECTIVES = [
  36. "Silent", "Lucky", "Happy", "Brave", "Quick", "Bright",
  37. "Quiet", "Crazy", "Cool", "Swift", "Blue", "Red", "Golden"
  38. ]
  39. NOUNS = [
  40. "Fox", "Wolf", "Tiger", "Eagle", "Lion", "Bear",
  41. "River", "Cloud", "Storm", "Nova", "Shadow"
  42. ]
  43. def human_like_nickname(max_number: int = 999) -> str:
  44. adj = random.choice(ADJECTIVES)
  45. noun = random.choice(NOUNS)
  46. number = random.randint(1, max_number)
  47. return f"{adj}{noun}{number}"
  48. class AuthService:
  49. # =========================
  50. # 自动注册(游客)
  51. # =========================
  52. @staticmethod
  53. async def auto_register(
  54. db: AsyncSession,
  55. req: AutoRegisterRequest,
  56. ip: str = None,
  57. user_agent = None
  58. ) -> Dict:
  59. uid = f"usr-{uuid.uuid4().hex[:8]}"
  60. user = VasUser(
  61. id=uid,
  62. role="user",
  63. nickname=human_like_nickname(),
  64. preferred_language="en",
  65. timezone="Asia/Shanghai",
  66. register_ip=ip or '',
  67. )
  68. db.add(user)
  69. token = "tok_" + uuid.uuid4().hex
  70. session = VasSession(
  71. id=token,
  72. user_id=uid,
  73. user_agent=user_agent or '',
  74. ip=ip or '',
  75. expire_at=datetime.utcnow() + timedelta(days=7),
  76. )
  77. db.add(session)
  78. await db.commit()
  79. await db.refresh(user)
  80. return {"user": user, "token": token}
  81. # =========================
  82. # 发送绑定邮箱验证码
  83. # =========================
  84. @staticmethod
  85. async def send_bind_code(
  86. db: AsyncSession,
  87. payload: SendBindCodeRequest,
  88. auth_user: VasUser,
  89. redis_client: Redis,
  90. ):
  91. token = uuid.uuid4().hex[:6]
  92. expiration_time = datetime.utcnow() + timedelta(minutes=10)
  93. record = VasVerificationToken(
  94. token=token,
  95. expire_at=expiration_time,
  96. )
  97. db.add(record)
  98. await db.commit()
  99. await NotificationService.post_message(
  100. db=db,
  101. channel="email",
  102. payload={
  103. "template_id": "email_verification_for_bind",
  104. "receiver": payload.email,
  105. "payload": {
  106. "app_name": "Visafly",
  107. "code": token,
  108. "expiration_time": "10 minutes"
  109. },
  110. },
  111. )
  112. # =========================
  113. # 发送重置密码验证码
  114. # =========================
  115. @staticmethod
  116. async def send_reset_code(
  117. db: AsyncSession,
  118. payload: SendResetCodeRequest,
  119. redis_client: Redis,
  120. ):
  121. stmt = select(VasUser).where(
  122. VasUser.email == payload.email,
  123. VasUser.email_verified == 1,
  124. )
  125. user = (await db.execute(stmt)).scalar_one_or_none()
  126. if not user:
  127. raise BizLogicError("User not exist")
  128. expiration_time = datetime.utcnow() + timedelta(minutes=10)
  129. token = uuid.uuid4().hex[:6]
  130. record = VasVerificationToken(
  131. token=token,
  132. expire_at=expiration_time,
  133. )
  134. db.add(record)
  135. await db.commit()
  136. await NotificationService.post_message(
  137. db=db,
  138. channel="email",
  139. payload={
  140. "template_id": "email_verification_for_reset",
  141. "receiver": payload.email,
  142. "payload": {
  143. "app_name": "Visafly",
  144. "code": token,
  145. "expiration_time": "10 minutes"
  146. },
  147. },
  148. )
  149. # =========================
  150. # 绑定邮箱
  151. # =========================
  152. @staticmethod
  153. async def bind_email(
  154. db: AsyncSession,
  155. payload: BindEmailRequest,
  156. auth_user: VasUser,
  157. redis_client: Redis,
  158. ip: str = None,
  159. user_agent = None
  160. ) -> Dict:
  161. # 邮箱是否已被绑定
  162. stmt = select(VasUser).where(
  163. VasUser.email == payload.email,
  164. VasUser.email_verified == 1,
  165. )
  166. if (await db.execute(stmt)).scalar_one_or_none():
  167. raise BizLogicError("Email already bound")
  168. # 校验验证码
  169. stmt = select(VasVerificationToken).where(
  170. VasVerificationToken.token == payload.code,
  171. VasVerificationToken.used == 0,
  172. )
  173. record = (await db.execute(stmt)).scalar_one_or_none()
  174. if not record:
  175. raise BizLogicError("Token invalid")
  176. if record.expire_at < datetime.utcnow():
  177. raise BizLogicError("Token expired")
  178. user = await db.get(VasUser, auth_user.id)
  179. plain_pwd = _random_password()
  180. hashed_pwd = bcrypt.hashpw(
  181. plain_pwd.encode(),
  182. bcrypt.gensalt(),
  183. ).decode()
  184. user.email = payload.email
  185. user.password_hash = hashed_pwd
  186. user.email_verified = 1
  187. record.used = 1
  188. token = "tok_" + uuid.uuid4().hex
  189. session = VasSession(
  190. id=token,
  191. user_id=user.id,
  192. ip=ip or '',
  193. user_agent=user_agent or '',
  194. expire_at=datetime.utcnow() + timedelta(days=30),
  195. )
  196. db.add(session)
  197. await db.commit()
  198. await db.refresh(user)
  199. await NotificationService.post_message(
  200. db=db,
  201. channel="email",
  202. payload={
  203. "template_id": "login_credentials",
  204. "receiver": payload.email,
  205. "payload": {
  206. "app_name": "Visafly",
  207. "username": payload.email,
  208. "password": plain_pwd,
  209. "login_url": "https://visafly.top/login"
  210. },
  211. },
  212. )
  213. return {"user": user, "token": token}
  214. # =========================
  215. # 重置密码
  216. # =========================
  217. @staticmethod
  218. async def reset_password(
  219. db: AsyncSession,
  220. payload: ResetPasswordRequest,
  221. ) -> bool:
  222. stmt = select(VasUser).where(
  223. VasUser.email == payload.email,
  224. VasUser.email_verified == 1,
  225. )
  226. user = (await db.execute(stmt)).scalar_one_or_none()
  227. if not user:
  228. raise BizLogicError("User not exist")
  229. stmt = select(VasVerificationToken).where(
  230. VasVerificationToken.token == payload.code,
  231. VasVerificationToken.used == 0,
  232. )
  233. record = (await db.execute(stmt)).scalar_one_or_none()
  234. if not record:
  235. raise BizLogicError("Token invalid")
  236. if record.expire_at < datetime.utcnow():
  237. raise BizLogicError("Token expired")
  238. user.password_hash = bcrypt.hashpw(
  239. payload.new_password.encode(),
  240. bcrypt.gensalt(),
  241. ).decode()
  242. record.used = 1
  243. await db.commit()
  244. return True
  245. # =========================
  246. # 登录
  247. # =========================
  248. @staticmethod
  249. async def login(
  250. db: AsyncSession,
  251. req: LoginRequest,
  252. ip: str = None,
  253. user_agent: str = None
  254. ) -> Dict:
  255. stmt = select(VasUser).where(VasUser.email == req.email)
  256. user = (await db.execute(stmt)).scalar_one_or_none()
  257. if not user:
  258. raise NotFoundError("User not found")
  259. if not bcrypt.checkpw(
  260. req.password.encode(),
  261. user.password_hash.encode(),
  262. ):
  263. raise PermissionDeniedError("Password incorrect")
  264. token = "tok_" + uuid.uuid4().hex
  265. session = VasSession(
  266. id=token,
  267. user_id=user.id,
  268. user_agent=user_agent or "",
  269. ip=ip or "",
  270. expire_at=datetime.utcnow() + timedelta(days=7),
  271. )
  272. db.add(session)
  273. await db.commit()
  274. return {"user": user, "token": token}