"""Exact Bernstein certificate for the q=3 binary-tree AF Potts zero-temperature inequality.

It certifies the non-analytic part of
    W(Psi(x,y)) <= (W(x)+W(y))/2,  x,y in Delta_3.

The small central box
    C0 = {p in Delta_3 : 1/4 <= p_i <= 5/12 for i=1,2,3}
is handled by the analytic Fourier/Taylor lemma in the replacement proof.
This script proves the inequality on the complement of C0 x C0 using exact
rational Bernstein coefficients.

All arithmetic used for certification is exact.  Floating point is not used.
"""
from fractions import Fraction
from math import comb, lcm
import time
from functools import lru_cache
import numpy as np
import sympy as sp

# ---------------------------- polynomial ---------------------------------
x1, x2, y1, y2 = sp.symbols("x1 x2 y1 y2")
x = [x1, x2, 1 - x1 - x2]
y = [y1, y2, 1 - y1 - y2]
S = sum((1 - x[i]) * (1 - y[i]) for i in range(3))
z = [(1 - x[i]) * (1 - y[i]) / S for i in range(3)]

def W(p):
    return (
        sum((pi - sp.Rational(1, 3)) ** 2 for pi in p)
        - sp.Rational(3, 4) * sum((pi - sp.Rational(1, 3)) ** 3 for pi in p)
    )

G = sp.Rational(1, 2) * W(x) + sp.Rational(1, 2) * W(y) - W(z)
P_expr = sp.together(8 * S**3 * G).as_numer_denom()[0]
P = sp.Poly(sp.expand(P_expr), x1, x2, y1, y2)

def sp_to_fraction(q):
    q = sp.Rational(q)
    return Fraction(int(q.p), int(q.q))

terms = [((a, b, c, d), sp_to_fraction(coeff)) for (a, b, c, d), coeff in P.terms()]

# ----------------------- Bernstein conversion -----------------------------
n = 6
basis = [(i, j) for i in range(n + 1) for j in range(n + 1 - i)]
nb = len(basis)
bern_factors = {}
for a in range(n + 1):
    for b in range(n + 1 - a):
        den = comb(n, a + b) * comb(a + b, a)  # crucial multinomial factor
        vals = []
        for idx, (i, j) in enumerate(basis):
            if a <= i and b <= j:
                vals.append((idx, Fraction(comb(i, a) * comb(j, b), den)))
        bern_factors[(a, b)] = vals

x_exps = sorted({(a, b) for (a, b, c, d), _ in terms})
y_exps = sorted({(c, d) for (a, b, c, d), _ in terms})
x_index = {e: i for i, e in enumerate(x_exps)}
y_index = {e: i for i, e in enumerate(y_exps)}

C = [[Fraction(0) for _ in y_exps] for __ in x_exps]
for (a, b, c, d), coeff in terms:
    C[x_index[(a, b)]][y_index[(c, d)]] += coeff
assert all(v.denominator == 1 for row in C for v in row)
C_int = np.array([[int(v) for v in row] for row in C], dtype=object)

# -------------------------- exact utilities --------------------------------
def poly_mul(p, q):
    out = {}
    for (a, b), ca in p.items():
        for (c, d), cb in q.items():
            out[(a + c, b + d)] = out.get((a + c, b + d), Fraction(0)) + ca * cb
    return {k: v for k, v in out.items() if v}

def poly_pow_linear(a0, a1, a2, power):
    out = {(0, 0): Fraction(1)}
    base = {}
    if a0:
        base[(0, 0)] = a0
    if a1:
        base[(1, 0)] = a1
    if a2:
        base[(0, 1)] = a2
    for _ in range(power):
        out = poly_mul(out, base)
    return out

def power_to_bernstein(poly):
    vec = [Fraction(0) for _ in basis]
    for (a, b), coeff in poly.items():
        for idx, fac in bern_factors[(a, b)]:
            vec[idx] += coeff * fac
    return vec

def monomial_pullback_bernstein(tri, exp):
    v0, v1, v2 = tri
    a0 = v0[0]
    a1 = v1[0] - v0[0]
    a2 = v2[0] - v0[0]
    b0 = v0[1]
    b1 = v1[1] - v0[1]
    b2 = v2[1] - v0[1]
    poly = poly_mul(poly_pow_linear(a0, a1, a2, exp[0]),
                    poly_pow_linear(b0, b1, b2, exp[1]))
    return power_to_bernstein(poly)

