payment_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # app/services/payment_service.py
  2. import time
  3. import stripe
  4. import random
  5. import uuid
  6. from typing import Dict, List, Optional
  7. from redis.asyncio import Redis
  8. from decimal import Decimal, ROUND_HALF_UP
  9. from datetime import datetime, timedelta
  10. from sqlalchemy.ext.asyncio import AsyncSession
  11. from sqlalchemy import select
  12. from app.utils.search import apply_keyword_search_stmt
  13. from app.utils.pagination import paginate
  14. from app.core.biz_exception import NotFoundError, BizLogicError
  15. from app.models.user import VasUser
  16. from app.models.order import VasOrder
  17. from app.models.vas_task import VasTask
  18. from app.models.payment import VasPayment
  19. from app.models.ticket import VasTicket
  20. from app.models.payment_event import VasPaymentEvent
  21. from app.models.product_routing import VasProductRouting
  22. from app.models.verification_token import VasVerificationToken
  23. from app.models.payment_provider import VasPaymentProvider
  24. from app.models.payment_qr import VasPaymentQR
  25. from app.models.payment_confirmation import VasPaymentConfirmation
  26. from app.schemas.payment import VasPaymentCreate
  27. from app.schemas.payment_confirmation import VasPaymentConfirmationCreate, VasPaymentConfirmationUpdate
  28. from app.services.notification_service import NotificationService
  29. class PaymentService:
  30. # --------------------------------------------------
  31. # 创建支付(统一入口)
  32. # --------------------------------------------------
  33. @staticmethod
  34. async def create_payment(
  35. db: AsyncSession,
  36. payload: VasPaymentCreate,
  37. rate_table: Dict,
  38. redis_client: Redis
  39. ) -> VasPayment:
  40. # ① 锁住订单(防并发)
  41. stmt = (
  42. select(VasOrder)
  43. .where(VasOrder.id == payload.order_id)
  44. .with_for_update()
  45. )
  46. order = (await db.execute(stmt)).scalar_one_or_none()
  47. if not order:
  48. raise NotFoundError("Order not found")
  49. # ② 是否已有 pending payment(幂等)
  50. stmt = select(VasPayment).where(
  51. VasPayment.order_id == order.id,
  52. VasPayment.status == "pending",
  53. )
  54. active_payment = (await db.execute(stmt)).scalar_one_or_none()
  55. if active_payment:
  56. if active_payment.provider == payload.provider:
  57. return active_payment
  58. else:
  59. active_payment.status = "failed"
  60. # ③ 根据 provider 创建
  61. if payload.provider in ("wechat", "alipay"):
  62. payment = await PaymentService.create_offline_payment(
  63. db=db,
  64. order=order,
  65. provider_name=payload.provider,
  66. rate_table=rate_table,
  67. )
  68. await db.commit()
  69. return payment
  70. if payload.provider == "stripe":
  71. payment = await PaymentService.create_stripe_payment(
  72. db=db,
  73. order=order,
  74. rate_table=rate_table,
  75. )
  76. await db.commit()
  77. return payment
  78. raise BizLogicError("Unsupported provider")
  79. @staticmethod
  80. async def confirm_by_user(
  81. db: AsyncSession,
  82. payload: VasPaymentConfirmationCreate,
  83. current_user: VasUser,
  84. redis_client: Redis
  85. ):
  86. """
  87. 用户点击“我已支付”
  88. """
  89. # 1️⃣ 查询是否存在对应 payment 确认记录
  90. result = await db.execute(
  91. select(VasPaymentConfirmation)
  92. .where(VasPaymentConfirmation.payment_id == payload.payment_id)
  93. .where(VasPaymentConfirmation.user_id == current_user.id)
  94. )
  95. record = result.scalar_one_or_none()
  96. if not record:
  97. # 没有则创建一条 pending -> confirmed 记录
  98. record = VasPaymentConfirmation(
  99. payment_id=payload.payment_id,
  100. amount=payload.amount,
  101. currency=payload.currency,
  102. random_offset=payload.random_offset,
  103. user_id=current_user.id,
  104. status="pending",
  105. confirmed_at=payload.confirmed_at
  106. )
  107. db.add(record)
  108. await db.commit()
  109. await db.refresh(record)
  110. # 2️⃣ 推送异步通知给管理员
  111. # await NotificationService.create(
  112. # redis_client=redis_client,
  113. # ntype="payment_user_confirmed",
  114. # user_id=current_user.id,
  115. # channels=["wechat"],
  116. # template_id="payment_user_confirmed",
  117. # payload={
  118. # "payment_id": payload.payment_id,
  119. # "user_id": current_user.id,
  120. # "confirmed_at": record.confirmed_at.isoformat()
  121. # }
  122. # )
  123. return record
  124. @staticmethod
  125. async def confirm_by_admin(
  126. db: AsyncSession,
  127. id: int,
  128. payload: VasPaymentConfirmationUpdate,
  129. current_user: VasUser
  130. ):
  131. """
  132. 管理员确认用户的支付
  133. """
  134. # 1️⃣ 查询对应确认记录
  135. result = await db.execute(
  136. select(VasPaymentConfirmation)
  137. .where(VasPaymentConfirmation.id == id)
  138. )
  139. record = result.scalar_one_or_none()
  140. if not record:
  141. raise NotFoundError("Payment confirmation record not found")
  142. # 3️⃣ 更新管理员确认状态
  143. record.admin_id = current_user.id
  144. record.admin_confirmed_at = datetime.utcnow()
  145. record.status = 'confirmed'
  146. await PaymentService._confirm_payment_action(db, record.payment_id)
  147. await db.commit()
  148. await db.refresh(record)
  149. return record
  150. @staticmethod
  151. async def list_payment_confirmation(
  152. db: AsyncSession,
  153. keyword: Optional[str] = None,
  154. page: int = 1,
  155. size: int = 20,
  156. ):
  157. stmt = select(VasPaymentConfirmation)
  158. stmt = apply_keyword_search_stmt(
  159. stmt=stmt,
  160. model=VasPaymentConfirmation,
  161. keyword=keyword,
  162. fields=["user_id"],
  163. )
  164. stmt = stmt.order_by(VasPaymentConfirmation.id.desc())
  165. return await paginate(db, stmt, page, size)
  166. @staticmethod
  167. async def _create_task_if_not_exists(
  168. db: AsyncSession,
  169. order: VasOrder,
  170. ) -> List[VasTask]:
  171. stmt = select(VasProductRouting).where(
  172. VasProductRouting.product_id == order.product_id,
  173. VasProductRouting.is_active == 1,
  174. )
  175. result = await db.execute(stmt)
  176. routings = result.scalars().all()
  177. if not routings:
  178. return []
  179. created_tasks: List[VasTask] = []
  180. for routing in routings:
  181. exists_stmt = select(VasTask).where(
  182. VasTask.order_id == order.id,
  183. VasTask.routing_key == routing.routing_key,
  184. VasTask.script_version == routing.script_version,
  185. )
  186. exists_result = await db.execute(exists_stmt)
  187. exists = exists_result.scalar_one_or_none()
  188. if exists:
  189. continue
  190. task = VasTask(
  191. order_id=order.id,
  192. routing_key=routing.routing_key,
  193. script_version=routing.script_version,
  194. priority=routing.priority,
  195. status="pending",
  196. user_inputs=order.user_inputs,
  197. config=routing.config,
  198. attempt_count=0,
  199. notify_count=0,
  200. expire_at=datetime.utcnow() + timedelta(days=60),
  201. created_at=datetime.utcnow(),
  202. )
  203. db.add(task)
  204. created_tasks.append(task)
  205. return created_tasks
  206. @staticmethod
  207. async def confirm_payment(
  208. db: AsyncSession,
  209. payment_id: int,
  210. token: str
  211. ):
  212. # 校验验证码
  213. stmt = select(VasVerificationToken).where(
  214. VasVerificationToken.token == token,
  215. VasVerificationToken.used == 0,
  216. )
  217. token_obj = (await db.execute(stmt)).scalar_one_or_none()
  218. if not token_obj:
  219. raise BizLogicError("Token invalid")
  220. if token_obj.expire_at < datetime.utcnow():
  221. raise BizLogicError("Token expired")
  222. payment = await PaymentService._confirm_payment_action(db, payment_id)
  223. token_obj.used = 1
  224. await db.commit()
  225. return payment
  226. @staticmethod
  227. async def _confirm_payment_action(db: AsyncSession, payment_id: int):
  228. # ---------- 查找 payment ----------
  229. pay_stmt = (
  230. select(VasPayment)
  231. .where(
  232. VasPayment.id == payment_id,
  233. VasPayment.status == "pending",
  234. )
  235. .order_by(VasPayment.created_at.desc())
  236. )
  237. pay_result = await db.execute(pay_stmt)
  238. payment = pay_result.scalar_one_or_none()
  239. if not payment:
  240. raise BizLogicError("Payment not found")
  241. event = VasPaymentEvent(
  242. provider=payment.provider,
  243. event_type="payment_received",
  244. title='confirm payment',
  245. content='confirm payment by admin',
  246. parsed_amount=payment.amount,
  247. parsed_currency=payment.currency,
  248. parsed_device='',
  249. status="received",
  250. )
  251. db.add(event)
  252. await db.commit()
  253. await db.refresh(event)
  254. if payment.status in ("succeeded", "late_paid"):
  255. event.status = "duplicate"
  256. event.matched_payment_id = payment.id
  257. event.matched_order_id = payment.order_id
  258. await db.commit()
  259. return None
  260. now = datetime.utcnow()
  261. payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
  262. payment.provider_payload = {
  263. "title": "confirm by admin",
  264. "received_at": now.isoformat(),
  265. }
  266. order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
  267. order_result = await db.execute(order_stmt)
  268. order = order_result.scalar_one_or_none()
  269. if order and order.status != "paid":
  270. order.status = "paid"
  271. await PaymentService._create_task_if_not_exists(db, order)
  272. event.status = "applied"
  273. event.matched_payment_id = payment.id
  274. event.matched_order_id = payment.order_id
  275. await db.commit()
  276. await db.refresh(payment)
  277. return payment
  278. @staticmethod
  279. async def create_offline_payment(
  280. db: AsyncSession,
  281. order: VasOrder,
  282. provider_name: str,
  283. rate_table: Dict,
  284. ) -> VasPayment:
  285. payment = (
  286. await PaymentService._create_wechat_payment(db, order)
  287. if provider_name == "wechat"
  288. else await PaymentService._create_alipay_payment(db, order)
  289. )
  290. stmt = select(VasPaymentProvider).where(
  291. VasPaymentProvider.enabled == 1,
  292. VasPaymentProvider.name == provider_name,
  293. )
  294. provider = (await db.execute(stmt)).scalar_one_or_none()
  295. if not provider:
  296. raise BizLogicError("Payment provider not available")
  297. stmt = select(VasPaymentQR).where(
  298. VasPaymentQR.provider == provider_name,
  299. VasPaymentQR.is_active == 1,
  300. )
  301. qrs = (await db.execute(stmt)).scalars().all()
  302. if not qrs:
  303. raise BizLogicError("No payment QR available")
  304. qr = random.choice(qrs)
  305. payment.qr_id = qr.id
  306. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  307. exchange_rate = Decimal(rate_table[rate_key])
  308. converted = (
  309. Decimal(payment.base_amount) * exchange_rate
  310. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  311. max_discount = min(99, int(converted * Decimal("0.01")))
  312. discount = random.randint(1, max_discount) if max_discount >= 1 else 0
  313. payment.exchange_rate = exchange_rate
  314. payment.amount = int(converted) - discount
  315. payment.currency = provider.currency
  316. payment.random_offset = discount
  317. return payment
  318. @staticmethod
  319. async def create_stripe_payment(
  320. db: AsyncSession,
  321. order: VasOrder,
  322. rate_table: Dict,
  323. ) -> VasPayment:
  324. payment = await PaymentService._create_stripe_payment(db, order)
  325. stmt = select(VasPaymentProvider).where(
  326. VasPaymentProvider.enabled == 1,
  327. VasPaymentProvider.name == "stripe",
  328. )
  329. provider = (await db.execute(stmt)).scalar_one_or_none()
  330. if not provider:
  331. raise BizLogicError("Stripe provider not enabled")
  332. rate_key = f"{order.base_currency}->{provider.currency}".upper()
  333. exchange_rate = Decimal(rate_table[rate_key])
  334. converted = (
  335. Decimal(payment.base_amount) * exchange_rate
  336. ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
  337. payment.exchange_rate = exchange_rate
  338. payment.amount = int(converted)
  339. payment.currency = provider.currency
  340. payment.random_offset = 0
  341. stripe_session = PaymentService.create_checkout_session(
  342. order=order,
  343. payment=payment,
  344. success_url="https://visafly.top/dashboard",
  345. cancel_url="https://visafly.top/dashboard",
  346. )
  347. payment.payment_intent_id = stripe_session.id
  348. payment.payment_url = stripe_session.url
  349. return payment
  350. @staticmethod
  351. def create_checkout_session(
  352. order: VasOrder,
  353. payment: VasPayment,
  354. success_url: str,
  355. cancel_url: str,
  356. ):
  357. expires_at = int(time.time()) + 30 * 60
  358. return stripe.checkout.Session.create(
  359. mode="payment",
  360. payment_method_types=["card"],
  361. line_items=[
  362. {
  363. "price_data": {
  364. "currency": payment.currency.lower(),
  365. "product_data": {
  366. "name": f"Visa Service Order {order.id}",
  367. },
  368. "unit_amount": payment.amount,
  369. },
  370. "quantity": 1,
  371. }
  372. ],
  373. metadata={
  374. "order_id": order.id,
  375. "payment_id": payment.id,
  376. "user_id": order.user_id,
  377. },
  378. success_url=success_url,
  379. cancel_url=cancel_url,
  380. expires_at=expires_at,
  381. )
  382. @staticmethod
  383. async def _create_wechat_payment(
  384. db: AsyncSession,
  385. order: VasOrder,
  386. ) -> VasPayment:
  387. payment = VasPayment(
  388. order_id=order.id,
  389. provider="wechat",
  390. channel="qr_static",
  391. base_amount=order.base_amount,
  392. base_currency=order.base_currency,
  393. amount=0,
  394. currency="CNY",
  395. random_offset=0,
  396. exchange_rate=0,
  397. status="pending",
  398. expire_at=datetime.utcnow() + timedelta(minutes=30),
  399. )
  400. db.add(payment)
  401. await db.flush()
  402. return payment
  403. @staticmethod
  404. async def _create_alipay_payment(
  405. db: AsyncSession,
  406. order: VasOrder,
  407. ) -> VasPayment:
  408. payment = VasPayment(
  409. order_id=order.id,
  410. provider="alipay",
  411. channel="qr_static",
  412. base_amount=order.base_amount,
  413. base_currency=order.base_currency,
  414. amount=0,
  415. currency="CNY",
  416. random_offset=0,
  417. exchange_rate=0,
  418. status="pending",
  419. expire_at=datetime.utcnow() + timedelta(minutes=30),
  420. )
  421. db.add(payment)
  422. await db.flush()
  423. return payment
  424. @staticmethod
  425. async def _create_stripe_payment(
  426. db: AsyncSession,
  427. order: VasOrder,
  428. ) -> VasPayment:
  429. payment = VasPayment(
  430. order_id=order.id,
  431. provider="stripe",
  432. channel="online_link",
  433. base_amount=order.base_amount,
  434. base_currency=order.base_currency,
  435. amount=0,
  436. currency="EUR",
  437. random_offset=0,
  438. exchange_rate=0,
  439. status="pending",
  440. expire_at=datetime.utcnow() + timedelta(minutes=30),
  441. )
  442. db.add(payment)
  443. await db.flush()
  444. return payment
  445. @staticmethod
  446. async def list_by_order(
  447. db: AsyncSession,
  448. order_id: int,
  449. ) -> List[VasPayment]:
  450. stmt = select(VasPayment).where(
  451. VasPayment.order_id == order_id
  452. )
  453. result = await db.execute(stmt)
  454. return result.scalars().all()
  455. @staticmethod
  456. async def get_by_id(
  457. db: AsyncSession,
  458. id: int,
  459. ) -> VasPayment:
  460. stmt = select(VasPayment).where(VasPayment.id == id)
  461. return (await db.execute(stmt)).scalar_one_or_none()