| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314 |
- # app/services/auth_service.py
- import uuid
- import bcrypt
- import random
- import string
- from datetime import datetime, timedelta
- from typing import Dict
- from sqlalchemy import select
- from sqlalchemy.ext.asyncio import AsyncSession
- from redis.asyncio import Redis
- from app.core.biz_exception import (
- NotFoundError,
- PermissionDeniedError,
- BizLogicError,
- )
- from app.models.user import VasUser
- from app.models.session import VasSession
- from app.models.verification_token import VasVerificationToken
- from app.schemas.auth import (
- AutoRegisterRequest,
- SendBindCodeRequest,
- SendResetCodeRequest,
- BindEmailRequest,
- ResetPasswordRequest,
- LoginRequest,
- )
- from app.services.notification_service import NotificationService
- def _random_password(length: int = 16) -> str:
- return "".join(
- random.choices(
- string.ascii_letters + string.digits + "!@#$%",
- k=length,
- )
- )
- ADJECTIVES = [
- "Silent", "Lucky", "Happy", "Brave", "Quick", "Bright",
- "Quiet", "Crazy", "Cool", "Swift", "Blue", "Red", "Golden"
- ]
- NOUNS = [
- "Fox", "Wolf", "Tiger", "Eagle", "Lion", "Bear",
- "River", "Cloud", "Storm", "Nova", "Shadow"
- ]
- def human_like_nickname(max_number: int = 999) -> str:
- adj = random.choice(ADJECTIVES)
- noun = random.choice(NOUNS)
- number = random.randint(1, max_number)
- return f"{adj}{noun}{number}"
- class AuthService:
- # =========================
- # 自动注册(游客)
- # =========================
- @staticmethod
- async def auto_register(
- db: AsyncSession,
- req: AutoRegisterRequest,
- ip: str = None,
- user_agent = None
- ) -> Dict:
- uid = f"usr-{uuid.uuid4().hex[:8]}"
- user = VasUser(
- id=uid,
- role="user",
- nickname=human_like_nickname(),
- preferred_language="en",
- timezone="Asia/Shanghai",
- register_ip=ip or '',
- )
- db.add(user)
- token = "tok_" + uuid.uuid4().hex
- session = VasSession(
- id=token,
- user_id=uid,
- user_agent=user_agent or '',
- ip=ip or '',
- expire_at=datetime.utcnow() + timedelta(days=7),
- )
- db.add(session)
- await db.commit()
- await db.refresh(user)
- return {"user": user, "token": token}
- # =========================
- # 发送绑定邮箱验证码
- # =========================
- @staticmethod
- async def send_bind_code(
- db: AsyncSession,
- payload: SendBindCodeRequest,
- auth_user: VasUser,
- redis_client: Redis,
- ):
- token = uuid.uuid4().hex[:6]
- expiration_time = datetime.utcnow() + timedelta(minutes=10)
- record = VasVerificationToken(
- token=token,
- expire_at=expiration_time,
- )
- db.add(record)
- await db.commit()
- await NotificationService.post_message(
- db=db,
- channel="email",
- payload={
- "template_id": "email_verification_for_bind",
- "receiver": payload.email,
- "payload": {
- "app_name": "Visafly",
- "code": token,
- "expiration_time": "10 minutes"
- },
- },
- )
- # =========================
- # 发送重置密码验证码
- # =========================
- @staticmethod
- async def send_reset_code(
- db: AsyncSession,
- payload: SendResetCodeRequest,
- redis_client: Redis,
- ):
- stmt = select(VasUser).where(
- VasUser.email == payload.email,
- VasUser.email_verified == 1,
- )
- user = (await db.execute(stmt)).scalar_one_or_none()
- if not user:
- raise BizLogicError("User not exist")
- expiration_time = datetime.utcnow() + timedelta(minutes=10)
- token = uuid.uuid4().hex[:6]
- record = VasVerificationToken(
- token=token,
- expire_at=expiration_time,
- )
- db.add(record)
- await db.commit()
- await NotificationService.post_message(
- db=db,
- channel="email",
- payload={
- "template_id": "email_verification_for_reset",
- "receiver": payload.email,
- "payload": {
- "app_name": "Visafly",
- "code": token,
- "expiration_time": "10 minutes"
- },
- },
- )
- # =========================
- # 绑定邮箱
- # =========================
- @staticmethod
- async def bind_email(
- db: AsyncSession,
- payload: BindEmailRequest,
- auth_user: VasUser,
- redis_client: Redis,
- ip: str = None,
- user_agent = None
- ) -> Dict:
- # 邮箱是否已被绑定
- stmt = select(VasUser).where(
- VasUser.email == payload.email,
- VasUser.email_verified == 1,
- )
- if (await db.execute(stmt)).scalar_one_or_none():
- raise BizLogicError("Email already bound")
- # 校验验证码
- stmt = select(VasVerificationToken).where(
- VasVerificationToken.token == payload.code,
- VasVerificationToken.used == 0,
- )
- record = (await db.execute(stmt)).scalar_one_or_none()
- if not record:
- raise BizLogicError("Token invalid")
- if record.expire_at < datetime.utcnow():
- raise BizLogicError("Token expired")
- user = await db.get(VasUser, auth_user.id)
- plain_pwd = _random_password()
- hashed_pwd = bcrypt.hashpw(
- plain_pwd.encode(),
- bcrypt.gensalt(),
- ).decode()
- user.email = payload.email
- user.password_hash = hashed_pwd
- user.email_verified = 1
- record.used = 1
- token = "tok_" + uuid.uuid4().hex
- session = VasSession(
- id=token,
- user_id=user.id,
- ip=ip or '',
- user_agent=user_agent or '',
- expire_at=datetime.utcnow() + timedelta(days=30),
- )
- db.add(session)
- await db.commit()
- await db.refresh(user)
-
- await NotificationService.post_message(
- db=db,
- channel="email",
- payload={
- "template_id": "login_credentials",
- "receiver": payload.email,
- "payload": {
- "app_name": "Visafly",
- "username": payload.email,
- "password": plain_pwd,
- "login_url": "https://visafly.top/login"
- },
- },
- )
- return {"user": user, "token": token}
- # =========================
- # 重置密码
- # =========================
- @staticmethod
- async def reset_password(
- db: AsyncSession,
- payload: ResetPasswordRequest,
- ) -> bool:
- stmt = select(VasUser).where(
- VasUser.email == payload.email,
- VasUser.email_verified == 1,
- )
- user = (await db.execute(stmt)).scalar_one_or_none()
- if not user:
- raise BizLogicError("User not exist")
- stmt = select(VasVerificationToken).where(
- VasVerificationToken.token == payload.code,
- VasVerificationToken.used == 0,
- )
- record = (await db.execute(stmt)).scalar_one_or_none()
- if not record:
- raise BizLogicError("Token invalid")
- if record.expire_at < datetime.utcnow():
- raise BizLogicError("Token expired")
- user.password_hash = bcrypt.hashpw(
- payload.new_password.encode(),
- bcrypt.gensalt(),
- ).decode()
- record.used = 1
- await db.commit()
- return True
- # =========================
- # 登录
- # =========================
- @staticmethod
- async def login(
- db: AsyncSession,
- req: LoginRequest,
- ip: str = None,
- user_agent: str = None
- ) -> Dict:
- stmt = select(VasUser).where(VasUser.email == req.email)
- user = (await db.execute(stmt)).scalar_one_or_none()
- if not user:
- raise NotFoundError("User not found")
- if not bcrypt.checkpw(
- req.password.encode(),
- user.password_hash.encode(),
- ):
- raise PermissionDeniedError("Password incorrect")
- token = "tok_" + uuid.uuid4().hex
- session = VasSession(
- id=token,
- user_id=user.id,
- user_agent=user_agent or "",
- ip=ip or "",
- expire_at=datetime.utcnow() + timedelta(days=7),
- )
- db.add(session)
- await db.commit()
- return {"user": user, "token": token}
|