#!/usr/bin/env python3
"""Numerical and algebraic checks for paper_task_frontier_race/task_frontier_race.tex."""

from __future__ import annotations

import math


PASSED = 0


def pass_msg(name: str, detail: str) -> None:
    global PASSED
    PASSED += 1
    print(f"PASS {name}: {detail}")


def assert_close(name: str, got: float, expected: float, tol: float) -> None:
    if abs(got - expected) > tol:
        raise AssertionError(f"{name}: got {got}, expected {expected}, tol {tol}")
    pass_msg(name, f"{got:.10g}")


def assert_condition(name: str, condition: bool, detail: str) -> None:
    if not condition:
        raise AssertionError(f"{name}: {detail}")
    pass_msg(name, detail)


def integrate1d(fun, lo: float, hi: float, n: int = 800) -> float:
    if n % 2:
        n += 1
    h = (hi - lo) / n
    total = fun(lo) + fun(hi)
    for i in range(1, n):
        total += (4 if i % 2 else 2) * fun(lo + i * h)
    return total * h / 3.0


def integrate2d(
    fun,
    h_lo: float = -8.0,
    h_hi: float = 8.0,
    f_lo: float = -8.0,
    f_hi: float = 8.0,
    n_h: int = 96,
    n_f: int = 96,
) -> float:
    dh = (h_hi - h_lo) / n_h
    df = (f_hi - f_lo) / n_f
    total = 0.0
    for i in range(n_h):
        h_val = h_lo + (i + 0.5) * dh
        row = 0.0
        for j in range(n_f):
            f_val = f_lo + (j + 0.5) * df
            row += fun(h_val, f_val)
        total += row
    return total * dh * df


def G(s: float) -> float:
    if s >= 0.0:
        z = math.exp(-s)
        return 1.0 / (1.0 + z)
    z = math.exp(s)
    return z / (1.0 + z)


def g(s: float) -> float:
    val = G(s)
    return val * (1.0 - val)


def logit(prob: float) -> float:
    return math.log(prob / (1.0 - prob))


def required_stage_success(stages: int, retries: int, epsilon: float) -> float:
    return 1.0 - (1.0 - (1.0 - epsilon) ** (1.0 / stages)) ** (1.0 / (1.0 + retries))


def effective_fragility(stages: int, retries: int, epsilon: float, local_burden: float) -> float:
    return local_burden + logit(required_stage_success(stages, retries, epsilon))


def sequential_success(q_f: float, local_burden: float, stages: int, retries: int) -> float:
    local = G(q_f - local_burden)
    stage = 1.0 - (1.0 - local) ** (1.0 + retries)
    return stage**stages


def a(h: float, f: float) -> float:
    low_low = 0.90 * math.exp(-0.18 * ((h - 0.3) ** 2 + 0.8 * (f - 0.2) ** 2))
    long_easy = 0.70 * math.exp(-0.23 * ((h - 3.0) ** 2 + 0.7 * (f - 0.3) ** 2))
    short_fragile = 0.60 * math.exp(-0.32 * ((h - 0.5) ** 2 + (f - 3.0) ** 2))
    long_fragile = 0.55 * math.exp(-0.20 * ((h - 2.8) ** 2 + (f - 2.8) ** 2))
    return low_low + long_easy + short_fragile + long_fragile


def p(q: tuple[float, float], h: float, f: float) -> float:
    return G(q[0] - h) * G(q[1] - f)


def W(q: tuple[float, float]) -> float:
    return integrate2d(lambda h, f: a(h, f) * p(q, h, f))


def BH(q: tuple[float, float]) -> float:
    return integrate2d(lambda h, f: a(h, f) * g(q[0] - h) * G(q[1] - f))


def BF(q: tuple[float, float]) -> float:
    return integrate2d(lambda h, f: a(h, f) * G(q[0] - h) * g(q[1] - f))


def C(q: tuple[float, float]) -> float:
    return integrate2d(lambda h, f: a(h, f) * g(q[0] - h) * g(q[1] - f))


def add(x: tuple[float, float], y: tuple[float, float]) -> tuple[float, float]:
    return (x[0] + y[0], x[1] + y[1])


