payment_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # app/services/payment_service.py
  2. import time
  3. import stripe
  4. import random
  5. import uuid
  6. from typing import Dict, List, Optional
  7. from redis.asyncio import Redis
  8. from decimal import Decimal, ROUND_HALF_UP
  9. from datetime import datetime, timedelta
  10. from sqlalchemy.ext.asyncio import AsyncSession
  11. from sqlalchemy import select
  12. from app.utils.search import apply_keyword_search_stmt
  13. from app.utils.pagination import paginate
  14. from app.core.biz_exception import NotFoundError, BizLogicError
  15. from app.models.user import VasUser
  16. from app.models.order import VasOrder
  17. from app.models.vas_task import VasTask
  18. from app.models.payment import VasPayment
  19. from app.models.ticket import VasTicket
  20. from app.models.payment_event import VasPaymentEvent
  21. from app.models.product_routing import VasProductRouting
  22. from app.models.verification_token import VasVerificationToken
  23. from app.models.payment_provider import VasPaymentProvider
  24. from app.models.payment_qr import VasPaymentQR
  25. from app.models.payment_confirmation import VasPaymentConfirmation
  26. from app.schemas.payment import VasPaymentCreate, AdminUpdateStatusPayload
  27. from app.schemas.payment_confirmation import VasPaymentConfirmationCreate, VasPaymentConfirmationUpdate
  28. from app.services.notification_service import NotificationService
  29. from app.services.webhook_service import WebhookService
  30. class PaymentService:
  31. # --------------------------------------------------
  32. # 创建支付(统一入口)
  33. # --------------------------------------------------
  34. @staticmethod
  35. async def create_payment(
  36. db: AsyncSession,
  37. payload: VasPaymentCreate,
  38. rate_table: Dict,
  39. redis_client: Redis
  40. ) -> VasPayment:
  41. # ① 锁住订单(防并发)
  42. stmt = (
  43. select(VasOrder)
  44. .where(VasOrder.id == payload.order_id)
  45. .with_for_update()
  46. )
  47. order = (await db.execute(stmt)).scalar_one_or_none()
  48. if not order:
  49. raise NotFoundError("Order not found")
  50. # ② 是否已有 pending payment(幂等)
  51. stmt = select(VasPayment).where(
  52. VasPayment.order_id == order.id,
  53. VasPayment.status == "pending",
  54. )
  55. active_payment = (await db.execute(stmt)).scalar_one_or_none()
  56. if active_payment:
  57. if active_payment.provider == payload.provider:
  58. return active_payment
  59. else:
  60. active_payment.status = "failed"
  61. # ③ 根据 provider 创建
  62. if payload.provider in ("wechat", "alipay"):
  63. payment = await PaymentService.create_offline_payment(
  64. db=db,
  65. order=order,
  66. provider_name=payload.provider,
  67. rate_table=rate_table,
  68. )
  69. await db.commit()
  70. return payment
  71. if payload.provider == "stripe":
  72. payment = await PaymentService.create_stripe_payment(
  73. db=db,
  74. order=order,
  75. rate_table=rate_table,
  76. )
  77. await db.commit()
  78. return payment
  79. raise BizLogicError("Unsupported provider")
  80. @staticmethod
  81. async def confirm_by_user(
  82. db: AsyncSession,
  83. payload: VasPaymentConfirmationCreate,
  84. current_user: VasUser,
  85. redis_client: Redis
  86. ):
  87. """
  88. 用户点击“我已支付”
  89. """
  90. # 1️⃣ 查询是否存在对应 payment 确认记录
  91. result = await db.execute(
  92. select(VasPaymentConfirmation)
  93. .where(VasPaymentConfirmation.payment_id == payload.payment_id)
  94. .where(VasPaymentConfirmation.user_id == current_user.id)
  95. )
  96. record = result.scalar_one_or_none()
  97. if not record:
  98. # 没有则创建一条 pending -> confirmed 记录
  99. record = VasPaymentConfirmation(
  100. payment_id=payload.payment_id,
  101. amount=payload.amount,
  102. currency=payload.currency,
  103. random_offset=payload.random_offset,
  104. user_id=current_user.id,
  105. status="pending",
  106. confirmed_at=payload.confirmed_at
  107. )
  108. db.add(record)
  109. await db.commit()
  110. await db.refresh(record)
  111. stmt = select(VasPayment).where(
  112. VasPayment.id == payload.payment_id
  113. )
  114. payment = (await db.execute(stmt)).scalar_one_or_none()
  115. formatted_time = payload.confirmed_at.strftime('%Y-%m-%d %H:%M') + " (UTC)"
  116. # 2️⃣ 推送异步通知给管理员
  117. await NotificationService.post_wechat(
  118. redis_client=redis_client,
  119. template_id="payment_user_confirmed",
  120. payload={
  121. "order_id": payment.order_id,
  122. "payment_id": payload.payment_id,
  123. "user_email": current_user.email,
  124. "amount": payload.amount,
  125. "currency": payload.currency,
  126. "token": "",
  127. "confirmed_at": formatted_time,
  128. "provider": payment.provider
  129. }
  130. )
  131. return record
  132. @staticmethod
  133. async def confirm_by_admin(
  134. db: AsyncSession,
  135. id: int,
  136. payload: VasPaymentConfirmationUpdate,
  137. current_user: VasUser
  138. ):
  139. """
  140. 管理员确认用户的支付
  141. """
  142. # 1️⃣ 查询对应确认记录
  143. result = await db.execute(
  144. select(VasPaymentConfirmation)
  145. .where(VasPaymentConfirmation.id == id)
  146. )
  147. record = result.scalar_one_or_none()
  148. if not record:
  149. raise NotFoundError("Payment confirmation record not found")
  150. # 3️⃣ 更新管理员确认状态
  151. record.admin_id = current_user.id
  152. record.admin_confirmed_at = datetime.utcnow()
  153. record.status = 'confirmed'
  154. await PaymentService._confirm_payment_action(db, record.payment_id, 'confirmed by admin')
  155. await db.commit()
  156. await db.refresh(record)
  157. return record
  158. @staticmethod
  159. async def admin_update_status(
  160. db: AsyncSession,
  161. payment_id: int,
  162. payload: AdminUpdateStatusPayload
  163. ):
  164. """
  165. 管理员确认用户的支付
  166. """
  167. if payload.status == "succeeded":
  168. payment = await PaymentService._confirm_payment_action(db, payment_id, payload.remark)
  169. else:
  170. pay_stmt = (
  171. select(VasPayment)
  172. .where(VasPayment.id == payment_id)
  173. .order_by(VasPayment.created_at.desc())
  174. )
  175. pay_result = await db.execute(pay_stmt)
  176. payment = pay_result.scalar_one_or_none()
  177. if not payment:
  178. raise BizLogicError("Payment not found")
  179. payment.status = 'failed'
  180. await db.commit()
  181. await db.refresh(payment)
  182. await db.refresh(payment)
  183. return payment
  184. @staticmethod
  185. async def list_payment_confirmation(
  186. db: AsyncSession,
  187. keyword: Optional[str] = None,
  188. page: int = 1,
  189. size: int = 20,
  190. ):
  191. stmt = select(VasPaymentConfirmation)
  192. stmt = apply_keyword_search_stmt(
  193. stmt=stmt,
  194. model=VasPaymentConfirmation,
  195. keyword=keyword,
  196. fields=["user_id"],
  197. )
  198. stmt = stmt.order_by(VasPaymentConfirmation.id.desc())
  199. return await paginate(db, stmt, page, size)
  200. @staticmethod
  201. async def confirm_payment(
  202. db: AsyncSession,
  203. payment_id: int,
  204. token: str
  205. ):
  206. # 校验验证码
  207. stmt = select(VasVerificationToken).where(
  208. VasVerificationToken.token == token,
  209. VasVerificationToken.used == 0,
  210. )
  211. token_obj = (await db.execute(stmt)).scalar_one_or_none()
  212. if not token_obj:
  213. raise BizLogicError("Token invalid")
  214. if token_obj.expire_at < datetime.utcnow():
  215. raise BizLogicError("Token expired")
  216. payment = await PaymentService._confirm_payment_action(db, payment_id, 'confirmed by admin')
  217. token_obj.used = 1
  218. await db.commit()
  219. return payment
  220. @staticmethod
  221. async def _confirm_payment_action(db: AsyncSession, payment_id: int, remark:str):
  222. # ---------- 查找 payment ----------
  223. pay_stmt = (
  224. select(VasPayment)
  225. .where(VasPayment.id == payment_id)
  226. .order_by(VasPayment.created_at.desc())
  227. )
  228. pay_result = await db.execute(pay_stmt)
  229. payment = pay_result.scalar_one_or_none()
  230. if not payment:
  231. raise BizLogicError("Payment not found")
  232. event = VasPaymentEvent(
  233. provider=payment.provider,
  234. event_type="payment_received",
  235. title='confirm payment',
  236. content='confirm payment by admin',
  237. parsed_amount=payment.amount,
  238. parsed_currency=payment.currency,
  239. parsed_device='',
  240. status="received",
  241. )
  242. db.add(event)
  243. await db.commit()
  244. await db.refresh(event)
  245. if payment.status in ("succeeded", "late_paid"):
  246. event.status = "duplicate"
  247. event.matched_payment_id = payment.id
  248. event.matched_order_id = payment.order_id
  249. await db.commit()
  250. raise BizLogicError("Payment has been confirmed")
  251. now = datetime.utcnow()
  252. payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
  253. payment.provider_payload = {
  254. "title": remark,
  255. "received_at": now.isoformat(),
  256. }
  257. order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
  258. order_result = await db.execute(order_stmt)
  259. order = order_result.scalar_one_or_none()
  260. if order and order.status != "paid":
  261. order.status = "paid"
  262. await WebhookService._create_task_if_not_exists(db, order)
  263. event.status = "applied"
  264. event.matched_payment_id = payment.id
  265. event.matched_order_id = payment.order_id
  266. await db.commit()
  267. await db.refresh(payment)
  268. return payment
  269. @staticmethod
  270. async def create_offline_payment(
  271. db: AsyncSession,
  272. order: VasOrder,
  273. provider_name: str,
  274. rate_table: Dict,
  275. ) -> VasPayment:
  276. payment = (
  277. await PaymentService._create_wechat_payment(db, order)
  278. if provider_name == "wechat"
  279. else await PaymentService._create_alipay_payment(db, order)
  280. )
  281. stmt = select(VasPaymentProvider).where(
  282. VasPaymentProvider.enabled == 1,
  283. VasPaymentProvider.name == provider_name,
  284. )
  285. provider = (await db.execute(stmt)).scalar_one_or_none()
  286. if not provider:
  287. raise BizLogicError("Payment provider not available")
  288. stmt = select(VasPaymentQR).where(
  289. VasPaymentQR.provider == provider_name,
  290. VasPaymentQR.is_active == 1,
  291. )
  292. qrs = (await db.execute(stmt)).scalars().all()
  293. if not qrs:
  294. raise BizLogicError("No payment QR available")
  295. qr = random.choice(qrs)
  296. payment.qr_id = qr.id
  297. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  298. exchange_rate = Decimal(rate_table[rate_key])
  299. converted = (
  300. Decimal(payment.base_amount) * exchange_rate
  301. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  302. max_discount = min(99, int(converted * Decimal("0.01")))
  303. discount = random.randint(1, max_discount) if max_discount >= 1 else 0
  304. payment.exchange_rate = exchange_rate
  305. payment.amount = int(converted) - discount
  306. payment.currency = provider.currency
  307. payment.random_offset = discount
  308. return payment
  309. @staticmethod
  310. async def create_stripe_payment(
  311. db: AsyncSession,
  312. order: VasOrder,
  313. rate_table: Dict,
  314. ) -> VasPayment:
  315. payment = await PaymentService._create_stripe_payment(db, order)
  316. stmt = select(VasPaymentProvider).where(
  317. VasPaymentProvider.enabled == 1,
  318. VasPaymentProvider.name == "stripe",
  319. )
  320. provider = (await db.execute(stmt)).scalar_one_or_none()
  321. if not provider:
  322. raise BizLogicError("Stripe provider not enabled")
  323. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  324. exchange_rate = Decimal(rate_table[rate_key])
  325. converted = (
  326. Decimal(payment.base_amount) * exchange_rate
  327. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  328. payment.exchange_rate = exchange_rate
  329. payment.amount = int(converted)
  330. payment.currency = provider.currency
  331. payment.random_offset = 0
  332. stripe_session = PaymentService.create_checkout_session(
  333. order=order,
  334. payment=payment,
  335. success_url="https://visafly.top/dashboard",
  336. cancel_url="https://visafly.top/dashboard",
  337. )
  338. payment.payment_intent_id = stripe_session.id
  339. payment.payment_url = stripe_session.url
  340. return payment
  341. @staticmethod
  342. def create_checkout_session(
  343. order: VasOrder,
  344. payment: VasPayment,
  345. success_url: str,
  346. cancel_url: str,
  347. ):
  348. expires_at = int(time.time()) + 30 * 60
  349. return stripe.checkout.Session.create(
  350. mode="payment",
  351. payment_method_types=["card"],
  352. line_items=[
  353. {
  354. "price_data": {
  355. "currency": payment.currency.lower(),
  356. "product_data": {
  357. "name": f"Visa Service Order {order.id}",
  358. },
  359. "unit_amount": payment.amount,
  360. },
  361. "quantity": 1,
  362. }
  363. ],
  364. metadata={
  365. "order_id": order.id,
  366. "payment_id": payment.id,
  367. "user_id": order.user_id,
  368. },
  369. success_url=success_url,
  370. cancel_url=cancel_url,
  371. expires_at=expires_at,
  372. )
  373. @staticmethod
  374. async def _create_wechat_payment(
  375. db: AsyncSession,
  376. order: VasOrder,
  377. ) -> VasPayment:
  378. payment = VasPayment(
  379. order_id=order.id,
  380. provider="wechat",
  381. channel="qr_static",
  382. base_amount=order.base_amount,
  383. base_currency=order.base_currency,
  384. amount=0,
  385. currency="CNY",
  386. random_offset=0,
  387. exchange_rate=0,
  388. status="pending",
  389. expire_at=datetime.utcnow() + timedelta(minutes=30),
  390. )
  391. db.add(payment)
  392. await db.flush()
  393. return payment
  394. @staticmethod
  395. async def _create_alipay_payment(
  396. db: AsyncSession,
  397. order: VasOrder,
  398. ) -> VasPayment:
  399. payment = VasPayment(
  400. order_id=order.id,
  401. provider="alipay",
  402. channel="qr_static",
  403. base_amount=order.base_amount,
  404. base_currency=order.base_currency,
  405. amount=0,
  406. currency="CNY",
  407. random_offset=0,
  408. exchange_rate=0,
  409. status="pending",
  410. expire_at=datetime.utcnow() + timedelta(minutes=30),
  411. )
  412. db.add(payment)
  413. await db.flush()
  414. return payment
  415. @staticmethod
  416. async def _create_stripe_payment(
  417. db: AsyncSession,
  418. order: VasOrder,
  419. ) -> VasPayment:
  420. payment = VasPayment(
  421. order_id=order.id,
  422. provider="stripe",
  423. channel="online_link",
  424. base_amount=order.base_amount,
  425. base_currency=order.base_currency,
  426. amount=0,
  427. currency="EUR",
  428. random_offset=0,
  429. exchange_rate=0,
  430. status="pending",
  431. expire_at=datetime.utcnow() + timedelta(minutes=30),
  432. )
  433. db.add(payment)
  434. await db.flush()
  435. return payment
  436. @staticmethod
  437. async def list_by_order(
  438. db: AsyncSession,
  439. order_id: int,
  440. ) -> List[VasPayment]:
  441. stmt = select(VasPayment).where(
  442. VasPayment.order_id == order_id
  443. )
  444. result = await db.execute(stmt)
  445. return result.scalars().all()
  446. @staticmethod
  447. async def get_by_id(
  448. db: AsyncSession,
  449. id: int,
  450. ) -> VasPayment:
  451. stmt = select(VasPayment).where(VasPayment.id == id)
  452. return (await db.execute(stmt)).scalar_one_or_none()