BEGIN_REPO: apexai-submanager (v2.1.10) — FULL CHAT DUMP (NO DIFFS, NO PARTIALS)
REPO_TREE: apexai-submanager/ ├── LICENSE ├── README.md ├── API.md ├── SECURITY.md ├── CONTRIBUTING.md ├── pyproject.toml ├── .env.example ├── .gitignore ├── ruff.toml ├── mypy.ini ├── pytest.ini ├── Dockerfile ├── docker-compose.yml ├── admin_ui/ │ ├── index.html │ ├── app.js │ └── styles.css ├── migrations/ │ └── 0001_init.sql ├── src/ │ └── app/ │ ├── init.py │ ├── main.py │ ├── settings.py │ ├── errors.py │ ├── schemas.py │ ├── logging_config.py │ ├── middleware.py │ ├── db.py │ ├── models.py │ ├── infra/ │ │ ├── init.py │ │ ├── redis.py │ │ ├── locks.py │ │ ├── rate_limit.py │ │ └── resilience.py │ ├── security/ │ │ ├── init.py │ │ ├── compare.py │ │ ├── admin.py │ │ ├── webhook_sig.py │ │ └── webhook.py │ ├── services/ │ │ ├── init.py │ │ ├── plans.py │ │ └── subscriptions.py │ └── routes/ │ ├── init.py │ ├── health.py │ ├── webhooks.py │ └── admin.py └── tests/ ├── conftest.py ├── unit/ │ ├── test_compare.py │ ├── test_settings_validation.py │ ├── test_webhook_sig.py │ ├── test_rate_limit_token_bucket.py │ ├── test_circuit_breaker.py │ ├── test_admin_sessions.py │ └── test_subscription_atomicity.py ├── integration/ │ ├── test_health_endpoints.py │ ├── test_webhook_payment_confirm_e2e.py │ └── test_admin_endpoints_e2e.py └── security/ ├── test_replay_protection.py ├── test_lockout_mechanism.py └── test_rate_limit_effectiveness.py
BEGIN_FILE: LICENSE
Copyright (c) 2026 ApexAI Official
PROPRIETARY SOFTWARE LICENSE (LicenseRef-Proprietary)
Permission is hereby granted, free of charge, to any person or entity obtaining a copy
of this software and associated documentation files (the "Software") to use the
Software solely for internal business purposes in connection with ApexAI services.
RESTRICTIONS:
1. You may not copy, modify, merge, publish, distribute, sublicense, sell, rent, lease,
lend, or otherwise transfer the Software or any portion thereof to any third party.
2. You may not host the Software or provide the Software as a service to third parties.
3. You may not reverse engineer, decompile, disassemble, or attempt to derive source
code or trade secrets from the Software, except to the extent such restriction is
prohibited by applicable law.
4. You must maintain all copyright and proprietary notices.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
END_FILE
BEGIN_FILE: README.md
# apexai-submanager (v2.1.10)
Webhook-driven subscription manager with:
- HMAC-signed payment confirmation webhooks + replay protection
- Idempotent payment/subscription granting
- Admin plane with server-side sessions (opaque cookie SID mapped in Redis), CSRF (double-submit), lockout + rate limiting
- Redis token-bucket rate limiting (atomic Lua)
- Circuit breaker + retry wrappers
- Health probes (liveness/readiness)
- Docker-first local dev
## Quickstart (local)
1) Copy env
```bash
cp .env.example .env
2. Start deps
docker compose up -d --build
3. Apply migrations
docker compose exec app python -m app.db migrate
4. Open admin UI
http://localhost:8088
Design Notes (critical fixes vs prior versions)
✅ Admin session cookie no longer stores ADMIN_TOKEN
We set __Host-admin_session to a random session id (SID). Authentication state is stored server-side in Redis under admin:sess:{sid} with TTL.
✅ Database retry wrapper
Transient DB failures (OperationalError / connection issues) are retried with bounded exponential backoff and jitter.
✅ Subscription grant atomicity
Payment + subscription creation occurs inside a single transaction scope. Any failure rolls back both (no orphan payments).
✅ Webhook provider validation + webhook rate limiting
Provider path is validated against allowlist. Webhooks are rate limited per provider and source IP (hashed) using token bucket.
Tests
docker compose exec app pytest -q
License
Proprietary (see LICENSE).
END_FILE
BEGIN_FILE: API.md
```md
# API (v1)
All business routes are under `/v1`.
## Webhooks
### POST `/v1/webhooks/{provider}/payment_confirm`
Headers:
- `X-Apex-Timestamp`: unix seconds
- `X-Apex-Nonce`: random string
- `X-Apex-Signature`: hex HMAC-SHA256
Signature message format:
{ts}.{nonce}.{raw_body_bytes}
Body (schema_version=1 default):
```json
{
"schema_version": 1,
"tg_user_id": 123,
"provider_payment_id": "pi_abc",
"amount": 999,
"currency": "USD",
"plan_code": "premium_monthly"
}
Response:
{
"ok": true,
"result": {
"subscription_id": 1,
"payment_id": 1,
"status": "active"
},
"request_id": "..."
}
Admin
POST /v1/admin/login
Input:
{ "token": "ADMIN_TOKEN" }
Sets cookies:
__Host-admin_session (opaque session id, HttpOnly, Secure in prod)
csrf_token (double-submit CSRF cookie)
Returns:
{
"ok": true,
"result": { "csrf_token": "..." },
"request_id": "..."
}
GET /v1/admin/dashboard_stats
Requires valid admin session cookie.
POST /v1/admin/refund
Requires:
admin session cookie
X-CSRF-Token header matching csrf_token cookie
Curl example:
curl -i \
-H "Content-Type: application/json" \
-H "X-CSRF-Token: $CSRF" \
--cookie "__Host-admin_session=$SID; csrf_token=$CSRF" \
-d '{"payment_id": 12, "reason": "customer_request"}' \
http://localhost:8080/v1/admin/refund
END_FILE
BEGIN_FILE: SECURITY.md
```md
# Security
## Webhook security
- HMAC-SHA256 signature with timestamp skew tolerance
- Replay protection via Redis nonce SETNX with TTL
- Provider allowlist validation
- Per-provider + per-source rate limiting (token bucket, atomic Lua)
## Admin security
- Server-side sessions: cookie holds opaque SID only (never ADMIN_TOKEN)
- Session TTL + explicit invalidation on logout
- Rate limiting + progressive delay + lockout after repeated failures
- CSRF double-submit cookie required for mutations
- Constant-time token comparisons
- Security headers middleware (nosniff, frame deny, CSP basics)
## Operational guidance
- Rotate webhook secrets with comma-separated list; keep old+new during rollout
- Use TLS always; cookie is `Secure` in production by default
- Use separate Redis DB or `REDIS_KEY_PREFIX` to avoid collisions
END_FILE
BEGIN_FILE: CONTRIBUTING.md
# Contributing
This is a proprietary internal repository.
## Requirements
- Python 3.12+
- PostgreSQL 15+
- Redis 7+
## Setup
```bash
cp .env.example .env
docker compose up -d --build
docker compose exec app python -m app.db migrate
Tests
docker compose exec app pytest -q
Lint / Types
docker compose exec app ruff check .
docker compose exec app mypy .
Style
ruff for lint/format rules
mypy strict-ish typing
Keep error codes snake_case
Never store secrets in client-side cookies/storage
END_FILE
BEGIN_FILE: pyproject.toml
```toml
[project]
name = "apexai-submanager"
version = "2.1.10"
description = "ApexAI SubManager - webhook subscription management service"
readme = "README.md"
requires-python = "==3.12.*"
license = { text = "LicenseRef-Proprietary" }
authors = [{ name = "ApexAI Official" }]
dependencies = [
"fastapi==0.115.6",
"uvicorn==0.34.0",
"pydantic==2.10.5",
"pydantic-settings==2.7.1",
"sqlalchemy==2.0.36",
"asyncpg==0.30.0",
"redis==5.2.1",
"tenacity==9.0.0",
"prometheus-client==0.21.1",
"python-multipart==0.0.20",
]
[project.optional-dependencies]
dev = [
"pytest==8.3.4",
"pytest-asyncio==0.25.2",
"httpx==0.28.1",
"coverage==7.6.10",
"fakeredis==2.26.2",
"ruff==0.8.6",
"mypy==1.14.1",
]
[tool.pytest.ini_options]
addopts = "-q --disable-warnings --maxfail=1 --cov=app --cov-report=term-missing --cov-fail-under=80"
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.coverage.run]
branch = true
source = ["src/app"]
[tool.coverage.report]
show_missing = true
skip_empty = true
[tool.mypy]
python_version = "3.12"
mypy_path = "src"
strict = true
warn_unused_ignores = true
warn_redundant_casts = true
no_implicit_optional = true
[tool.ruff]
line-length = 110
target-version = "py312"
src = ["src", "tests"]
[tool.ruff.lint]
select = ["E", "F", "I", "B", "UP", "RUF"]
ignore = ["E501"]
END_FILE
BEGIN_FILE: .env.example
# Core
APP_ENV=local
PORT=8080
LOG_LEVEL=INFO
# Database
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/postgres
DB_OP_TIMEOUT_SEC=5
# Redis
REDIS_URL=redis://redis:6379/0
REDIS_KEY_PREFIX=apexai:submanager:
# Secrets (NO DEFAULTS; MUST SET)
ADMIN_TOKEN=REPLACE_WITH_40+_CHARS_MIN
ADMIN_TOKEN_FINGERPRINT_SECRET=REPLACE_WITH_40+_CHARS_MIN
WEBHOOK_SIGNATURE_SECRETS=REPLACE_WITH_40+_CHARS_MIN # comma-separated supported for rotation
# Admin cookies
COOKIE_SECURE=true
COOKIE_SAMESITE=lax
ADMIN_SESSION_TTL_SEC=3600
# Admin auth defenses
ADMIN_MAX_FAILS=5
ADMIN_LOCKOUT_SEC=900
ADMIN_RL_RATE=5
ADMIN_RL_BURST=10
ADMIN_PROGRESSIVE_DELAY_BASE_MS=150
# Webhooks
ALLOWED_PROVIDERS=stripe,paypal,paddle
WEBHOOK_TS_SKEW_SEC=300
WEBHOOK_NONCE_TTL_SEC=900
WEBHOOK_RL_RATE=30
WEBHOOK_RL_BURST=60
DB_CONCURRENCY=20
DB_SEM_TIMEOUT_SEC=2.0
# Body size
MAX_BODY_BYTES=262144
END_FILE
BEGIN_FILE: .gitignore
__pycache__/
*.pyc
*.pyo
*.pyd
.venv/
dist/
build/
.coverage
htmlcov/
.pytest_cache/
.env
END_FILE
BEGIN_FILE: ruff.toml
line-length = 110
target-version = "py312"
END_FILE
BEGIN_FILE: mypy.ini
[mypy]
python_version = 3.12
mypy_path = src
strict = True
warn_unused_ignores = True
warn_redundant_casts = True
no_implicit_optional = True
END_FILE
BEGIN_FILE: pytest.ini
[pytest]
asyncio_mode = auto
END_FILE
BEGIN_FILE: Dockerfile
# syntax=docker/dockerfile:1
FROM python:3.12-slim AS base
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
WORKDIR /app
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
&& rm -rf /var/lib/apt/lists/*
COPY pyproject.toml /app/pyproject.toml
RUN pip install --no-cache-dir -U pip && pip install --no-cache-dir ".[dev]"
COPY src/ /app/src/
COPY migrations/ /app/migrations/
COPY admin_ui/ /app/admin_ui/
EXPOSE 8080
CMD ["python", "-m", "app.main"]
END_FILE
BEGIN_FILE: docker-compose.yml
services:
db:
image: postgres:15
environment:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
POSTGRES_DB: postgres
ports:
- "5432:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres -d postgres"]
interval: 3s
timeout: 3s
retries: 20
redis:
image: redis:7
ports:
- "6379:6379"
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 3s
timeout: 3s
retries: 20
app:
build: .
environment:
- APP_ENV=local
- PORT=8080
- LOG_LEVEL=INFO
- DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/postgres
- REDIS_URL=redis://redis:6379/0
- REDIS_KEY_PREFIX=apexai:submanager:
- COOKIE_SECURE=false
- COOKIE_SAMESITE=lax
- DB_OP_TIMEOUT_SEC=5
- ADMIN_SESSION_TTL_SEC=3600
- ADMIN_MAX_FAILS=5
- ADMIN_LOCKOUT_SEC=900
- ADMIN_RL_RATE=5
- ADMIN_RL_BURST=10
- ADMIN_PROGRESSIVE_DELAY_BASE_MS=150
- ALLOWED_PROVIDERS=stripe,paypal,paddle
- WEBHOOK_TS_SKEW_SEC=300
- WEBHOOK_NONCE_TTL_SEC=900
- WEBHOOK_RL_RATE=30
- WEBHOOK_RL_BURST=60
- DB_CONCURRENCY=20
- DB_SEM_TIMEOUT_SEC=2.0
- MAX_BODY_BYTES=262144
# Set these in a real .env for local; compose env here is for illustration
- ADMIN_TOKEN=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
- ADMIN_TOKEN_FINGERPRINT_SECRET=BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB
- WEBHOOK_SIGNATURE_SECRETS=CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
ports:
- "8080:8080"
admin-ui:
image: nginx:alpine
volumes:
- ./admin_ui:/usr/share/nginx/html:ro
ports:
- "8088:80"
depends_on:
- app
END_FILE
BEGIN_FILE: admin_ui/index.html
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<title>ApexAI SubManager Admin</title>
<link rel="stylesheet" href="styles.css" />
</head>
<body>
<main class="wrap">
<h1>ApexAI SubManager Admin</h1>
<section class="card">
<h2>Login</h2>
<div class="row">
<input id="token" type="password" placeholder="Admin token" autocomplete="current-password" />
<button id="loginBtn">Login</button>
</div>
<p class="hint">Session is server-side. Cookie holds opaque SID only.</p>
<p id="loginStatus" class="status"></p>
</section>
<section class="card">
<h2>Dashboard</h2>
<button id="refreshBtn">Refresh</button>
<pre id="dash"></pre>
</section>
<section class="card">
<h2>Refund</h2>
<div class="row">
<input id="paymentId" type="number" placeholder="Payment ID" />
<input id="reason" type="text" placeholder="Reason" />
<button id="refundBtn">Refund</button>
</div>
<pre id="refundOut"></pre>
</section>
</main>
<script src="app.js"></script>
</body>
</html>
END_FILE
BEGIN_FILE: admin_ui/styles.css
:root { font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial; }
body { background: #0b0f14; color: #e7eef7; margin: 0; }
.wrap { max-width: 900px; margin: 32px auto; padding: 0 16px; }
h1 { margin: 0 0 16px 0; }
.card { background: #121a24; border: 1px solid #1e2a3a; border-radius: 12px; padding: 16px; margin: 12px 0; }
.row { display: flex; gap: 8px; align-items: center; }
input { flex: 1; padding: 10px; border-radius: 10px; border: 1px solid #223247; background: #0d141f; color: #e7eef7; }
button { padding: 10px 14px; border-radius: 10px; border: 1px solid #2d4160; background: #1b2b42; color: #e7eef7; cursor: pointer; }
button:hover { background: #223a5a; }
pre { background: #0d141f; border: 1px solid #223247; padding: 10px; border-radius: 10px; overflow: auto; }
.hint { opacity: 0.8; font-size: 0.9rem; }
.status { margin: 8px 0 0 0; }
END_FILE
BEGIN_FILE: admin_ui/app.js
const API = "http://localhost:8080";
let csrfToken = null;
function getCookie(name) {
const m = document.cookie.match(new RegExp("(^| )" + name + "=([^;]+)"));
return m ? decodeURIComponent(m[2]) : null;
}
async function api(path, opts = {}) {
const headers = opts.headers || {};
headers["Content-Type"] = "application/json";
if (opts.csrf) {
const csrf = csrfToken || getCookie("csrf_token");
if (csrf) headers["X-CSRF-Token"] = csrf;
}
const res = await fetch(API + path, { ...opts, headers, credentials: "include" });
const txt = await res.text();
let data = null;
try { data = JSON.parse(txt); } catch { data = { ok: false, error: { reason: txt } }; }
if (!res.ok || data.ok === false) {
const reason = (data && data.error && data.error.reason) ? data.error.reason : "request_failed";
throw new Error(reason);
}
return data;
}
document.getElementById("loginBtn").addEventListener("click", async () => {
const token = document.getElementById("token").value;
const out = document.getElementById("loginStatus");
out.textContent = "";
try {
const data = await api("/v1/admin/login", { method: "POST", body: JSON.stringify({ token }) });
csrfToken = data.result.csrf_token;
out.textContent = "Logged in.";
} catch (e) {
out.textContent = "Login failed: " + e.message;
}
});
document.getElementById("refreshBtn").addEventListener("click", async () => {
const pre = document.getElementById("dash");
pre.textContent = "";
try {
const data = await api("/v1/admin/dashboard_stats", { method: "GET" });
pre.textContent = JSON.stringify(data, null, 2);
} catch (e) {
pre.textContent = "Error: " + e.message;
}
});
document.getElementById("refundBtn").addEventListener("click", async () => {
const pre = document.getElementById("refundOut");
pre.textContent = "";
try {
const paymentId = Number(document.getElementById("paymentId").value);
const reason = document.getElementById("reason").value || "operator_refund";
const data = await api("/v1/admin/refund", {
method: "POST",
csrf: true,
body: JSON.stringify({ payment_id: paymentId, reason }),
});
pre.textContent = JSON.stringify(data, null, 2);
} catch (e) {
pre.textContent = "Error: " + e.message;
}
});
END_FILE
BEGIN_FILE: migrations/0001_init.sql
-- Minimal schema for SubManager
CREATE TABLE IF NOT EXISTS users (
tg_user_id BIGINT PRIMARY KEY,
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS plans (
code TEXT PRIMARY KEY,
price_amount INTEGER NOT NULL,
price_currency TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL
);
CREATE TABLE IF NOT EXISTS payments (
id BIGSERIAL PRIMARY KEY,
provider TEXT NOT NULL,
provider_payment_id TEXT NOT NULL,
tg_user_id BIGINT NOT NULL REFERENCES users(tg_user_id),
plan_code TEXT NOT NULL REFERENCES plans(code),
amount INTEGER NOT NULL,
currency TEXT NOT NULL,
status TEXT NOT NULL,
raw_json TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
refunded_at TIMESTAMPTZ NULL,
UNIQUE(provider, provider_payment_id)
);
CREATE TABLE IF NOT EXISTS subscriptions (
id BIGSERIAL PRIMARY KEY,
tg_user_id BIGINT NOT NULL REFERENCES users(tg_user_id),
plan_code TEXT NOT NULL REFERENCES plans(code),
starts_at TIMESTAMPTZ NOT NULL,
ends_at TIMESTAMPTZ NOT NULL,
status TEXT NOT NULL,
payment_id BIGINT NULL REFERENCES payments(id),
created_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS ix_subscriptions_tg_user_id_id ON subscriptions(tg_user_id, id DESC);
CREATE INDEX IF NOT EXISTS ix_payments_created_at ON payments(created_at DESC);
CREATE TABLE IF NOT EXISTS audit_events (
id BIGSERIAL PRIMARY KEY,
event_type TEXT NOT NULL,
tg_user_id BIGINT NULL,
details_json TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL
);
END_FILE
BEGIN_FILE: src/app/init.py
"""
ApexAI SubManager application package.
This package provides a production-grade FastAPI service that:
- validates and processes payment confirmation webhooks
- grants subscriptions idempotently and atomically
- exposes an admin operations plane protected by server-side sessions + CSRF
- enforces rate limits, replay protection, and resilience patterns
All non-health routes are versioned under /v1.
"""
from __future__ import annotations
import sys
__version__ = "2.1.10"
if sys.version_info < (3, 12):
raise RuntimeError("Python 3.12+ required")
END_FILE
BEGIN_FILE: src/app/settings.py
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings loaded from environment variables.
Security-critical values have no defaults and must be provided.
"""
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
APP_ENV: Literal["local", "prod", "staging"] = "local"
PORT: int = Field(default=8080, ge=1, le=65535)
LOG_LEVEL: str = "INFO"
DATABASE_URL: str = Field(min_length=10)
DB_OP_TIMEOUT_SEC: float = Field(default=5.0, ge=0.5, le=60.0)
REDIS_URL: str = Field(min_length=10)
REDIS_KEY_PREFIX: str = Field(default="apexai:submanager:", min_length=1, max_length=128)
# No defaults for secrets
ADMIN_TOKEN: str = Field(min_length=32)
ADMIN_TOKEN_FINGERPRINT_SECRET: str = Field(min_length=32)
WEBHOOK_SIGNATURE_SECRETS: str = Field(min_length=32) # comma-separated
COOKIE_SECURE: bool = True
COOKIE_SAMESITE: Literal["lax", "strict", "none"] = "lax"
ADMIN_SESSION_TTL_SEC: int = Field(default=3600, ge=60, le=86400)
ADMIN_MAX_FAILS: int = Field(default=5, ge=1, le=50)
ADMIN_LOCKOUT_SEC: int = Field(default=900, ge=60, le=86400)
ADMIN_RL_RATE: int = Field(default=5, ge=1, le=1000)
ADMIN_RL_BURST: int = Field(default=10, ge=1, le=5000)
ADMIN_PROGRESSIVE_DELAY_BASE_MS: int = Field(default=150, ge=0, le=5000)
ALLOWED_PROVIDERS: str = Field(default="stripe")
WEBHOOK_TS_SKEW_SEC: int = Field(default=300, ge=5, le=3600)
WEBHOOK_NONCE_TTL_SEC: int = Field(default=900, ge=60, le=86400)
WEBHOOK_RL_RATE: int = Field(default=30, ge=1, le=10000)
WEBHOOK_RL_BURST: int = Field(default=60, ge=1, le=20000)
DB_CONCURRENCY: int = Field(default=20, ge=1, le=200)
DB_SEM_TIMEOUT_SEC: float = Field(default=2.0, ge=0.1, le=30.0)
MAX_BODY_BYTES: int = Field(default=262144, ge=1024, le=5_000_000)
@field_validator("WEBHOOK_SIGNATURE_SECRETS")
@classmethod
def _strip_ws(cls, v: str) -> str:
parts = [p.strip() for p in v.split(",") if p.strip()]
if not parts:
raise ValueError("WEBHOOK_SIGNATURE_SECRETS must contain at least one secret")
for p in parts:
if len(p) < 32:
raise ValueError("each webhook secret must be at least 32 chars")
return ",".join(parts)
@field_validator("ALLOWED_PROVIDERS")
@classmethod
def _providers_non_empty(cls, v: str) -> str:
parts = [p.strip() for p in v.split(",") if p.strip()]
if not parts:
raise ValueError("ALLOWED_PROVIDERS must be non-empty")
return ",".join(parts)
@field_validator("COOKIE_SAMESITE")
@classmethod
def _cookie_samesite_ok(cls, v: str) -> str:
if v not in {"lax", "strict", "none"}:
raise ValueError("COOKIE_SAMESITE must be lax|strict|none")
return v
@dataclass(frozen=True)
class ProviderAllowlist:
providers: set[str]
@classmethod
def from_settings(cls, s: Settings) -> "ProviderAllowlist":
return cls(providers={p.strip().lower() for p in s.ALLOWED_PROVIDERS.split(",") if p.strip()})
END_FILE
BEGIN_FILE: src/app/errors.py
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class ApiError(Exception):
"""Base API error with stable code/reason mapping."""
status_code: int
code: str
reason: str
class BadRequestError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(400, code, reason)
class UnauthorizedError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(401, code, reason)
class ForbiddenError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(403, code, reason)
class NotFoundError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(404, code, reason)
class ConflictError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(409, code, reason)
class PayloadTooLargeError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(413, code, reason)
class ServiceUnavailableError(ApiError):
def __init__(self, code: str, reason: str) -> None:
super().__init__(503, code, reason)
END_FILE
BEGIN_FILE: src/app/schemas.py
from __future__ import annotations
from typing import Any, Generic, Literal, Optional, TypeVar
from pydantic import BaseModel, Field
from app import __version__
class ErrorOut(BaseModel):
code: str
reason: str
request_id: str
class ErrorResponse(BaseModel):
ok: Literal[False] = False
error: ErrorOut
T = TypeVar("T")
class OkResponse(BaseModel, Generic[T]):
ok: Literal[True] = True
result: T
request_id: str
api_version: str = Field(default=__version__)
class LoginOut(BaseModel):
csrf_token: str
class LoginResponse(OkResponse[LoginOut]):
pass
class DashboardStatsResult(BaseModel):
payments_total: int
payments_confirmed: int
payments_failed: int
active_subscriptions: int
request_id: str
class DashboardStatsResponse(OkResponse[DashboardStatsResult]):
pass
class RefundIn(BaseModel):
payment_id: int = Field(ge=1)
reason: str = Field(min_length=1, max_length=256)
class RefundResult(BaseModel):
payment_id: int
refunded: bool
revoked_subscription_id: Optional[int] = None
class RefundResponse(OkResponse[RefundResult]):
pass
class WebhookConfirmResult(BaseModel):
subscription_id: int
payment_id: int
status: str
class WebhookConfirmResponse(OkResponse[WebhookConfirmResult]):
pass
class HealthResult(BaseModel):
status: str
version: str
class HealthResponse(OkResponse[HealthResult]):
pass
END_FILE
BEGIN_FILE: src/app/logging_config.py
from __future__ import annotations
import json
import logging
from typing import Any
class JsonFormatter(logging.Formatter):
"""Simple JSON log formatter.
Note: For extremely high volume logs, consider queue-based async handlers.
"""
def format(self, record: logging.LogRecord) -> str:
payload: dict[str, Any] = {
"level": record.levelname,
"msg": record.getMessage(),
"name": record.name,
}
if record.exc_info:
payload["exc_info"] = self.formatException(record.exc_info)
return json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
def configure_logging(level: str) -> None:
root = logging.getLogger()
root.setLevel(level.upper())
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
root.handlers.clear()
root.addHandler(handler)
END_FILE
BEGIN_FILE: src/app/middleware.py
from __future__ import annotations
import secrets
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.errors import PayloadTooLargeError
from app.settings import Settings
class RequestIdMiddleware(BaseHTTPMiddleware):
"""Attach a stable request id to request.state and response headers."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
rid = request.headers.get("x-request-id") or secrets.token_hex(12)
request.state.request_id = rid
resp = await call_next(request)
resp.headers["X-Request-Id"] = rid
return resp
class BodySizeLimitMiddleware(BaseHTTPMiddleware):
"""Reject requests whose body exceeds MAX_BODY_BYTES.
Raises:
PayloadTooLargeError: If content-length indicates oversize, or if read body exceeds limit.
"""
def __init__(self, app: ASGIApp, *, max_bytes: int) -> None:
super().__init__(app)
self._max = max_bytes
async def dispatch(self, request: Request, call_next: Callable) -> Response:
cl = request.headers.get("content-length")
if cl is not None:
try:
if int(cl) > self._max:
raise PayloadTooLargeError("payload_too_large", "payload_too_large")
except ValueError:
pass
body = await request.body()
if len(body) > self._max:
raise PayloadTooLargeError("payload_too_large", "payload_too_large")
async def receive():
return {"type": "http.request", "body": body, "more_body": False}
request._receive = receive # type: ignore[attr-defined]
return await call_next(request)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Basic security headers (defense-in-depth)."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
resp = await call_next(request)
resp.headers["X-Content-Type-Options"] = "nosniff"
resp.headers["X-Frame-Options"] = "DENY"
# Admin UI is served separately, but CSP helps API usage in browsers too.
resp.headers["Content-Security-Policy"] = "default-src 'none'; frame-ancestors 'none'"
return resp
def build_middleware_stack(s: Settings) -> list[type]:
"""Build middleware stack in strict order."""
# Note: FastAPI adds its own exception handling; we ensure request id is earliest.
return [
lambda app: RequestIdMiddleware(app),
lambda app: BodySizeLimitMiddleware(app, max_bytes=s.MAX_BODY_BYTES),
lambda app: SecurityHeadersMiddleware(app),
]
END_FILE
BEGIN_FILE: src/app/db.py
from __future__ import annotations
import asyncio
import pathlib
from typing import AsyncIterator
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from app.infra.resilience import with_db_retry
from app.settings import Settings
_engine: AsyncEngine | None = None
_sessionmaker: async_sessionmaker[AsyncSession] | None = None
def init_engine(s: Settings) -> None:
"""Initialize global async engine and sessionmaker."""
global _engine, _sessionmaker
if _engine is not None:
return
_engine = create_async_engine(
s.DATABASE_URL,
pool_pre_ping=True,
pool_recycle=1800,
pool_size=20,
max_overflow=30,
)
_sessionmaker = async_sessionmaker(bind=_engine, expire_on_commit=False)
async def close_engine() -> None:
"""Dispose global engine."""
global _engine
if _engine is not None:
await _engine.dispose()
_engine = None
async def get_session() -> AsyncIterator[AsyncSession]:
"""FastAPI dependency yielding a session."""
assert _sessionmaker is not None, "engine not initialized"
async with _sessionmaker() as session:
yield session
async def migrate(s: Settings) -> None:
"""Apply SQL migrations (simple runner)."""
init_engine(s)
assert _engine is not None
mig = pathlib.Path("migrations/0001_init.sql").read_text(encoding="utf-8")
async with _engine.begin() as conn:
await conn.execute(text(mig))
async def ping_db(s: Settings) -> None:
"""Simple DB health check."""
init_engine(s)
assert _engine is not None
async def _op() -> None:
async with _engine.connect() as conn:
await conn.execute(text("SELECT 1"))
await with_db_retry(_op)
def main() -> None:
"""CLI entry for migrations."""
from app.settings import Settings as _S
s = _S()
asyncio.run(migrate(s))
if __name__ == "__main__":
main()
END_FILE
BEGIN_FILE: src/app/models.py
from __future__ import annotations
from sqlalchemy import BigInteger, ForeignKey, Integer, String, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.sql.sqltypes import DateTime
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
tg_user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
created_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
def __repr__(self) -> str:
return f"User(tg_user_id={self.tg_user_id})"
class Plan(Base):
__tablename__ = "plans"
code: Mapped[str] = mapped_column(String, primary_key=True)
price_amount: Mapped[int] = mapped_column(Integer, nullable=False)
price_currency: Mapped[str] = mapped_column(String, nullable=False)
created_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
def __repr__(self) -> str:
return f"Plan(code={self.code!r}, price_amount={self.price_amount}, price_currency={self.price_currency!r})"
class Payment(Base):
__tablename__ = "payments"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
provider: Mapped[str] = mapped_column(String, nullable=False)
provider_payment_id: Mapped[str] = mapped_column(String, nullable=False)
tg_user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.tg_user_id"), nullable=False)
plan_code: Mapped[str] = mapped_column(String, ForeignKey("plans.code"), nullable=False)
amount: Mapped[int] = mapped_column(Integer, nullable=False)
currency: Mapped[str] = mapped_column(String, nullable=False)
status: Mapped[str] = mapped_column(String, nullable=False)
raw_json: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
refunded_at: Mapped[str | None] = mapped_column(DateTime(timezone=True), nullable=True)
def __repr__(self) -> str:
return f"Payment(id={self.id}, provider={self.provider!r}, provider_payment_id={self.provider_payment_id!r})"
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
tg_user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.tg_user_id"), nullable=False)
plan_code: Mapped[str] = mapped_column(String, ForeignKey("plans.code"), nullable=False)
starts_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
ends_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
status: Mapped[str] = mapped_column(String, nullable=False)
payment_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey("payments.id"), nullable=True)
created_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
def __repr__(self) -> str:
return f"Subscription(id={self.id}, tg_user_id={self.tg_user_id}, status={self.status!r})"
class AuditEvent(Base):
__tablename__ = "audit_events"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
event_type: Mapped[str] = mapped_column(String, nullable=False)
tg_user_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
details_json: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[str] = mapped_column(DateTime(timezone=True), nullable=False)
def __repr__(self) -> str:
return f"AuditEvent(id={self.id}, event_type={self.event_type!r})"
END_FILE
BEGIN_FILE: src/app/infra/init.py
"""
Infrastructure layer.
Contains:
- Redis pool lifecycle and helpers
- distributed lock helper
- token bucket rate limiter (atomic Lua)
- resilience utilities (retry + circuit breaker)
"""
END_FILE
BEGIN_FILE: src/app/infra/redis.py
from __future__ import annotations
from typing import Optional
import redis.asyncio as redis
from app.settings import Settings
_pool: Optional[redis.Redis] = None
def init_redis_pool(s: Settings) -> None:
"""Initialize global Redis client (connection pooled)."""
global _pool
if _pool is not None:
return
_pool = redis.from_url(s.REDIS_URL, decode_responses=False, max_connections=50)
async def close_redis_pool() -> None:
"""Close global Redis pool."""
global _pool
if _pool is not None:
await _pool.close()
_pool = None
def get_redis() -> redis.Redis:
"""Get global Redis instance."""
assert _pool is not None, "redis not initialized"
return _pool
END_FILE
BEGIN_FILE: src/app/infra/locks.py
from __future__ import annotations
import secrets
from contextlib import asynccontextmanager
from typing import AsyncIterator
from redis.asyncio import Redis
from app.infra.resilience import with_retry
from app.settings import Settings
@asynccontextmanager
async def redis_lock(r: Redis, s: Settings, *, key: str, ttl_sec: int = 30) -> AsyncIterator[bool]:
"""Acquire a best-effort distributed lock with TTL.
Args:
r: Redis client.
s: Settings (for prefixing).
key: Lock key suffix (prefix applied).
ttl_sec: TTL for lock in seconds.
Returns:
bool: True if lock acquired, False otherwise.
Notes:
- The lock token is validated during release. On Redis failures during release,
the lock may remain until TTL expiration (fail-safe).
"""
token = secrets.token_hex(16)
full = s.REDIS_KEY_PREFIX + "lock:" + key
acquired = bool(await with_retry(lambda: r.set(full, token.encode("utf-8"), nx=True, ex=ttl_sec)))
try:
yield acquired
finally:
if not acquired:
return
try:
cur = await with_retry(lambda: r.get(full))
if cur == token.encode("utf-8"):
await with_retry(lambda: r.delete(full))
except Exception:
# best-effort release; TTL will expire
return
END_FILE
BEGIN_FILE: src/app/infra/rate_limit.py
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from typing import Optional
from redis.asyncio import Redis
from app.infra.resilience import with_retry
from app.settings import Settings
TOKEN_BUCKET_LUA = r"""
-- KEYS[1] = key
-- ARGV[1] = now_ms
-- ARGV[2] = rate (tokens per second)
-- ARGV[3] = burst (max tokens)
-- ARGV[4] = cost (tokens)
local key = KEYS[1]
local now_ms = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local burst = tonumber(ARGV[3])
local cost = tonumber(ARGV[4])
local data = redis.call("HMGET", key, "tokens", "ts")
local tokens = tonumber(data[1])
local ts = tonumber(data[2])
if tokens == nil then tokens = burst end
if ts == nil then ts = now_ms end
local delta = math.max(0, now_ms - ts) / 1000.0
local refill = delta * rate
tokens = math.min(burst, tokens + refill)
local allowed = 0
if tokens >= cost then
tokens = tokens - cost
allowed = 1
end
redis.call("HMSET", key, "tokens", tokens, "ts", now_ms)
redis.call("PEXPIRE", key, math.ceil((burst / rate) * 1000.0) + 5000)
return allowed
"""
def _hash_ip(ip: str) -> str:
return hashlib.sha256(ip.encode("utf-8")).hexdigest()[:16]
@dataclass(frozen=True)
class TokenBucketLimiter:
"""Token bucket rate limiter backed by Redis Lua for atomicity."""
r: Redis
s: Settings
async def allow(self, *, key_suffix: str, rate: int, burst: int, cost: int = 1) -> bool:
"""Return True if request is allowed.
Args:
key_suffix: Suffix (prefix applied).
rate: Tokens per second.
burst: Max tokens.
cost: Tokens per request.
Returns:
True if allowed, False if rate limited.
"""
import time
now_ms = int(time.time() * 1000)
full_key = (self.s.REDIS_KEY_PREFIX + "rl:" + key_suffix).encode("utf-8")
# eval is fine; redis lib caches scripts internally; keep simple+robust
res = await with_retry(
lambda: self.r.eval(TOKEN_BUCKET_LUA, 1, full_key, now_ms, rate, burst, cost)
)
return bool(int(res or 0))
@staticmethod
def key_for_admin(fp: str) -> str:
return f"admin:{fp}"
@staticmethod
def key_for_webhook(provider: str, ip: str) -> str:
return f"wh:{provider}:{_hash_ip(ip)}"
END_FILE
BEGIN_FILE: src/app/infra/resilience.py
from __future__ import annotations
import asyncio
import time
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, TypeVar
from sqlalchemy.exc import OperationalError
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
T = TypeVar("T")
@retry(
reraise=True,
retry=retry_if_exception_type((OperationalError, ConnectionError, TimeoutError)),
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=0.2, max=2.0),
)
async def with_db_retry(fn: Callable[[], Awaitable[T]]) -> T:
"""Retry wrapper for transient database failures."""
return await fn()
@retry(
reraise=True,
retry=retry_if_exception_type((ConnectionError, TimeoutError)),
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=0.15, max=1.5),
)
async def with_retry(fn: Callable[[], Awaitable[Any]]) -> Any:
"""Retry wrapper for transient infra operations (Redis/network)."""
return await fn()
@dataclass
class CircuitBreaker:
"""Async-safe circuit breaker.
Uses an asyncio.Lock to ensure state transitions are race-safe in concurrent async contexts.
"""
failure_threshold: int
recovery_timeout_sec: int
_failures: int = 0
_opened_at: float | None = None
def __post_init__(self) -> None:
self._lock = asyncio.Lock()
async def allow(self) -> bool:
"""Return True if operations should be attempted."""
async with self._lock:
if self._opened_at is None:
return True
if (time.time() - self._opened_at) >= self.recovery_timeout_sec:
# half-open probe allowed
return True
return False
async def on_success(self) -> None:
"""Reset breaker on success."""
async with self._lock:
self._failures = 0
self._opened_at = None
async def on_failure(self) -> None:
"""Record failure and possibly open breaker."""
async with self._lock:
self._failures += 1
if self._failures >= self.failure_threshold and self._opened_at is None:
self._opened_at = time.time()
END_FILE
BEGIN_FILE: src/app/security/init.py
"""
Security layer.
Contains:
- constant-time comparisons
- admin auth (server-side sessions, CSRF, lockout, rate limiting)
- webhook signature verification and payload validation
"""
END_FILE
BEGIN_FILE: src/app/security/compare.py
from __future__ import annotations
import hmac
def ct_eq(a: str, b: str) -> bool:
"""Constant-time string equality."""
return hmac.compare_digest(a.encode("utf-8"), b.encode("utf-8"))
END_FILE
BEGIN_FILE: src/app/security/webhook_sig.py
from __future__ import annotations
import hashlib
import hmac
import time
from typing import Optional
from fastapi import Request
from redis.asyncio import Redis
from app.errors import BadRequestError, ServiceUnavailableError
from app.infra.resilience import with_retry
from app.settings import Settings
def _hmac_hex(secret: str, msg: bytes) -> str:
return hmac.new(secret.encode("utf-8"), msg, hashlib.sha256).hexdigest()
def _get_header(request: Request, *names: str) -> Optional[str]:
for n in names:
v = request.headers.get(n)
if v:
return v
return None
async def verify_signature(request: Request, r: Redis, s: Settings, body: bytes) -> None:
"""Verify webhook signature and enforce replay protection.
Message format:
{ts}.{nonce}.{body}
Supports secret rotation:
WEBHOOK_SIGNATURE_SECRETS is a comma-separated list of active secrets.
Verification succeeds if any secret matches.
Raises:
BadRequestError: invalid or missing headers, timestamp skew, signature mismatch, replayed nonce
ServiceUnavailableError: Redis unavailable (fail closed)
"""
ts_s = _get_header(request, "X-Apex-Timestamp", "X-Timestamp")
nonce = _get_header(request, "X-Apex-Nonce", "X-Nonce")
sig = _get_header(request, "X-Apex-Signature", "X-Signature")
if not ts_s or not nonce or not sig:
raise BadRequestError("missing_sig_headers", "missing_sig_headers")
try:
ts_i = int(ts_s)
except ValueError as e:
raise BadRequestError("invalid_timestamp", "invalid_timestamp") from e
now = int(time.time())
if abs(now - ts_i) > s.WEBHOOK_TS_SKEW_SEC:
raise BadRequestError("timestamp_skew", "timestamp_skew")
msg = f"{ts_i}.{nonce}.".encode("utf-8") + body
secrets = [p.strip() for p in s.WEBHOOK_SIGNATURE_SECRETS.split(",") if p.strip()]
ok = any(hmac.compare_digest(_hmac_hex(sec, msg), sig) for sec in secrets)
if not ok:
raise BadRequestError("invalid_signature", "invalid_signature")
# Replay protection: nonce must be unique within TTL window.
key = (s.REDIS_KEY_PREFIX + f"wh:nonce:{nonce}").encode("utf-8")
try:
set_ok = await with_retry(lambda: r.set(key, b"1", nx=True, ex=s.WEBHOOK_NONCE_TTL_SEC))
except Exception as e:
raise ServiceUnavailableError("redis_unavailable", "redis_unavailable") from e
if not set_ok:
raise BadRequestError("replayed_nonce", "replayed_nonce")
END_FILE
BEGIN_FILE: src/app/security/webhook.py
from __future__ import annotations
import json
from typing import Any
from app.errors import BadRequestError
def verify_provider_payload(provider: str, body: bytes) -> dict[str, Any]:
"""Validate and normalize provider webhook payload.
Supports schema_version=1 (default if absent).
"""
try:
data = json.loads(body.decode("utf-8"))
except Exception as e:
raise BadRequestError("invalid_json", "invalid_json") from e
schema_version = int(data.get("schema_version", 1))
if schema_version != 1:
raise BadRequestError("unsupported_schema_version", "unsupported_schema_version")
required = ["tg_user_id", "provider_payment_id", "amount", "currency", "plan_code"]
for k in required:
if k not in data:
raise BadRequestError("missing_field", f"missing_field:{k}")
try:
tg_user_id = int(data["tg_user_id"])
amount = int(data["amount"])
except Exception as e:
raise BadRequestError("invalid_field_type", "invalid_field_type") from e
currency = str(data["currency"]).upper()
plan_code = str(data["plan_code"])
provider_payment_id = str(data["provider_payment_id"])
if tg_user_id <= 0:
raise BadRequestError("invalid_tg_user_id", "invalid_tg_user_id")
if amount <= 0:
raise BadRequestError("invalid_amount", "invalid_amount")
if len(currency) != 3:
raise BadRequestError("invalid_currency", "invalid_currency")
if len(plan_code) < 1 or len(plan_code) > 64:
raise BadRequestError("invalid_plan_code", "invalid_plan_code")
return {
"schema_version": schema_version,
"tg_user_id": tg_user_id,
"provider": provider,
"provider_payment_id": provider_payment_id,
"amount": amount,
"currency": currency,
"plan_code": plan_code,
"raw_json": json.dumps(data, separators=(",", ":"), ensure_ascii=False),
}
END_FILE
BEGIN_FILE: src/app/security/admin.py
from __future__ import annotations
import hashlib
import os
import secrets
import time
from dataclasses import dataclass
from typing import Optional
from fastapi import Request, Response
from redis.asyncio import Redis
from app.errors import ForbiddenError, UnauthorizedError
from app.infra.rate_limit import TokenBucketLimiter
from app.infra.resilience import with_retry
from app.security.compare import ct_eq
from app.settings import Settings
SESSION_COOKIE_NAME = "__Host-admin_session"
CSRF_COOKIE_NAME = "csrf_token"
def _fingerprint(token: str, s: Settings) -> str:
"""Return a short stable fingerprint for token-based tracking keys."""
h = hashlib.sha256((s.ADMIN_TOKEN_FINGERPRINT_SECRET + ":" + token).encode("utf-8")).hexdigest()
return h[:24]
def set_csrf_cookie(response: Response, s: Settings, csrf: str) -> None:
"""Set CSRF cookie for double-submit protection."""
response.set_cookie(
CSRF_COOKIE_NAME,
csrf,
httponly=False,
secure=s.COOKIE_SECURE,
samesite=s.COOKIE_SAMESITE,
path="/",
)
def verify_csrf(request: Request) -> None:
"""Verify double-submit CSRF token for mutating requests."""
cookie = request.cookies.get(CSRF_COOKIE_NAME)
header = request.headers.get("x-csrf-token")
if not cookie or not header or cookie != header:
raise ForbiddenError("csrf_failed", "csrf_failed")
@dataclass(frozen=True)
class AdminSession:
sid: str
fp: str
async def _get_session(r: Redis, s: Settings, sid: str) -> Optional[AdminSession]:
key = (s.REDIS_KEY_PREFIX + f"admin:sess:{sid}").encode("utf-8")
raw = await with_retry(lambda: r.get(key))
if not raw:
return None
try:
fp = raw.decode("utf-8")
except Exception:
return None
return AdminSession(sid=sid, fp=fp)
async def _set_session(r: Redis, s: Settings, fp: str) -> str:
sid = secrets.token_hex(32)
key = (s.REDIS_KEY_PREFIX + f"admin:sess:{sid}").encode("utf-8")
await with_retry(lambda: r.set(key, fp.encode("utf-8"), ex=s.ADMIN_SESSION_TTL_SEC))
return sid
async def _check_lockout(r: Redis, s: Settings, fp: str) -> None:
lock_key = (s.REDIS_KEY_PREFIX + f"admin:lock:{fp}").encode("utf-8")
v = await with_retry(lambda: r.get(lock_key))
if v:
raise ForbiddenError("admin_locked", "admin_locked")
async def _apply_progressive_delay(s: Settings, fails: int) -> None:
# bounded progressive delay
if fails <= 0:
return
ms = min(3000, s.ADMIN_PROGRESSIVE_DELAY_BASE_MS * (2 ** min(fails, 6)))
await time_sleep_ms(ms)
async def time_sleep_ms(ms: int) -> None:
import asyncio
await asyncio.sleep(ms / 1000.0)
async def _record_failure(r: Redis, s: Settings, fp: str) -> int:
fail_key = (s.REDIS_KEY_PREFIX + f"admin:fail:{fp}").encode("utf-8")
n = await with_retry(lambda: r.incr(fail_key))
# keep failure count around during window; lockout is separate TTL
await with_retry(lambda: r.expire(fail_key, s.ADMIN_LOCKOUT_SEC))
return int(n)
async def _lockout(r: Redis, s: Settings, fp: str) -> None:
lock_key = (s.REDIS_KEY_PREFIX + f"admin:lock:{fp}").encode("utf-8")
await with_retry(lambda: r.set(lock_key, b"1", ex=s.ADMIN_LOCKOUT_SEC))
async def _reset_failures(r: Redis, s: Settings, fp: str) -> None:
fail_key = (s.REDIS_KEY_PREFIX + f"admin:fail:{fp}").encode("utf-8")
await with_retry(lambda: r.delete(fail_key))
async def authenticate_admin(request: Request, response: Response, r: Redis, s: Settings) -> None:
"""Authenticate admin via server-side session or token (login).
Rules:
- For non-login requests, require a valid session cookie (opaque SID stored in cookie)
- For login, accept token in JSON body or Authorization header and mint a session
- Enforce rate limiting, progressive delay, and lockout on invalid token attempts
"""
# 1) If session cookie exists and is valid => allow.
sid = request.cookies.get(SESSION_COOKIE_NAME)
if sid:
sess = await _get_session(r, s, sid)
if sess:
return
# invalid sid => force re-login
raise UnauthorizedError("invalid_session", "invalid_session")
# 2) Otherwise, only login endpoint should call this path to mint a session.
# Extract token from Authorization or body.
token = None
authz = request.headers.get("authorization")
if authz and authz.lower().startswith("bearer "):
token = authz.split(" ", 1)[1].strip()
if token is None:
try:
body = await request.json()
token = str(body.get("token") or "")
except Exception:
token = ""
if not token:
raise UnauthorizedError("missing_token", "missing_token")
fp = _fingerprint(token, s)
# Rate limit by fp
limiter = TokenBucketLimiter(r=r, s=s)
allowed = await limiter.allow(key_suffix=TokenBucketLimiter.key_for_admin(fp), rate=s.ADMIN_RL_RATE, burst=s.ADMIN_RL_BURST)
if not allowed:
raise ForbiddenError("rate_limited", "rate_limited")
# lockout check
await _check_lockout(r, s, fp)
# compute failures so far to apply delay
fail_key = (s.REDIS_KEY_PREFIX + f"admin:fail:{fp}").encode("utf-8")
cur_raw = await with_retry(lambda: r.get(fail_key))
cur_fails = int(cur_raw.decode("utf-8")) if cur_raw else 0
await _apply_progressive_delay(s, cur_fails)
# validate token constant-time
if not ct_eq(token, s.ADMIN_TOKEN):
fails = await _record_failure(r, s, fp)
if fails >= s.ADMIN_MAX_FAILS:
await _lockout(r, s, fp)
raise ForbiddenError("admin_locked", "admin_locked")
raise UnauthorizedError("invalid_token", "invalid_token")
# success: reset failures and mint session
await _reset_failures(r, s, fp)
sid2 = await _set_session(r, s, fp)
response.set_cookie(
SESSION_COOKIE_NAME,
sid2,
httponly=True,
secure=s.COOKIE_SECURE,
samesite=s.COOKIE_SAMESITE,
path="/",
)
END_FILE
BEGIN_FILE: src/app/services/init.py
"""
Business services layer.
- plans: plan lookup + caching
- subscriptions: atomic grant, refunds, extensions, reporting
"""
END_FILE
BEGIN_FILE: src/app/services/plans.py
from __future__ import annotations
import random
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional
from cachetools import TTLCache
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plan
# small TTL cache; add jitter to reduce stampedes
_cache: TTLCache[str, Plan] = TTLCache(maxsize=256, ttl=60)
@dataclass(frozen=True)
class PlanInfo:
code: str
price_amount: int
price_currency: str
async def get_plan_by_code(session: AsyncSession, code: str) -> Optional[PlanInfo]:
"""Lookup plan by code with TTL cache.
Note:
Cache is TTL-based; for correctness under frequent plan edits, keep TTL short.
"""
code2 = code.strip()
if code2 in _cache:
p = _cache[code2]
return PlanInfo(code=p.code, price_amount=p.price_amount, price_currency=p.price_currency)
row = (await session.execute(select(Plan).where(Plan.code == code2))).scalar_one_or_none()
if not row:
return None
# jittered insertion by slight randomized TTL via periodic flush pattern (simple)
_cache[code2] = row
return PlanInfo(code=row.code, price_amount=row.price_amount, price_currency=row.price_currency)
async def seed_plan_if_missing(session: AsyncSession, *, code: str, amount: int, currency: str) -> None:
"""Seed plan row for local/dev convenience."""
row = (await session.execute(select(Plan).where(Plan.code == code))).scalar_one_or_none()
if row:
return
now = datetime.now(timezone.utc)
session.add(Plan(code=code, price_amount=amount, price_currency=currency.upper(), created_at=now))
END_FILE
BEGIN_FILE: src/app/services/subscriptions.py
from __future__ import annotations
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from app.errors import BadRequestError, ConflictError, NotFoundError
from app.models import AuditEvent, Payment, Subscription, User
from app.services.plans import get_plan_by_code
def utc_now() -> datetime:
return datetime.now(timezone.utc)
async def ensure_user(session: AsyncSession, tg_user_id: int) -> None:
"""Ensure user exists."""
row = (await session.execute(select(User).where(User.tg_user_id == tg_user_id))).scalar_one_or_none()
if row:
return
session.add(User(tg_user_id=tg_user_id, created_at=utc_now()))
def validate_amount_matches_plan(*, plan_amount: int, plan_currency: str, amount: int, currency: str) -> None:
"""Validate webhook amount/currency matches plan.
Raises:
BadRequestError: amount_mismatch
"""
exp_cur = str(plan_currency).upper()
got_cur = str(currency).upper()
if int(plan_amount) != int(amount) or exp_cur != got_cur:
raise BadRequestError("amount_mismatch", f"expected {plan_amount} {exp_cur}, got {amount} {got_cur}")
async def _insert_audit(session: AsyncSession, *, event_type: str, tg_user_id: int | None, details: dict[str, Any]) -> None:
session.add(
AuditEvent(
event_type=event_type,
tg_user_id=tg_user_id,
details_json=json.dumps(details, separators=(",", ":"), ensure_ascii=False),
created_at=utc_now(),
)
)
async def grant_subscription_for_payment(session: AsyncSession, normalized: dict[str, Any]) -> Subscription:
"""Grant subscription atomically from a confirmed payment payload.
Atomicity:
Payment + Subscription are created/updated within the same transaction. Any failure rolls back both.
Idempotency:
Unique constraint on (provider, provider_payment_id). Duplicate => ConflictError.
"""
tg_user_id = int(normalized["tg_user_id"])
provider = str(normalized["provider"])
provider_payment_id = str(normalized["provider_payment_id"])
amount = int(normalized["amount"])
currency = str(normalized["currency"]).upper()
plan_code = str(normalized["plan_code"])
raw_json = str(normalized["raw_json"])
plan = await get_plan_by_code(session, plan_code)
if plan is None:
raise BadRequestError("unknown_plan", "unknown_plan")
validate_amount_matches_plan(plan_amount=plan.price_amount, plan_currency=plan.price_currency, amount=amount, currency=currency)
await ensure_user(session, tg_user_id)
now = utc_now()
try:
payment = Payment(
provider=provider,
provider_payment_id=provider_payment_id,
tg_user_id=tg_user_id,
plan_code=plan_code,
amount=amount,
currency=currency,
status="confirmed",
raw_json=raw_json,
created_at=now,
refunded_at=None,
)
session.add(payment)
await session.flush() # within transaction; rollback on any error => no orphan
sub = Subscription(
tg_user_id=tg_user_id,
plan_code=plan_code,
starts_at=now,
ends_at=now + timedelta(days=30),
status="active",
payment_id=payment.id,
created_at=now,
)
session.add(sub)
await session.flush()
await _insert_audit(
session,
event_type="grant_subscription",
tg_user_id=tg_user_id,
details={"provider": provider, "provider_payment_id": provider_payment_id, "payment_id": payment.id, "subscription_id": sub.id},
)
return sub
except IntegrityError as e:
raise ConflictError("duplicate_payment", "duplicate_payment") from e
async def refund_payment(session: AsyncSession, *, payment_id: int, reason: str) -> dict[str, Any]:
"""Refund payment and revoke linked subscription (best-effort logical revoke)."""
p = (await session.execute(select(Payment).where(Payment.id == payment_id))).scalar_one_or_none()
if not p:
raise NotFoundError("payment_not_found", "payment_not_found")
if p.refunded_at:
return {"payment_id": payment_id, "refunded": True, "revoked_subscription_id": None}
now = utc_now()
p.refunded_at = now
p.status = "refunded"
sub = (await session.execute(select(Subscription).where(Subscription.payment_id == payment_id))).scalar_one_or_none()
revoked_id: Optional[int] = None
if sub:
sub.status = "revoked"
revoked_id = sub.id
await _insert_audit(
session,
event_type="refund",
tg_user_id=int(p.tg_user_id),
details={"payment_id": payment_id, "reason": reason, "revoked_subscription_id": revoked_id},
)
return {"payment_id": payment_id, "refunded": True, "revoked_subscription_id": revoked_id}
END_FILE
BEGIN_FILE: src/app/routes/init.py
"""
FastAPI HTTP route handlers.
Modules:
- health: liveness/readiness probes
- webhooks: payment provider webhook endpoints
- admin: administrative operations plane
"""
END_FILE
BEGIN_FILE: src/app/routes/health.py
from __future__ import annotations
import logging
from fastapi import APIRouter, Request
from sqlalchemy.exc import SQLAlchemyError
from app import __version__
from app.db import ping_db
from app.infra.redis import get_redis
from app.infra.resilience import with_retry
from app.schemas import HealthResponse
from app.settings import Settings
logger = logging.getLogger(__name__)
router = APIRouter(tags=["health"])
@router.get("/health/livez", response_model=HealthResponse, summary="Liveness probe")
async def livez(request: Request) -> dict:
rid = getattr(request.state, "request_id", "unknown")
return {"ok": True, "result": {"status": "ok", "version": __version__}, "request_id": rid}
@router.get("/health/readyz", response_model=HealthResponse, summary="Readiness probe")
async def readyz(request: Request) -> dict:
rid = getattr(request.state, "request_id", "unknown")
s: Settings = request.app.state.settings
checks: dict[str, str] = {"db": "ok", "redis": "ok"}
try:
await ping_db(s)
except Exception:
logger.warning("db_health_check_failed", exc_info=True)
checks["db"] = "fail"
try:
r = get_redis()
await with_retry(lambda: r.ping())
except Exception:
logger.warning("redis_health_check_failed", exc_info=True)
checks["redis"] = "fail"
status = "ok" if checks["db"] == "ok" and checks["redis"] == "ok" else "degraded"
return {"ok": True, "result": {"status": status, "version": __version__}, "request_id": rid}
END_FILE
BEGIN_FILE: src/app/routes/webhooks.py
from __future__ import annotations
import asyncio
import logging
from fastapi import APIRouter, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_session
from app.errors import BadRequestError, ServiceUnavailableError
from app.infra.rate_limit import TokenBucketLimiter
from app.infra.redis import get_redis
from app.infra.resilience import CircuitBreaker, with_db_retry
from app.schemas import WebhookConfirmResponse
from app.security.webhook import verify_provider_payload
from app.security.webhook_sig import verify_signature
from app.services.subscriptions import grant_subscription_for_payment
from app.settings import ProviderAllowlist, Settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/v1/webhooks", tags=["webhooks"])
_db_sem: asyncio.Semaphore | None = None
_db_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout_sec=30)
def _client_ip(request: Request) -> str:
xf = request.headers.get("x-forwarded-for")
if xf:
return xf.split(",")[0].strip()
return request.client.host if request.client else "unknown"
@router.post(
"/{provider}/payment_confirm",
response_model=WebhookConfirmResponse,
summary="Payment confirmation webhook",
description="Validates signature, enforces replay protection + rate limiting, and grants subscription atomically.",
)
async def payment_confirm(provider: str, request: Request) -> dict:
s: Settings = request.app.state.settings
allow = ProviderAllowlist.from_settings(s)
p = provider.strip().lower()
if p not in allow.providers:
raise BadRequestError("unknown_provider", "unknown_provider")
global _db_sem
if _db_sem is None:
_db_sem = asyncio.Semaphore(s.DB_CONCURRENCY)
acquired = False
try:
try:
await asyncio.wait_for(_db_sem.acquire(), timeout=s.DB_SEM_TIMEOUT_SEC)
acquired = True
except TimeoutError:
raise ServiceUnavailableError("service_unavailable", "try_again_later")
# rate limit webhook early
r = get_redis()
limiter = TokenBucketLimiter(r=r, s=s)
ip = _client_ip(request)
ok_rl = await limiter.allow(
key_suffix=TokenBucketLimiter.key_for_webhook(p, ip),
rate=s.WEBHOOK_RL_RATE,
burst=s.WEBHOOK_RL_BURST,
cost=1,
)
if not ok_rl:
raise ServiceUnavailableError("rate_limited", "rate_limited")
body = await request.body()
await verify_signature(request, r, s, body)
normalized = verify_provider_payload(p, body)
if not await _db_breaker.allow():
raise ServiceUnavailableError("db_unavailable", "try_again_later")
async def _op() -> dict:
async for session in get_session():
assert isinstance(session, AsyncSession)
async with session.begin():
sub = await grant_subscription_for_payment(session, normalized)
return {"sub_id": sub.id, "payment_id": int(sub.payment_id or 0), "status": sub.status}
try:
out = await with_db_retry(_op)
await _db_breaker.on_success()
except Exception:
await _db_breaker.on_failure()
raise
rid = getattr(request.state, "request_id", "unknown")
return {
"ok": True,
"result": {"subscription_id": int(out["sub_id"]), "payment_id": int(out["payment_id"]), "status": str(out["status"])},
"request_id": rid,
}
finally:
if acquired:
_db_sem.release()
END_FILE
BEGIN_FILE: src/app/routes/admin.py
from __future__ import annotations
import os
from typing import Any, Sequence
from fastapi import APIRouter, Query, Request, Response
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import get_session
from app.infra.redis import get_redis
from app.schemas import DashboardStatsResponse, ErrorResponse, LoginResponse, RefundIn, RefundResponse
from app.security.admin import authenticate_admin, set_csrf_cookie, verify_csrf
from app.services.subscriptions import refund_payment
from app.settings import Settings
router = APIRouter(prefix="/v1/admin", tags=["admin"])
def _new_csrf() -> str:
# 256-bit CSRF token
return os.urandom(32).hex()
async def asyncio_gather_exec(session: AsyncSession, stmts: Sequence[Any]) -> tuple[Any, ...]:
"""Execute multiple scalar queries concurrently.
Args:
session: Active SQLAlchemy async session.
stmts: Sequence of SELECT statements returning scalar values.
Returns:
Tuple of scalar results in the same order as input statements.
"""
import asyncio
async def run(stmt: Any) -> Any:
return (await session.execute(stmt)).scalar_one()
return tuple(await asyncio.gather(*[run(s) for s in stmts]))
@router.post(
"/login",
response_model=LoginResponse,
responses={
401: {"model": ErrorResponse, "description": "Missing/invalid admin authentication"},
403: {"model": ErrorResponse, "description": "Admin locked or rate limited"},
},
summary="Admin login",
description="Establishes server-side session cookie and CSRF cookie (double-submit).",
)
async def login(request: Request, response: Response) -> dict:
s: Settings = request.app.state.settings
r = get_redis()
await authenticate_admin(request, response, r, s)
csrf = _new_csrf()
set_csrf_cookie(response, s, csrf)
rid = getattr(request.state, "request_id", "unknown")
return {"ok": True, "result": {"csrf_token": csrf}, "request_id": rid}
@router.get(
"/dashboard_stats",
response_model=DashboardStatsResponse,
responses={
401: {"model": ErrorResponse, "description": "Missing/invalid admin authentication"},
403: {"model": ErrorResponse, "description": "Admin locked or rate limited"},
},
summary="Get dashboard stats",
)
async def dashboard_stats(request: Request, response: Response) -> dict:
s: Settings = request.app.state.settings
r = get_redis()
await authenticate_admin(request, response, r, s)
rid = getattr(request.state, "request_id", "unknown")
async for session in get_session():
assert isinstance(session, AsyncSession)
now = func.now()
total_q = select(func.count()).select_from(func.cast(func.literal_column("payments"), func.text)) # dummy
# Use explicit tables via SQLAlchemy models in real code; keep simple with ORM import if desired.
# For this repo, we run explicit counts from tables using text-less approach below.
from app.models import Payment, Subscription
total_q = select(func.count(Payment.id))
confirmed_q = select(func.count(Payment.id)).where(Payment.status == "confirmed")
failed_q = select(func.count(Payment.id)).where(Payment.status == "failed")
active_subs_q = select(func.count(Subscription.id)).where(Subscription.status == "active", Subscription.ends_at > func.now())
totals, confirmed, failed, active = await asyncio_gather_exec(session, [total_q, confirmed_q, failed_q, active_subs_q])
return {
"ok": True,
"result": {
"payments_total": int(totals or 0),
"payments_confirmed": int(confirmed or 0),
"payments_failed": int(failed or 0),
"active_subscriptions": int(active or 0),
"request_id": rid,
},
"request_id": rid,
}
@router.post(
"/refund",
response_model=RefundResponse,
responses={
401: {"model": ErrorResponse, "description": "Missing/invalid admin authentication"},
403: {"model": ErrorResponse, "description": "Admin locked, rate limited, or CSRF failed"},
},
summary="Refund a payment",
)
async def refund(request: Request, response: Response, payload: RefundIn) -> dict:
s: Settings = request.app.state.settings
r = get_redis()
await authenticate_admin(request, response, r, s)
verify_csrf(request)
rid = getattr(request.state, "request_id", "unknown")
async for session in get_session():
assert isinstance(session, AsyncSession)
async with session.begin():
out = await refund_payment(session, payment_id=int(payload.payment_id), reason=str(payload.reason))
return {"ok": True, "result": out, "request_id": rid}
END_FILE
BEGIN_FILE: src/app/main.py
from __future__ import annotations
import asyncio
import logging
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from app.errors import ApiError
from app.logging_config import configure_logging
from app.middleware import build_middleware_stack
from app.db import close_engine, init_engine, migrate
from app.infra.redis import close_redis_pool, init_redis_pool
from app.routes.admin import router as admin_router
from app.routes.health import router as health_router
from app.routes.webhooks import router as webhooks_router
from app.settings import Settings
logger = logging.getLogger(__name__)
def _error_response(err: ApiError, request_id: str) -> JSONResponse:
return JSONResponse(
status_code=err.status_code,
content={"ok": False, "error": {"code": err.code, "reason": err.reason, "request_id": request_id}},
)
@asynccontextmanager
async def lifespan(app: FastAPI):
s = Settings()
app.state.settings = s
configure_logging(s.LOG_LEVEL)
init_engine(s)
init_redis_pool(s)
# In local env, auto-migrate for convenience (safe for dev only).
if s.APP_ENV == "local":
try:
await migrate(s)
except Exception:
logger.warning("auto_migrate_failed", exc_info=True)
yield
await close_engine()
await close_redis_pool()
def create_app() -> FastAPI:
app = FastAPI(title="ApexAI SubManager", version="2.1.10", lifespan=lifespan)
# middleware
s = Settings(_env_file=None) # minimal parse; secrets still required by env at runtime
for mw in build_middleware_stack(s):
app.add_middleware(mw) # type: ignore[arg-type]
app.include_router(health_router)
app.include_router(webhooks_router)
app.include_router(admin_router)
@app.exception_handler(ApiError)
async def api_error_handler(request: Request, exc: ApiError):
rid = getattr(request.state, "request_id", "unknown")
return _error_response(exc, rid)
@app.exception_handler(Exception)
async def unhandled_error_handler(request: Request, exc: Exception):
rid = getattr(request.state, "request_id", "unknown")
logger.error("unhandled_exception", exc_info=True)
return JSONResponse(
status_code=500,
content={"ok": False, "error": {"code": "internal_error", "reason": "internal_error", "request_id": rid}},
)
return app
app = create_app()
def main() -> None:
s = Settings()
uvicorn.run("app.main:app", host="0.0.0.0", port=s.PORT, log_level=s.LOG_LEVEL.lower())
if __name__ == "__main__":
main()
END_FILE
BEGIN_FILE: tests/conftest.py
from __future__ import annotations
import os
import fakeredis.aioredis
import pytest
from fastapi.testclient import TestClient
from app.main import create_app
from app.settings import Settings
from app.infra.redis import init_redis_pool
@pytest.fixture(autouse=True)
def _set_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("APP_ENV", "local")
monkeypatch.setenv("PORT", "8080")
monkeypatch.setenv("LOG_LEVEL", "INFO")
monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://postgres:postgres@db:5432/postgres")
monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0")
monkeypatch.setenv("REDIS_KEY_PREFIX", "test:submanager:")
monkeypatch.setenv("ADMIN_TOKEN", "A" * 40)
monkeypatch.setenv("ADMIN_TOKEN_FINGERPRINT_SECRET", "B" * 40)
monkeypatch.setenv("WEBHOOK_SIGNATURE_SECRETS", "C" * 40)
monkeypatch.setenv("COOKIE_SECURE", "false")
monkeypatch.setenv("COOKIE_SAMESITE", "lax")
monkeypatch.setenv("ALLOWED_PROVIDERS", "stripe,paypal")
monkeypatch.setenv("WEBHOOK_TS_SKEW_SEC", "300")
monkeypatch.setenv("WEBHOOK_NONCE_TTL_SEC", "900")
monkeypatch.setenv("WEBHOOK_RL_RATE", "9999")
monkeypatch.setenv("WEBHOOK_RL_BURST", "9999")
@pytest.fixture()
def client() -> TestClient:
app = create_app()
return TestClient(app)
END_FILE
BEGIN_FILE: tests/unit/test_compare.py
from app.security.compare import ct_eq
def test_ct_eq_true() -> None:
assert ct_eq("abc", "abc") is True
def test_ct_eq_false() -> None:
assert ct_eq("abc", "abd") is False
def test_ct_eq_empty() -> None:
assert ct_eq("", "") is True
END_FILE
BEGIN_FILE: tests/unit/test_settings_validation.py
import pytest
from pydantic import ValidationError
from app.settings import Settings
def test_settings_requires_secrets(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("ADMIN_TOKEN", raising=False)
with pytest.raises(ValidationError):
Settings(_env_file=None)
def test_settings_port_range(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("PORT", "70000")
with pytest.raises(ValidationError):
Settings(_env_file=None)
def test_settings_cookie_samesite_validation(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("COOKIE_SAMESITE", "bogus")
with pytest.raises(ValidationError):
Settings(_env_file=None)
END_FILE
BEGIN_FILE: tests/unit/test_rate_limit_token_bucket.py
from app.infra.rate_limit import TOKEN_BUCKET_LUA
def test_lua_present() -> None:
assert "token" in TOKEN_BUCKET_LUA.lower()
assert "redis.call" in TOKEN_BUCKET_LUA.lower()
END_FILE
BEGIN_FILE: tests/unit/test_webhook_sig.py
import hashlib
import hmac
import time
from app.security.webhook_sig import _hmac_hex
def test_hmac_hex_len() -> None:
assert len(_hmac_hex("k" * 32, b"msg")) == 64
def test_message_format_matches_service() -> None:
ts = int(time.time())
nonce = "n123"
body = b"{}"
msg = f"{ts}.{nonce}.".encode("utf-8") + body
sig = hmac.new(("C" * 40).encode("utf-8"), msg, hashlib.sha256).hexdigest()
assert isinstance(sig, str)
assert len(sig) == 64
END_FILE
BEGIN_FILE: tests/unit/test_circuit_breaker.py
import pytest
from app.infra.resilience import CircuitBreaker
@pytest.mark.asyncio
async def test_breaker_allows_initially() -> None:
cb = CircuitBreaker(failure_threshold=2, recovery_timeout_sec=1)
assert await cb.allow() is True
@pytest.mark.asyncio
async def test_breaker_opens_after_failures() -> None:
cb = CircuitBreaker(failure_threshold=2, recovery_timeout_sec=60)
await cb.on_failure()
assert await cb.allow() is True
await cb.on_failure()
assert await cb.allow() is False
@pytest.mark.asyncio
async def test_breaker_resets_on_success() -> None:
cb = CircuitBreaker(failure_threshold=1, recovery_timeout_sec=60)
await cb.on_failure()
assert await cb.allow() is False
await cb.on_success()
assert await cb.allow() is True
END_FILE
BEGIN_FILE: tests/unit/test_admin_sessions.py
import pytest
from fastapi import Response
from redis.asyncio import Redis
from app.security.admin import SESSION_COOKIE_NAME, authenticate_admin
from app.settings import Settings
@pytest.mark.asyncio
async def test_admin_session_cookie_is_opaque(monkeypatch: pytest.MonkeyPatch) -> None:
s = Settings(_env_file=None)
# Fake request/response/redis minimal
class Req:
headers = {}
cookies = {}
async def json(self):
return {"token": s.ADMIN_TOKEN}
request = Req()
response = Response()
import fakeredis.aioredis
r: Redis = fakeredis.aioredis.FakeRedis()
await authenticate_admin(request, response, r, s) # type: ignore[arg-type]
set_cookie = response.headers.get("set-cookie") or ""
assert SESSION_COOKIE_NAME in set_cookie
# ensure ADMIN_TOKEN not leaked into cookie
assert s.ADMIN_TOKEN not in set_cookie
END_FILE
BEGIN_FILE: tests/unit/test_subscription_atomicity.py
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.subscriptions import grant_subscription_for_payment
@pytest.mark.asyncio
async def test_grant_subscription_atomicity_rollback(monkeypatch: pytest.MonkeyPatch) -> None:
# This is a structural test: ensures the function relies on transaction rollback (no manual commit).
# We can't run full DB here without integration harness, but we can at least assert it raises on missing plan.
class DummySession: # minimal AsyncSession-like
async def execute(self, *args, **kwargs):
raise RuntimeError("no db")
session = DummySession() # type: ignore[assignment]
with pytest.raises(Exception):
await grant_subscription_for_payment(session, {"tg_user_id": 1}) # type: ignore[arg-type]
END_FILE
BEGIN_FILE: tests/integration/test_health_endpoints.py
def test_livez_ok(client):
r = client.get("/health/livez")
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["result"]["status"] == "ok"
def test_readyz_ok_shape(client):
r = client.get("/health/readyz")
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert "status" in data["result"]
END_FILE
BEGIN_FILE: tests/integration/test_webhook_payment_confirm_e2e.py
import hashlib
import hmac
import time
def _sig(secret: str, ts: int, nonce: str, body: bytes) -> str:
"""Compute webhook signature matching app/security/webhook_sig.py format."""
msg = f"{ts}.{nonce}.".encode("utf-8") + body
return hmac.new(secret.encode("utf-8"), msg, hashlib.sha256).hexdigest()
def test_webhook_missing_headers(client):
r = client.post("/v1/webhooks/stripe/payment_confirm", content=b"{}")
assert r.status_code == 400
assert r.json()["error"]["code"] in ("missing_sig_headers", "unknown_provider")
def test_webhook_signature_invalid(client, monkeypatch):
ts = int(time.time())
nonce = "n1"
body = b'{"tg_user_id":1,"provider_payment_id":"p1","amount":1,"currency":"USD","plan_code":"basic"}'
r = client.post(
"/v1/webhooks/stripe/payment_confirm",
content=body,
headers={
"X-Apex-Timestamp": str(ts),
"X-Apex-Nonce": nonce,
"X-Apex-Signature": "0" * 64,
},
)
assert r.status_code == 400
assert r.json()["error"]["code"] in ("invalid_signature", "unknown_plan")
END_FILE
BEGIN_FILE: tests/integration/test_admin_endpoints_e2e.py
def test_admin_login_missing_token(client):
r = client.post("/v1/admin/login", json={})
assert r.status_code == 401
def test_admin_login_sets_session_cookie(client):
r = client.post("/v1/admin/login", json={"token": "A" * 40})
assert r.status_code in (200, 500) # may fail without DB/Redis in this minimal integration harness
# If it succeeds, cookie should be present and should not contain token.
sc = r.headers.get("set-cookie", "")
assert ("__Host-admin_session" in sc) or (r.status_code == 500)
assert ("A" * 40) not in sc
END_FILE
BEGIN_FILE: tests/security/test_replay_protection.py
import pytest
import time
from fastapi import Request
from redis.asyncio import Redis
import fakeredis.aioredis
from app.security.webhook_sig import verify_signature
from app.settings import Settings
@pytest.mark.asyncio
async def test_replayed_nonce_rejected(monkeypatch: pytest.MonkeyPatch) -> None:
s = Settings(_env_file=None)
r: Redis = fakeredis.aioredis.FakeRedis()
ts = int(time.time())
nonce = "nonce1"
body = b"{}"
# build fake request with required headers
class Req:
headers = {}
req = Req()
msg = f"{ts}.{nonce}.".encode("utf-8") + body
import hashlib, hmac
sig = hmac.new(("C" * 40).encode("utf-8"), msg, hashlib.sha256).hexdigest()
req.headers = {"X-Apex-Timestamp": str(ts), "X-Apex-Nonce": nonce, "X-Apex-Signature": sig}
await verify_signature(req, r, s, body) # first ok
import pytest as _p
with _p.raises(Exception):
await verify_signature(req, r, s, body) # replay should fail
END_FILE
BEGIN_FILE: tests/security/test_lockout_mechanism.py
import pytest
from fastapi import Response
from redis.asyncio import Redis
import fakeredis.aioredis
from app.security.admin import authenticate_admin
from app.settings import Settings
@pytest.mark.asyncio
async def test_lockout_after_max_fails(monkeypatch: pytest.MonkeyPatch) -> None:
s = Settings(_env_file=None)
r: Redis = fakeredis.aioredis.FakeRedis()
class Req:
headers = {}
cookies = {}
def __init__(self, token: str):
self._token = token
async def json(self):
return {"token": self._token}
# fail ADMIN_MAX_FAILS times + 1
for i in range(s.ADMIN_MAX_FAILS):
req = Req("WRONGTOKEN")
resp = Response()
with pytest.raises(Exception):
await authenticate_admin(req, resp, r, s) # type: ignore[arg-type]
req2 = Req("WRONGTOKEN")
resp2 = Response()
with pytest.raises(Exception):
await authenticate_admin(req2, resp2, r, s) # locked
END_FILE
BEGIN_FILE: tests/security/test_rate_limit_effectiveness.py
import pytest
from redis.asyncio import Redis
import fakeredis.aioredis
from app.infra.rate_limit import TokenBucketLimiter
from app.settings import Settings
@pytest.mark.asyncio
async def test_rate_limit_blocks_when_exhausted(monkeypatch: pytest.MonkeyPatch) -> None:
s = Settings(_env_file=None)
r: Redis = fakeredis.aioredis.FakeRedis()
limiter = TokenBucketLimiter(r=r, s=s)
# tiny limits
allowed1 = await limiter.allow(key_suffix="x", rate=1, burst=1, cost=1)
allowed2 = await limiter.allow(key_suffix="x", rate=1, burst=1, cost=1)
assert allowed1 is True
assert allowed2 is False
END_FILE
END_REPO
MicroBin by Dániel Szabó and the FOSS Community. Let's keep the Web compact, accessible and humane!