def sub(x: tuple[float, float], y: tuple[float, float]) -> tuple[float, float]:
    return (x[0] - y[0], x[1] - y[1])


def scale(c: float, x: tuple[float, float]) -> tuple[float, float]:
    return (c * x[0], c * x[1])


def dot(x: tuple[float, float], y: tuple[float, float]) -> float:
    return x[0] * y[0] + x[1] * y[1]


def norm1(x: tuple[float, float]) -> float:
    return abs(x[0]) + abs(x[1])


def A(q: tuple[float, float], delta: tuple[float, float]) -> float:
    return W(add(q, delta)) - W(q)


def share(r: float) -> float:
    return G(r)


def scaled_share(kappa: float, gap: float) -> float:
    return G(kappa * gap)


def data_tipping_map(gap: float, rho: float, trace_productivity: float, mass: float, kappa: float) -> float:
    return (1.0 - rho) * gap + trace_productivity * mass * (2.0 * scaled_share(kappa, gap) - 1.0)


def race_pressure(
    qi: tuple[float, float],
    qj: tuple[float, float],
    di: tuple[float, float],
    dj: tuple[float, float],
    beta: tuple[float, float],
    rival_action: int,
) -> float:
    rival_step = scale(rival_action, dj)
    invested_q = add(qi, di)
    r_invest = dot(beta, sub(invested_q, add(qj, rival_step)))
    r_no = dot(beta, sub(qi, add(qj, rival_step)))
    return share(r_invest) * W(invested_q) - share(r_no) * W(qi)


def race_decomp(
    qi: tuple[float, float],
    qj: tuple[float, float],
    di: tuple[float, float],
    dj: tuple[float, float],
    beta: tuple[float, float],
    rival_action: int,
) -> float:
    rival_step = scale(rival_action, dj)
    invested_q = add(qi, di)
    r_invest = dot(beta, sub(invested_q, add(qj, rival_step)))
    r_no = dot(beta, sub(qi, add(qj, rival_step)))
    return share(r_invest) * A(qi, di) + (share(r_invest) - share(r_no)) * W(qi)


def is_equilibrium(profile: tuple[int, int], phi_l: dict[int, float], phi_f: dict[int, float], cost: float) -> bool:
    a_l, a_f = profile
    leader_ok = phi_l[a_f] >= cost if a_l else phi_l[a_f] <= cost
    follower_ok = phi_f[a_l] >= cost if a_f else phi_f[a_l] <= cost
    return leader_ok and follower_ok


def region_value(q: tuple[float, float], bounds: tuple[float, float, float, float]) -> float:
    h_lo, h_hi, f_lo, f_hi = bounds
    return integrate2d(lambda h, f: a(h, f) * p(q, h, f), h_lo, h_hi, f_lo, f_hi, 80, 80)


def region_mass(bounds: tuple[float, float, float, float]) -> float:
    h_lo, h_hi, f_lo, f_hi = bounds
    return integrate2d(lambda h, f: a(h, f), h_lo, h_hi, f_lo, f_hi, 80, 80)


def region_BH(q: tuple[float, float], bounds: tuple[float, float, float, float]) -> float:
    h_lo, h_hi, f_lo, f_hi = bounds
    return integrate2d(lambda h, f: a(h, f) * g(q[0] - h) * G(q[1] - f), h_lo, h_hi, f_lo, f_hi, 80, 80)


def region_BF(q: tuple[float, float], bounds: tuple[float, float, float, float]) -> float:
    h_lo, h_hi, f_lo, f_hi = bounds
    return integrate2d(lambda h, f: a(h, f) * G(q[0] - h) * g(q[1] - f), h_lo, h_hi, f_lo, f_hi, 80, 80)


def social_frontier(q1: tuple[float, float], q2: tuple[float, float]) -> float:
    return integrate2d(lambda h, f: a(h, f) * max(p(q1, h, f), p(q2, h, f)))


def value(state: tuple[float, float, float, float]) -> float:
    q1h, q1f, q2h, q2f = state
    return 0.4 * q1h + 0.5 * q1f - 0.15 * max(q2h - q1h, 0.0) - 0.12 * max(q2f - q1f, 0.0)