@lru_cache(maxsize=None)
def B_int_matrix_for_triangle_cached(tri, exps_tuple):
    exps = list(exps_tuple)
    frac_mat = [[Fraction(0) for _ in exps] for __ in range(nb)]
    den = 1
    for col, exp in enumerate(exps):
        vec = monomial_pullback_bernstein(tri, exp)
        for row, val in enumerate(vec):
            frac_mat[row][col] = val
            den = lcm(den, val.denominator)
    arr = np.empty((nb, len(exps)), dtype=object)
    for i in range(nb):
        for j in range(len(exps)):
            val = frac_mat[i][j]
            arr[i, j] = val.numerator * (den // val.denominator)
    return arr

def B_int_matrix_for_triangle(tri, exps):
    return B_int_matrix_for_triangle_cached(tri, tuple(exps))

def check_product_cell(tri_x, tri_y):
    Bx = B_int_matrix_for_triangle(tri_x, x_exps)
    By = B_int_matrix_for_triangle(tri_y, y_exps)
    coeffs = (Bx @ C_int) @ By.T
    flat = list(coeffs.ravel())
    mn = min(flat)
    nneg = sum(1 for v in flat if v < 0)
    return nneg, mn

# -------------------------- triangulations ---------------------------------
def simplex_tris(N):
    tris = []
    for i in range(N):
        for j in range(N - i):
            v0 = (Fraction(i, N), Fraction(j, N))
            v1 = (Fraction(i + 1, N), Fraction(j, N))
            v2 = (Fraction(i, N), Fraction(j + 1, N))
            tris.append((v0, v1, v2))
            if i + j <= N - 2:
                v0 = (Fraction(i + 1, N), Fraction(j, N))
                v1 = (Fraction(i + 1, N), Fraction(j + 1, N))
                v2 = (Fraction(i, N), Fraction(j + 1, N))
                tris.append((v0, v1, v2))
    return tris

def tri_inside_box(tri, lo, hi):
    for v in tri:
        p = [v[0], v[1], Fraction(1) - v[0] - v[1]]
        if not all(lo <= q <= hi for q in p):
            return False
    return True

def split_tri4(tri):
    v0, v1, v2 = tri
    def mid(a, b):
        return ((a[0] + b[0]) / 2, (a[1] + b[1]) / 2)
    m01 = mid(v0, v1)
    m12 = mid(v1, v2)
    m20 = mid(v2, v0)
    return [(v0, m01, m20), (m01, v1, m12), (m20, m12, v2), (m01, m12, m20)]

# ----------------------------- certificate ---------------------------------
def main():
    t0 = time.time()
    big_lo, big_hi = Fraction(1, 6), Fraction(1, 2)
    small_lo, small_hi = Fraction(1, 4), Fraction(5, 12)

    # Level 1: N=6 triangulation.  Certify all cells outside the larger central box.
    tris6 = simplex_tris(6)
    big = [tri_inside_box(T, big_lo, big_hi) for T in tris6]
    B6 = [B_int_matrix_for_triangle(T, x_exps) for T in tris6]
    L6 = [B @ C_int for B in B6]

    outside_checked = 0
    outside_negative = 0
    outside_min = None
    for ix, L in enumerate(L6):
        for iy, By in enumerate(B6):
            if big[ix] and big[iy]:
                continue
            coeffs = L @ By.T
            flat = list(coeffs.ravel())
            mn = min(flat)
            outside_min = mn if outside_min is None else min(outside_min, mn)
            nneg = sum(1 for v in flat if v < 0)
            outside_negative += nneg
            outside_checked += 1
            assert nneg == 0, ("outside big box failed", ix, iy, nneg, mn)

    # Level 2: inside the larger central box, split its six triangles once.
    big_tris = [T for T, flag in zip(tris6, big) if flag]
    subtris = []
    for T in big_tris:
        subtris.extend(split_tri4(T))
    small = [tri_inside_box(T, small_lo, small_hi) for T in subtris]

    annulus_checked = 0
    annulus_negative = 0
    annulus_min = None
    bad_annulus = []
    for ix, tx in enumerate(subtris):
        for iy, ty in enumerate(subtris):
            if small[ix] and small[iy]:
                continue  # analytic central box
            nneg, mn = check_product_cell(tx, ty)
            annulus_checked += 1
            annulus_min = mn if annulus_min is None else min(annulus_min, mn)
            if nneg == 0:
                continue
            annulus_negative += nneg
            bad_annulus.append((ix, iy, tx, ty, nneg, mn))

    # Level 3: split the finitely many bad annulus cells once more.
    split_checked = 0
    split_negative = 0
    split_min = None
    for ix, iy, tx, ty, _, _ in bad_annulus:
        for sx in split_tri4(tx):
            for sy in split_tri4(ty):
                if tri_inside_box(sx, small_lo, small_hi) and tri_inside_box(sy, small_lo, small_hi):
                    continue
                nneg, mn = check_product_cell(sx, sy)
                split_checked += 1
                split_min = mn if split_min is None else min(split_min, mn)
                split_negative += nneg
                assert nneg == 0, ("split annulus failed", ix, iy, sx, sy, nneg, mn)

    print("polynomial terms:", len(terms))
    print("Bernstein degree per simplex:", n)
    print("level-1 triangles:", len(tris6))
    print("large central triangles:", sum(big))
    print("small analytic triangles after split:", sum(small))
    print("outside-large-box product cells checked:", outside_checked)
    print("outside-large-box negative Bernstein coefficients:", outside_negative)
    print("outside-large-box minimum scaled coefficient:", outside_min)
    print("annulus product cells checked before final split:", annulus_checked)
    print("annulus bad product cells requiring final split:", len(bad_annulus))
    print("annulus negative Bernstein coefficients before final split:", annulus_negative)
    print("annulus minimum scaled coefficient before final split:", annulus_min)
    print("final split product cells checked:", split_checked)
    print("final split negative Bernstein coefficients:", split_negative)
    print("final split minimum scaled coefficient:", split_min)
    print("certificate status: PASS")
    print("elapsed seconds: %.2f" % (time.time() - t0))

if __name__ == "__main__":
    main()
