payment_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # app/services/payment_service.py
  2. import time
  3. import stripe
  4. import random
  5. import uuid
  6. from typing import Dict, List, Optional
  7. from redis.asyncio import Redis
  8. from decimal import Decimal, ROUND_HALF_UP
  9. from datetime import datetime, timedelta
  10. from sqlalchemy.ext.asyncio import AsyncSession
  11. from sqlalchemy import select
  12. from app.utils.search import apply_keyword_search_stmt
  13. from app.utils.pagination import paginate
  14. from app.core.biz_exception import NotFoundError, BizLogicError
  15. from app.models.user import VasUser
  16. from app.models.order import VasOrder
  17. from app.models.payment import VasPayment
  18. from app.models.ticket import VasTicket
  19. from app.models.payment_event import VasPaymentEvent
  20. from app.models.product_routing import VasProductRouting
  21. from app.models.verification_token import VasVerificationToken
  22. from app.models.payment_provider import VasPaymentProvider
  23. from app.models.payment_qr import VasPaymentQR
  24. from app.models.payment_confirmation import VasPaymentConfirmation
  25. from app.schemas.payment import VasPaymentCreate, AdminUpdateStatusPayload
  26. from app.schemas.payment_confirmation import VasPaymentConfirmationCreate, VasPaymentConfirmationUpdate
  27. from app.services.notification_service import NotificationService
  28. from app.services.webhook_service import WebhookService
  29. class PaymentService:
  30. # --------------------------------------------------
  31. # 创建支付(统一入口)
  32. # --------------------------------------------------
  33. @staticmethod
  34. async def create_payment(
  35. db: AsyncSession,
  36. payload: VasPaymentCreate,
  37. rate_table: Dict,
  38. redis_client: Redis
  39. ) -> VasPayment:
  40. # ① 锁住订单(防并发)
  41. stmt = (
  42. select(VasOrder)
  43. .where(VasOrder.id == payload.order_id)
  44. .with_for_update()
  45. )
  46. order = (await db.execute(stmt)).scalar_one_or_none()
  47. if not order:
  48. raise NotFoundError("Order not found")
  49. # ② 是否已有 pending payment(幂等)
  50. stmt = select(VasPayment).where(
  51. VasPayment.order_id == order.id,
  52. VasPayment.status == "pending",
  53. )
  54. active_payment = (await db.execute(stmt)).scalar_one_or_none()
  55. if active_payment:
  56. if active_payment.provider == payload.provider:
  57. return active_payment
  58. else:
  59. active_payment.status = "failed"
  60. # ③ 根据 provider 创建
  61. if payload.provider in ("wechat", "alipay"):
  62. payment = await PaymentService.create_offline_payment(
  63. db=db,
  64. order=order,
  65. provider_name=payload.provider,
  66. rate_table=rate_table,
  67. )
  68. await db.commit()
  69. return payment
  70. if payload.provider == "stripe":
  71. payment = await PaymentService.create_stripe_payment(
  72. db=db,
  73. order=order,
  74. rate_table=rate_table,
  75. )
  76. await db.commit()
  77. return payment
  78. raise BizLogicError("Unsupported provider")
  79. @staticmethod
  80. async def confirm_by_user(
  81. db: AsyncSession,
  82. payload: VasPaymentConfirmationCreate,
  83. current_user: VasUser,
  84. redis_client: Redis
  85. ):
  86. """
  87. 用户点击“我已支付”
  88. """
  89. # 1️⃣ 查询是否存在对应 payment 确认记录
  90. result = await db.execute(
  91. select(VasPaymentConfirmation)
  92. .where(VasPaymentConfirmation.payment_id == payload.payment_id)
  93. .where(VasPaymentConfirmation.user_id == current_user.id)
  94. )
  95. record = result.scalar_one_or_none()
  96. if not record:
  97. # 没有则创建一条 pending -> confirmed 记录
  98. record = VasPaymentConfirmation(
  99. payment_id=payload.payment_id,
  100. amount=payload.amount,
  101. currency=payload.currency,
  102. random_offset=payload.random_offset,
  103. user_id=current_user.id,
  104. status="pending",
  105. confirmed_at=payload.confirmed_at
  106. )
  107. db.add(record)
  108. await db.commit()
  109. await db.refresh(record)
  110. stmt = select(VasPayment).where(
  111. VasPayment.id == payload.payment_id
  112. )
  113. payment = (await db.execute(stmt)).scalar_one_or_none()
  114. formatted_time = payload.confirmed_at.strftime('%Y-%m-%d %H:%M') + " (UTC)"
  115. # 2️⃣ 推送异步通知给管理员
  116. await NotificationService.post_wechat(
  117. redis_client=redis_client,
  118. template_id="payment_user_confirmed",
  119. payload={
  120. "order_id": payment.order_id,
  121. "payment_id": payload.payment_id,
  122. "user_email": current_user.email,
  123. "amount": payload.amount,
  124. "currency": payload.currency,
  125. "token": "",
  126. "confirmed_at": formatted_time,
  127. "provider": payment.provider
  128. }
  129. )
  130. return record
  131. @staticmethod
  132. async def confirm_by_admin(
  133. db: AsyncSession,
  134. id: int,
  135. payload: VasPaymentConfirmationUpdate,
  136. current_user: VasUser
  137. ):
  138. """
  139. 管理员确认用户的支付
  140. """
  141. # 1️⃣ 查询对应确认记录
  142. result = await db.execute(
  143. select(VasPaymentConfirmation)
  144. .where(VasPaymentConfirmation.id == id)
  145. )
  146. record = result.scalar_one_or_none()
  147. if not record:
  148. raise NotFoundError("Payment confirmation record not found")
  149. # 3️⃣ 更新管理员确认状态
  150. record.admin_id = current_user.id
  151. record.admin_confirmed_at = datetime.utcnow()
  152. record.status = 'confirmed'
  153. await PaymentService._confirm_payment_action(db, record.payment_id, 'confirmed by admin')
  154. await db.commit()
  155. await db.refresh(record)
  156. return record
  157. @staticmethod
  158. async def admin_update_status(
  159. db: AsyncSession,
  160. payment_id: int,
  161. payload: AdminUpdateStatusPayload
  162. ):
  163. """
  164. 管理员确认用户的支付
  165. """
  166. if payload.status == "succeeded":
  167. payment = await PaymentService._confirm_payment_action(db, payment_id, payload.remark)
  168. else:
  169. pay_stmt = (
  170. select(VasPayment)
  171. .where(VasPayment.id == payment_id)
  172. .order_by(VasPayment.created_at.desc())
  173. )
  174. pay_result = await db.execute(pay_stmt)
  175. payment = pay_result.scalar_one_or_none()
  176. if not payment:
  177. raise BizLogicError("Payment not found")
  178. payment.status = 'failed'
  179. await db.commit()
  180. await db.refresh(payment)
  181. await db.refresh(payment)
  182. return payment
  183. @staticmethod
  184. async def list_payment_confirmation(
  185. db: AsyncSession,
  186. keyword: Optional[str] = None,
  187. page: int = 1,
  188. size: int = 20,
  189. ):
  190. stmt = select(VasPaymentConfirmation)
  191. stmt = apply_keyword_search_stmt(
  192. stmt=stmt,
  193. model=VasPaymentConfirmation,
  194. keyword=keyword,
  195. fields=["user_id"],
  196. )
  197. stmt = stmt.order_by(VasPaymentConfirmation.id.desc())
  198. return await paginate(db, stmt, page, size)
  199. @staticmethod
  200. async def confirm_payment(
  201. db: AsyncSession,
  202. payment_id: int,
  203. token: str
  204. ):
  205. # 校验验证码
  206. stmt = select(VasVerificationToken).where(
  207. VasVerificationToken.token == token,
  208. VasVerificationToken.used == 0,
  209. )
  210. token_obj = (await db.execute(stmt)).scalar_one_or_none()
  211. if not token_obj:
  212. raise BizLogicError("Token invalid")
  213. if token_obj.expire_at < datetime.utcnow():
  214. raise BizLogicError("Token expired")
  215. payment = await PaymentService._confirm_payment_action(db, payment_id, 'confirmed by admin')
  216. token_obj.used = 1
  217. await db.commit()
  218. return payment
  219. @staticmethod
  220. async def _confirm_payment_action(db: AsyncSession, payment_id: int, remark:str):
  221. # ---------- 查找 payment ----------
  222. pay_stmt = (
  223. select(VasPayment)
  224. .where(VasPayment.id == payment_id)
  225. .order_by(VasPayment.created_at.desc())
  226. )
  227. pay_result = await db.execute(pay_stmt)
  228. payment = pay_result.scalar_one_or_none()
  229. if not payment:
  230. raise BizLogicError("Payment not found")
  231. event = VasPaymentEvent(
  232. provider=payment.provider,
  233. event_type="payment_received",
  234. title='confirm payment',
  235. content='confirm payment by admin',
  236. parsed_amount=payment.amount,
  237. parsed_currency=payment.currency,
  238. parsed_device='',
  239. status="received",
  240. )
  241. db.add(event)
  242. await db.commit()
  243. await db.refresh(event)
  244. if payment.status in ("succeeded", "late_paid"):
  245. event.status = "duplicate"
  246. event.matched_payment_id = payment.id
  247. event.matched_order_id = payment.order_id
  248. await db.commit()
  249. raise BizLogicError("Payment has been confirmed")
  250. now = datetime.utcnow()
  251. payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
  252. payment.provider_payload = {
  253. "title": remark,
  254. "received_at": now.isoformat(),
  255. }
  256. order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
  257. order_result = await db.execute(order_stmt)
  258. order = order_result.scalar_one_or_none()
  259. if order and order.status != "paid":
  260. order.status = "paid"
  261. await WebhookService._create_task_if_not_exists(db, order)
  262. event.status = "applied"
  263. event.matched_payment_id = payment.id
  264. event.matched_order_id = payment.order_id
  265. await db.commit()
  266. await db.refresh(payment)
  267. return payment
  268. @staticmethod
  269. async def create_offline_payment(
  270. db: AsyncSession,
  271. order: VasOrder,
  272. provider_name: str,
  273. rate_table: Dict,
  274. ) -> VasPayment:
  275. payment = (
  276. await PaymentService._create_wechat_payment(db, order)
  277. if provider_name == "wechat"
  278. else await PaymentService._create_alipay_payment(db, order)
  279. )
  280. stmt = select(VasPaymentProvider).where(
  281. VasPaymentProvider.enabled == 1,
  282. VasPaymentProvider.name == provider_name,
  283. )
  284. provider = (await db.execute(stmt)).scalar_one_or_none()
  285. if not provider:
  286. raise BizLogicError("Payment provider not available")
  287. stmt = select(VasPaymentQR).where(
  288. VasPaymentQR.provider == provider_name,
  289. VasPaymentQR.is_active == 1,
  290. )
  291. qrs = (await db.execute(stmt)).scalars().all()
  292. if not qrs:
  293. raise BizLogicError("No payment QR available")
  294. qr = random.choice(qrs)
  295. payment.qr_id = qr.id
  296. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  297. exchange_rate = Decimal(rate_table[rate_key])
  298. converted = (
  299. Decimal(payment.base_amount) * exchange_rate
  300. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  301. max_discount = min(99, int(converted * Decimal("0.01")))
  302. discount = random.randint(1, max_discount) if max_discount >= 1 else 0
  303. payment.exchange_rate = exchange_rate
  304. payment.amount = int(converted) - discount
  305. payment.currency = provider.currency
  306. payment.random_offset = discount
  307. return payment
  308. @staticmethod
  309. async def create_stripe_payment(
  310. db: AsyncSession,
  311. order: VasOrder,
  312. rate_table: Dict,
  313. ) -> VasPayment:
  314. payment = await PaymentService._create_stripe_payment(db, order)
  315. stmt = select(VasPaymentProvider).where(
  316. VasPaymentProvider.enabled == 1,
  317. VasPaymentProvider.name == "stripe",
  318. )
  319. provider = (await db.execute(stmt)).scalar_one_or_none()
  320. if not provider:
  321. raise BizLogicError("Stripe provider not enabled")
  322. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  323. exchange_rate = Decimal(rate_table[rate_key])
  324. converted = (
  325. Decimal(payment.base_amount) * exchange_rate
  326. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  327. payment.exchange_rate = exchange_rate
  328. payment.amount = int(converted)
  329. payment.currency = provider.currency
  330. payment.random_offset = 0
  331. stripe_session = PaymentService.create_checkout_session(
  332. order=order,
  333. payment=payment,
  334. success_url="https://visafly.top/dashboard",
  335. cancel_url="https://visafly.top/dashboard",
  336. )
  337. payment.payment_intent_id = stripe_session.id
  338. payment.payment_url = stripe_session.url
  339. return payment
  340. @staticmethod
  341. def create_checkout_session(
  342. order: VasOrder,
  343. payment: VasPayment,
  344. success_url: str,
  345. cancel_url: str,
  346. ):
  347. expires_at = int(time.time()) + 30 * 60
  348. return stripe.checkout.Session.create(
  349. mode="payment",
  350. payment_method_types=["card"],
  351. line_items=[
  352. {
  353. "price_data": {
  354. "currency": payment.currency.lower(),
  355. "product_data": {
  356. "name": f"Visa Service Order {order.id}",
  357. },
  358. "unit_amount": payment.amount,
  359. },
  360. "quantity": 1,
  361. }
  362. ],
  363. metadata={
  364. "order_id": order.id,
  365. "payment_id": payment.id,
  366. "user_id": order.user_id,
  367. },
  368. success_url=success_url,
  369. cancel_url=cancel_url,
  370. expires_at=expires_at,
  371. )
  372. @staticmethod
  373. async def _create_wechat_payment(
  374. db: AsyncSession,
  375. order: VasOrder,
  376. ) -> VasPayment:
  377. payment = VasPayment(
  378. order_id=order.id,
  379. provider="wechat",
  380. channel="qr_static",
  381. base_amount=order.base_amount,
  382. base_currency=order.base_currency,
  383. amount=0,
  384. currency="CNY",
  385. random_offset=0,
  386. exchange_rate=0,
  387. status="pending",
  388. expire_at=datetime.utcnow() + timedelta(minutes=30),
  389. )
  390. db.add(payment)
  391. await db.flush()
  392. return payment
  393. @staticmethod
  394. async def _create_alipay_payment(
  395. db: AsyncSession,
  396. order: VasOrder,
  397. ) -> VasPayment:
  398. payment = VasPayment(
  399. order_id=order.id,
  400. provider="alipay",
  401. channel="qr_static",
  402. base_amount=order.base_amount,
  403. base_currency=order.base_currency,
  404. amount=0,
  405. currency="CNY",
  406. random_offset=0,
  407. exchange_rate=0,
  408. status="pending",
  409. expire_at=datetime.utcnow() + timedelta(minutes=30),
  410. )
  411. db.add(payment)
  412. await db.flush()
  413. return payment
  414. @staticmethod
  415. async def _create_stripe_payment(
  416. db: AsyncSession,
  417. order: VasOrder,
  418. ) -> VasPayment:
  419. payment = VasPayment(
  420. order_id=order.id,
  421. provider="stripe",
  422. channel="online_link",
  423. base_amount=order.base_amount,
  424. base_currency=order.base_currency,
  425. amount=0,
  426. currency="EUR",
  427. random_offset=0,
  428. exchange_rate=0,
  429. status="pending",
  430. expire_at=datetime.utcnow() + timedelta(minutes=30),
  431. )
  432. db.add(payment)
  433. await db.flush()
  434. return payment
  435. @staticmethod
  436. async def list_by_order(
  437. db: AsyncSession,
  438. order_id: int,
  439. ) -> List[VasPayment]:
  440. stmt = select(VasPayment).where(
  441. VasPayment.order_id == order_id
  442. )
  443. result = await db.execute(stmt)
  444. return result.scalars().all()
  445. @staticmethod
  446. async def get_by_id(
  447. db: AsyncSession,
  448. id: int,
  449. ) -> VasPayment:
  450. stmt = select(VasPayment).where(VasPayment.id == id)
  451. return (await db.execute(stmt)).scalar_one_or_none()