def transition(
    q1: tuple[float, float],
    q2: tuple[float, float],
    d1: tuple[float, float],
    d2: tuple[float, float],
) -> tuple[float, float, float, float]:
    q1n = add(q1, d1)
    q2n = add(q2, d2)
    return (q1n[0], q1n[1], q2n[0], q2n[1])


def H_cap(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    return (
        0.2
        + 0.25 * base
        + 0.55 * search
        + 0.50 * tools
        + 0.16 * verify
        + 0.20 * data
        + 0.22 * distill
        + 0.07 * base * search
        + 0.05 * base * tools
        + 0.02 * base * verify
    )


def F_cap(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    return (
        0.1
        + 0.90 * base
        + 0.10 * search
        + 0.14 * tools
        + 0.55 * verify
        + 0.10 * data
        + 0.26 * distill
        + 0.02 * base * search
        + 0.03 * base * tools
        + 0.08 * base * verify
    )


def H_b(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, data, distill
    return 0.25 + 0.07 * search + 0.05 * tools + 0.02 * verify


def H_s(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.55 + 0.07 * base


def H_tau(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.50 + 0.05 * base


def H_v(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.16 + 0.02 * base


def H_d(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, search, tools, verify, data, distill
    return 0.20


def H_m(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, search, tools, verify, data, distill
    return 0.22


def F_b(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, data, distill
    return 0.90 + 0.02 * search + 0.03 * tools + 0.08 * verify


def F_s(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.10 + 0.02 * base


def F_tau(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.14 + 0.03 * base


def F_v(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del search, tools, verify, data, distill
    return 0.55 + 0.08 * base


def F_d(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, search, tools, verify, data, distill
    return 0.10


def F_m(base: float, search: float, tools: float, verify: float, data: float, distill: float) -> float:
    del base, search, tools, verify, data, distill
    return 0.26


def q_from_tech(
    base: float,
    search: float,
    tools: float,
    verify: float,
    data: float,
    distill: float,
) -> tuple[float, float]:
    return (
        H_cap(base, search, tools, verify, data, distill),
        F_cap(base, search, tools, verify, data, distill),
    )


def runtime_cost(distill: float) -> float:
    return 0.8 * math.exp(-0.5 * distill)


def runtime_cost_prime(distill: float) -> float:
    return -0.4 * math.exp(-0.5 * distill)


def main() -> None:
    # Reliability kernel.
    mass = integrate1d(g, -40.0, 40.0, 80_000)
    assert_close("logistic boundary kernel integrates to one", mass, 1.0, 1e-12)

    # Sufficient statistics.
    q = (1.1, 1.0)
    eps = 2e-3
    fd_h = (W((q[0] + eps, q[1])) - W((q[0] - eps, q[1]))) / (2.0 * eps)
    fd_f = (W((q[0], q[1] + eps)) - W((q[0], q[1] - eps))) / (2.0 * eps)
    assert_close("BH equals dW/dqH", BH(q), fd_h, 2e-5)
    assert_close("BF equals dW/dqF", BF(q), fd_f, 2e-5)
    c_val = C(q)
    cross_h = (BH((q[0], q[1] + eps)) - BH((q[0], q[1] - eps))) / (2.0 * eps)
    cross_f = (BF((q[0] + eps, q[1])) - BF((q[0] - eps, q[1]))) / (2.0 * eps)
    assert_close("C equals dBH/dqF", c_val, cross_h, 2e-5)
    assert_close("C equals dBF/dqH", c_val, cross_f, 2e-5)
    assert_condition("C nonnegative", c_val >= 0.0, f"C={c_val:.10g}")

    # Sequential microfoundation.
    stages = 8
    retries = 2
    failure_tolerance = 0.02
    local_burden = 0.4
    f_eff = effective_fragility(stages, retries, failure_tolerance, local_burden)
    assert_close("sequential threshold hits target failure rate", sequential_success(f_eff, local_burden, stages, retries), 1.0 - failure_tolerance, 1e-12)
    assert_condition(
        "fragility rises with horizon",
        effective_fragility(stages + 2, retries, failure_tolerance, local_burden) > f_eff,
        "f(n+2)>f(n)",
    )
    assert_condition(
        "fragility falls with retries",
        effective_fragility(stages, retries + 1, failure_tolerance, local_burden) < f_eff,
        "f(ell+1)<f(ell)",
    )
    assert_condition(
        "fragility rises when tolerated failure falls",
        effective_fragility(stages, retries, failure_tolerance / 2.0, local_burden) > f_eff,
        "f(eps/2)>f(eps)",
    )
    assert_close(
        "local reliability burden shifts fragility one-for-one",
        effective_fragility(stages, retries, failure_tolerance, local_burden + 0.25) - f_eff,
        0.25,
        1e-12,
    )

    # Finite-step paths and interaction.
    delta = (0.28, 0.22)
    direct = A(q, delta)
    path_hf = integrate1d(lambda u: BH((q[0] + u, q[1])), 0.0, delta[0], 80) + integrate1d(
        lambda v: BF((q[0] + delta[0], q[1] + v)), 0.0, delta[1], 80
    )
    path_fh = integrate1d(lambda v: BF((q[0], q[1] + v)), 0.0, delta[1], 80) + integrate1d(
        lambda u: BH((q[0] + u, q[1] + delta[1])), 0.0, delta[0], 80
    )
    assert_close("A equals H-then-F path", direct, path_hf, 8e-5)
    assert_close("A equals F-then-H path", direct, path_fh, 8e-5)
    interaction = direct - A(q, (delta[0], 0.0)) - A(q, (0.0, delta[1]))
    c_integral = integrate2d(lambda u, v: C((q[0] + u, q[1] + v)), 0.0, delta[0], 0.0, delta[1], 16, 16)
    assert_close("interaction equals integral of C", interaction, c_integral, 8e-5)
    assert_condition("interaction nonnegative", interaction >= 0.0, f"I={interaction:.10g}")

    small = (1e-4, 2e-4)
    local = small[0] * BH(q) + small[1] * BF(q)
    assert_close("local A equals boundary linearization", A(q, small), local, 3e-7)

    # Technology race values: base, search, tools, verification, deployment data, and distillation.
    base = 1.2
    search = 0.8
    tools = 0.7
    verify = 0.5
    data = 0.6
    distill = 0.4
    served_mass = 3.0
    tech = (base, search, tools, verify, data, distill)
    qtech = q_from_tech(*tech)
    vb = BH(qtech) * H_b(*tech) + BF(qtech) * F_b(*tech)
    vs = BH(qtech) * H_s(*tech) + BF(qtech) * F_s(*tech)
    vtau = BH(qtech) * H_tau(*tech) + BF(qtech) * F_tau(*tech)
    vv = BH(qtech) * H_v(*tech) + BF(qtech) * F_v(*tech)
    vd = BH(qtech) * H_d(*tech) + BF(qtech) * F_d(*tech)
    vm = BH(qtech) * H_m(*tech) + BF(qtech) * F_m(*tech) - served_mass * runtime_cost_prime(distill)
    tech_eps = 1e-4
    fd_b = (W(q_from_tech(base + tech_eps, search, tools, verify, data, distill)) - W(q_from_tech(base - tech_eps, search, tools, verify, data, distill))) / (2.0 * tech_eps)
    fd_s = (W(q_from_tech(base, search + tech_eps, tools, verify, data, distill)) - W(q_from_tech(base, search - tech_eps, tools, verify, data, distill))) / (2.0 * tech_eps)
    fd_tau = (W(q_from_tech(base, search, tools + tech_eps, verify, data, distill)) - W(q_from_tech(base, search, tools - tech_eps, verify, data, distill))) / (2.0 * tech_eps)
    fd_v = (W(q_from_tech(base, search, tools, verify + tech_eps, data, distill)) - W(q_from_tech(base, search, tools, verify - tech_eps, data, distill))) / (2.0 * tech_eps)
    fd_d = (W(q_from_tech(base, search, tools, verify, data + tech_eps, distill)) - W(q_from_tech(base, search, tools, verify, data - tech_eps, distill))) / (2.0 * tech_eps)
    fd_m = (
        W(q_from_tech(base, search, tools, verify, data, distill + tech_eps))
        - served_mass * runtime_cost(distill + tech_eps)
        - W(q_from_tech(base, search, tools, verify, data, distill - tech_eps))
        + served_mass * runtime_cost(distill - tech_eps)
    ) / (2.0 * tech_eps)
    assert_close("base race value chain rule", vb, fd_b, 3e-5)
    assert_close("search race value chain rule", vs, fd_s, 3e-5)
    assert_close("tools race value chain rule", vtau, fd_tau, 3e-5)
    assert_close("verification race value chain rule", vv, fd_v, 3e-5)
    assert_close("deployment data value chain rule", vd, fd_d, 3e-5)
    assert_close("distillation value chain rule with cost savings", vm, fd_m, 3e-5)
    assert_condition("base technology mainly raises reliability", F_b(*tech) > H_b(*tech), "F_b > H_b")
    assert_condition("search mainly raises horizon", H_s(*tech) > F_s(*tech), "H_s > F_s")
    assert_condition("tools mainly raise horizon", H_tau(*tech) > F_tau(*tech), "H_tau > F_tau")
    assert_condition("verification mainly raises reliability", F_v(*tech) > H_v(*tech), "F_v > H_v")

    # Task-region race inversion and scalar failure.
    long_forgiving = (1.7, 4.4, -1.0, 0.9)
    short_fragile = (-0.5, 1.4, 1.7, 4.4)
    dy = (0.60, 0.10)
    dz = (0.10, 0.50)
    inversion_threshold = (dz[1] - dy[1]) / (dy[0] - dz[0])
    bh_long = region_BH(q, long_forgiving)
    bf_long = region_BF(q, long_forgiving)
    bh_fragile = region_BH(q, short_fragile)
    bf_fragile = region_BF(q, short_fragile)
    theta_long = bh_long / bf_long
    theta_fragile = bh_fragile / bf_fragile
    value_y_long = bh_long * dy[0] + bf_long * dy[1]
    value_z_long = bh_long * dz[0] + bf_long * dz[1]
    value_y_fragile = bh_fragile * dy[0] + bf_fragile * dy[1]
    value_z_fragile = bh_fragile * dz[0] + bf_fragile * dz[1]
    assert_condition("long-forgiving region above race-inversion threshold", theta_long > inversion_threshold, f"theta={theta_long:.10g}")
    assert_condition("short-fragile region below race-inversion threshold", theta_fragile < inversion_threshold, f"theta={theta_fragile:.10g}")
    assert_condition("horizon-intensive project wins long-forgiving region", value_y_long > value_z_long, "runtime/search direction wins")
    assert_condition("reliability-intensive project wins short-fragile region", value_z_fragile > value_y_fragile, "base/verification direction wins")
    scalar_boundary_long = bh_long + bf_long
    scalar_boundary_fragile = bh_fragile + bf_fragile
    gamma_y = 0.30
    gamma_z = 0.20
    scalar_diff_long = scalar_boundary_long * (gamma_y - gamma_z)
    scalar_diff_fragile = scalar_boundary_fragile * (gamma_y - gamma_z)
    assert_condition("scalar same-cost project ranking cannot reverse", scalar_diff_long * scalar_diff_fragile > 0.0, "same sign across regions")
    d_gap = (dy[0] - dz[0], dy[1] - dz[1])
    beta_private = (0.6, 0.9)
    market_r = 0.0
    share_r = share(market_r)
    share_prime_r = g(market_r)
    w_long = region_value(q, long_forgiving)
    w_fragile = region_value(q, short_fragile)
    private_threshold_long = (-d_gap[1] / d_gap[0]) - (
        share_prime_r * w_long / (share_r * bf_long)
    ) * (dot(beta_private, d_gap) / d_gap[0])
    private_threshold_fragile = (-d_gap[1] / d_gap[0]) - (
        share_prime_r * w_fragile / (share_r * bf_fragile)
    ) * (dot(beta_private, d_gap) / d_gap[0])
    private_diff_long = share_r * (bh_long * d_gap[0] + bf_long * d_gap[1]) + share_prime_r * w_long * dot(beta_private, d_gap)
    private_diff_fragile = share_r * (bh_fragile * d_gap[0] + bf_fragile * d_gap[1]) + share_prime_r * w_fragile * dot(beta_private, d_gap)
    assert_condition("private long-forgiving threshold predicts horizon project", theta_long > private_threshold_long, f"theta={theta_long:.10g}")
    assert_condition("private short-fragile threshold predicts reliability project", theta_fragile < private_threshold_fragile, f"theta={theta_fragile:.10g}")
    assert_condition("private horizon project wins long-forgiving region", private_diff_long > 0.0, "rent-adjusted runtime/search wins")
    assert_condition("private reliability project wins short-fragile region", private_diff_fragile < 0.0, "rent-adjusted base/verification wins")

    # Task salience and classification.
    h_task, f_task = 1.4, 1.2
    exact_task_change = a(h_task, f_task) * (p(add(q, small), h_task, f_task) - p(q, h_task, f_task))
    salience = a(h_task, f_task) * (
        g(q[0] - h_task) * G(q[1] - f_task) * small[0]
        + G(q[0] - h_task) * g(q[1] - f_task) * small[1]
    )
    assert_close("task-level salience", exact_task_change, salience, 2e-8)

    examples = {
        "product recommendation": (0.4, 0.3, "low-low"),
        "coding": (3.2, 0.4, "high-low"),
        "self driving": (0.8, 3.1, "low-high"),
        "drug discovery": (3.3, 3.2, "high-high"),
    }
    for name, (h_ex, f_ex, label) in examples.items():
        got = ("high" if h_ex > 2.0 else "low") + "-" + ("high" if f_ex > 2.0 else "low")
        assert_condition(f"classification {name}", got == label, got)

    # One-dimensional reduction with separable density.
    def a_h(h: float) -> float:
        return math.exp(-0.25 * (h - 1.0) ** 2)

    def a_f(f: float) -> float:
        return math.exp(-0.30 * (f - 0.5) ** 2)

    qh, qf = 0.8, 0.6
    product_w = integrate1d(lambda h: a_h(h) * G(qh - h), -8, 8, 2000) * integrate1d(
        lambda f: a_f(f) * G(qf - f), -8, 8, 2000
    )
    direct_w = integrate2d(lambda h, f: a_h(h) * a_f(f) * G(qh - h) * G(qf - f), -8, 8, -8, 8, 100, 100)
    assert_close("one-dimensional reduction for separable density", direct_w, product_w, 3e-4)

    # Cost-performance sorting.
    q_leader = (1.8, 1.8)
    q_follower = (1.8, 1.35)
    forgiving_region = (-1.0, 1.0, -2.0, 0.0)
    fragile_region = (0.0, 2.8, 1.1, 2.2)
    forgiving_loss = region_value(q_leader, forgiving_region) - region_value(q_follower, forgiving_region)
    fragile_loss = region_value(q_leader, fragile_region) - region_value(q_follower, fragile_region)
    assert_condition("performance loss nonnegative forgiving", forgiving_loss >= 0.0, f"loss={forgiving_loss:.10g}")
    assert_condition("performance loss nonnegative fragile", fragile_loss >= 0.0, f"loss={fragile_loss:.10g}")
    assert_condition("follower wins if saving exceeds loss", 1.1 * forgiving_loss >= forgiving_loss, "saving above loss")
    assert_condition("frontier wins if saving below loss", 0.9 * fragile_loss < fragile_loss, "saving below loss")
    mass_r = region_mass(forgiving_region)
    assert_condition("uniform gap bound covers loss", (forgiving_loss / mass_r + 1e-6) * mass_r >= forgiving_loss, "bound ok")

    # Race decomposition and strategic interaction.
    qi = (1.0, 1.2)
    qj = (0.45, 0.65)
    di = (0.25, 0.12)
    dj = (0.10, 0.24)
    beta = (0.6, 0.9)
    for aj in (0, 1):
        assert_close(
            f"vector race decomposition a_j={aj}",
            race_pressure(qi, qj, di, dj, beta, aj),
            race_decomp(qi, qj, di, dj, beta, aj),
            1e-11,
        )

    p1 = race_pressure(qi, qj, di, dj, beta, 1)
    p0 = race_pressure(qi, qj, di, dj, beta, 0)
    r = dot(beta, sub(qi, qj))
    ai = dot(beta, di)
    aj = dot(beta, dj)
    strategic_formula = (
        (share(r + ai - aj) - share(r + ai)) * A(qi, di)
        + (share(r + ai - aj) - share(r + ai) - share(r - aj) + share(r)) * W(qi)
    )
    assert_close("strategic interaction formula", p1 - p0, strategic_formula, 1e-11)

    # Dynamic pressure.
    cost = 0.7
    discount = 0.92
    current_project = share(dot(beta, sub(add(qi, di), add(qj, dj)))) * W(add(qi, di)) - cost
    current_no = share(dot(beta, sub(qi, add(qj, dj)))) * W(qi)
    direct_dynamic = current_project - current_no + discount * (
        value(transition(qi, qj, di, dj)) - value(transition(qi, qj, (0.0, 0.0), dj))
    )
    phi = race_decomp(qi, qj, di, dj, beta, 1) + discount * (
        value(transition(qi, qj, di, dj)) - value(transition(qi, qj, (0.0, 0.0), dj))
    )
    assert_close("dynamic pressure equals payoff gain before cost", direct_dynamic, phi - cost, 1e-11)

    # Race regimes.
    K = 1.0
    assert_condition("no-race regime", is_equilibrium((0, 0), {0: 0.6, 1: 0.4}, {0: 0.7, 1: 0.5}, K), "(0,0)")
    assert_condition("leader-only regime", is_equilibrium((1, 0), {0: 1.2, 1: 0.9}, {0: 1.4, 1: 0.8}, K), "(1,0)")
    assert_condition("follower-only regime", is_equilibrium((0, 1), {0: 1.3, 1: 0.8}, {0: 1.2, 1: 0.9}, K), "(0,1)")
    assert_condition("mutual-race regime", is_equilibrium((1, 1), {0: 0.8, 1: 1.1}, {0: 0.7, 1: 1.2}, K), "(1,1)")

    # Data-distillation tipping threshold.
    rho_data = 0.20
    mass = 1.0
    kappa = 1.5
    lambda_d = 0.30
    lambda_m = 0.40
    trace_quality = 0.70
    q_d = (H_d(*tech), F_d(*tech))
    q_m = (H_m(*tech), F_m(*tech))
    chi_struct = lambda_d * dot(beta, q_d) + lambda_m * trace_quality * dot(beta, q_m)
    chi_more_trace = lambda_d * dot(beta, q_d) + lambda_m * (trace_quality + 0.10) * dot(beta, q_m)
    chi_more_distill = lambda_d * dot(beta, q_d) + (lambda_m + 0.10) * trace_quality * dot(beta, q_m)
    assert_condition("structural chi positive", chi_struct > 0.0, f"chi={chi_struct:.10g}")
    assert_condition("structural chi rises with trace usability", chi_more_trace > chi_struct, "d chi / d omega > 0")
    assert_condition("structural chi rises with distillation productivity", chi_more_distill > chi_struct, "d chi / d lambda_m > 0")
    zeta_long = lambda_d * (bh_long * q_d[0] + bf_long * q_d[1]) + lambda_m * trace_quality * (
        bh_long * q_m[0] + bf_long * q_m[1]
    )
    zeta_fragile = lambda_d * (bh_fragile * q_d[0] + bf_fragile * q_d[1]) + lambda_m * trace_quality * (
        bh_fragile * q_m[0] + bf_fragile * q_m[1]
    )
    zeta_more_trace = lambda_d * (bh_long * q_d[0] + bf_long * q_d[1]) + lambda_m * (trace_quality + 0.10) * (
        bh_long * q_m[0] + bf_long * q_m[1]
    )
    zeta_more_h_boundary = lambda_d * ((bh_long + 0.10) * q_d[0] + bf_long * q_d[1]) + lambda_m * trace_quality * (
        (bh_long + 0.10) * q_m[0] + bf_long * q_m[1]
    )
    assert_condition("productive data-distillation value positive", zeta_long > 0.0, f"zeta={zeta_long:.10g}")
    assert_condition("productive value rises with usable traces", zeta_more_trace > zeta_long, "d zeta / d omega > 0")
    assert_condition("productive value rises with boundary mass", zeta_more_h_boundary > zeta_long, "d zeta / d B_H > 0")
    threshold = 2.0 * rho_data / kappa
    chi_low = 0.80 * threshold
    chi_high = 1.20 * threshold
    lockin_low = chi_low * mass * kappa / 2.0 - rho_data
    lockin_high = chi_high * mass * kappa / 2.0 - rho_data
    learning_cost = 0.065
    assert_condition("region stable below lock-in index", lockin_low < 0.0, f"L={lockin_low:.10g}")
    assert_condition("region unstable above lock-in index", lockin_high > 0.0, f"L={lockin_high:.10g}")
    assert_condition("productive tipping classification", lockin_high > 0.0 and zeta_long > learning_cost, "lock-in and productive learning")
    assert_condition("pure share lock-in classification", lockin_high > 0.0 and zeta_fragile <= learning_cost, "lock-in without enough frontier value")
    assert_condition("productive learning without tipping classification", lockin_low < 0.0 and zeta_long > learning_cost, "learning without local lock-in")
    small_gap = 1e-5
    derivative_low = 1.0 - rho_data + chi_low * mass * kappa / 2.0
    derivative_high = 1.0 - rho_data + chi_high * mass * kappa / 2.0
    assert_condition("data flywheel stable below threshold", derivative_low < 1.0, f"derivative={derivative_low:.10g}")
    assert_condition("data flywheel unstable above threshold", derivative_high > 1.0, f"derivative={derivative_high:.10g}")
    assert_condition(
        "small gap shrinks below data threshold",
        data_tipping_map(small_gap, rho_data, chi_low, mass, kappa) / small_gap < 1.0,
        "gap contracts",
    )
    assert_condition(
        "small gap expands above data threshold",
        data_tipping_map(small_gap, rho_data, chi_high, mass, kappa) / small_gap > 1.0,
        "gap expands",
    )

    # Tipping as market-relevant vector gap drift.
    x = 2.5
    rho = 0.20
    alpha_l = 0.75
    alpha_f = 0.40
    p_l = 0.8
    p_f = 0.25
    drift = -rho * x + alpha_l * p_l - alpha_f * p_f
    direct_drift = ((1.0 - rho) * x + alpha_l * p_l - alpha_f * p_f) - x
    assert_close("vector gap drift identity", drift, direct_drift, 1e-14)
    assert_condition("leader-only tipping condition", alpha_l > rho * x, f"alpha_L={alpha_l:.4g}, rho*x={rho*x:.4g}")
    assert_close("mutual equal-step racing compresses gap", -rho * x, ((1 - rho) * x + 0.5 - 0.5) - x, 1e-14)

    # Stopping bound and private persistence.
    tiny = (1e-5, 2e-5)
    bbar = BH(qi) + BF(qi) + 0.1
    wbar = W(qi) + 0.1
    l_sigma = 0.25
    beta_inf = max(abs(beta[0]), abs(beta[1]))
    l_v = 0.8
    bound = norm1(tiny) * (bbar + l_sigma * beta_inf * wbar + discount * l_v)
    actual = A(qi, tiny) + l_sigma * beta_inf * norm1(tiny) * W(qi) + discount * l_v * norm1(tiny)
    assert_condition("stopping pressure bound", actual <= bound, f"actual={actual:.10g}, bound={bound:.10g}")
    rent_only = (share(2.0 + 0.2) - share(2.0)) * 500.0
    assert_condition("private rent race can persist with tiny A", rent_only > 5.0, f"rent={rent_only:.10g}")

    # Welfare frontier.
    qL = (2.0, 2.0)
    qF = (1.2, 1.3)
    assert_close("dominated social frontier equals leader W", social_frontier(qL, qF), W(qL), 4e-5)
    leader_step = (0.15, 0.10)
    assert_close("leader social gain equals A", social_frontier(add(qL, leader_step), qF) - social_frontier(qL, qF), A(qL, leader_step), 5e-5)
    follower_step = (0.3, 0.2)
    assert_close("non-frontier follower social gain is zero", social_frontier(qL, add(qF, follower_step)) - social_frontier(qL, qF), 0.0, 4e-5)
    follower_private = race_decomp(qF, qL, follower_step, (0.0, 0.0), beta, 0)
    assert_condition("non-frontier follower private gain can be positive", follower_private > 0.0, f"private={follower_private:.10g}")

    print()
    print(f"TOTAL: {PASSED} pass, 0 fail")


if __name__ == "__main__":
    main()
