# rlhf_canonicalizer_demo.py
# -----------------------------------------------------------
# Minimal, runnable demo of the canonicalizer/tester from appendix.
# - Encodes 10 RLHF objectives as "ladders" (Add, Rew, Link)
# - Canonicalizes reducible objectives (hash) and emits witnesses for irreducible ones
# - Prints results tables; optionally saves CSVs
#
# Requires: Python 3.9+, pandas
#   pip install pandas
#
# Run:
#   python rlhf_canonicalizer_demo.py
#
# Notes:
# - No GPUs, datasets, or internet needed.
# - Deterministic serialization (ASCII JSON with fixed float format) for demo;
#   in production, use IEEE754 binary64 little-endian bytes as in the spec.

from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import hashlib
import json
import itertools
import pandas as pd

# ------------------------------
# Toy domain (one instance, three candidates)
# ------------------------------
X_DOMAIN = ["x1"]
Y_DOMAIN_BY_X: Dict[str, List[str]] = {"x1": ["a", "b", "c"]}

def all_pairs_for_x(x: str) -> List[Tuple[str, str]]:
    ys = Y_DOMAIN_BY_X[x]
    return [(i, j) for i in ys for j in ys if i != j]

def all_triples_for_x(x: str) -> List[Tuple[str, str, str]]:
    ys = Y_DOMAIN_BY_X[x]
    triples = []
    for a in ys:
        for b in ys:
            for c in ys:
                if a != b and b != c and c != a:
                    triples.append((a, b, c))
    # Deduplicate
    return list(dict.fromkeys(triples))

# ------------------------------
# Operators and ladder
# ------------------------------
@dataclass
class Add:
    # Provide either phi map OR delta_phi callable (not both).
    phi: Optional[Dict[Tuple[str, str], float]] = None  # (x,y) -> phi(x,y)
    delta_phi: Optional[Callable[[str, str, str], float]] = None  # (x,y,z) -> delta at pair

@dataclass
class Rew:
    # Either s_of_x (instance-only) OR omega(x,y,z) for pair-dependent weights.
    s_of_x: Optional[Callable[[str], float]] = None
    omega: Optional[Callable[[str, str, str], float]] = None  # general weight

@dataclass
class Link:
    g_name: str  # "logistic", "identity", ...
    beta: float = 1.0
    monotone: bool = True  # strictly increasing

@dataclass
class Ladder:
    ops: List[Union[Add, Rew, Link]]

# ------------------------------
# Certificates and witnesses
# ------------------------------
@dataclass
class Witness:
    type: str  # "weight_nonconstant" | "cocycle_violation" | "link_nonmonotone"
    x: Optional[str] = None
    pairs: Optional[List[Tuple[str, str]]] = None
    triple: Optional[Tuple[str, str, str]] = None
    values: Optional[Dict[str, float]] = None
    message: Optional[str] = None

@dataclass
class Certificate:
    verdict: str  # "reducible" | "irreducible"
    canon_hash: Optional[str] = None
    rewrite_ledger: Optional[List[str]] = None
    witness: Optional[Witness] = None
    serialization_hex: Optional[str] = None

# ------------------------------
# Canonicalizer / tester core
# ------------------------------
def gauge_fix_phi(phi_map: Dict[Tuple[str, str], float]) -> Dict[Tuple[str, str], float]:
    out = dict(phi_map)
    for x in X_DOMAIN:
        ys = Y_DOMAIN_BY_X[x]
        vals = [out.get((x, y), 0.0) for y in ys]
        if not vals:
            continue
        mu = sum(vals) / len(vals)
        for y in ys:
            out[(x, y)] = out.get((x, y), 0.0) - mu
    return out

def serialize_canonical(phi_gauge: Dict[Tuple[str, str], float],
                        s_map: Dict[str, float],
                        g_name: str) -> bytes:
    # Demo serialization: JSON (sorted keys) with fixed float text format.
    # For production, replace with binary IEEE754 little-endian.
    def float_fmt(v: float) -> str:
        return f"{v:.12g}"
    serial = {
        "header": "GKPOv1-demo",
        "g_name": g_name,
        "beta": 1.0,
        "s": {x: float_fmt(s_map.get(x, 1.0)) for x in sorted(s_map.keys())},
        "phi": {f"{x}:{y}": float_fmt(phi_gauge.get((x, y), 0.0))
                for x in sorted(Y_DOMAIN_BY_X.keys())
                for y in sorted(Y_DOMAIN_BY_X[x])},
        "footer": "END"
    }
    return json.dumps(serial, sort_keys=True, separators=(",", ":")).encode("ascii")

def sha256_hex(b: bytes) -> str:
    return hashlib.sha256(b).hexdigest()

def check_cocycle_from_phi(phi_map: Dict[Tuple[str, str], float],
                           tol: float = 1e-12) -> Optional[Witness]:
    for x in X_DOMAIN:
        ys = Y_DOMAIN_BY_X[x]
        def dphi(y: str, z: str) -> float:
            return phi_map.get((x, y), 0.0) - phi_map.get((x, z), 0.0)
        for a, b, c in all_triples_for_x(x):
            cyc = dphi(a, b) + dphi(b, c) + dphi(c, a)
            if abs(cyc) > tol:
                return Witness(
                    type="cocycle_violation", x=x, triple=(a, b, c),
                    values={"cycle": float(cyc)}, message="triangle cycle sum nonzero from phi"
                )
    return None

def check_cocycle_from_delta(delta_phi: Callable[[str, str, str], float],
                             tol: float = 1e-12) -> Optional[Witness]:
    for x in X_DOMAIN:
        for a, b, c in all_triples_for_x(x):
            cyc = delta_phi(x, a, b) + delta_phi(x, b, c) + delta_phi(x, c, a)
            if abs(cyc) > tol:
                return Witness(
                    type="cocycle_violation", x=x, triple=(a, b, c),
                    values={"cycle": float(cyc)}, message="triangle cycle sum nonzero from additive component"
                )
    return None

def canonicalize(ladder: Ladder,
                 tol_w: float = 1e-9,
                 tol_c: float = 1e-12) -> Certificate:
    phi_acc: Dict[Tuple[str, str], float] = {}
    s_map: Dict[str, float] = {x: 1.0 for x in X_DOMAIN}
    g_name: str = "identity"
    ledger: List[str] = []

    for op in ladder.ops:
        if isinstance(op, Rew):
            if op.omega is not None:
                x = X_DOMAIN[0]
                pairs = all_pairs_for_x(x)
                if len(pairs) >= 2:
                    (y1, z1), (y2, z2) = pairs[0], pairs[1]
                    w1 = float(op.omega(x, y1, z1))
                    w2 = float(op.omega(x, y2, z2))
                    if abs(w1 - w2) > tol_w:
                        return Certificate(
                            verdict="irreducible",
                            witness=Witness(
                                type="weight_nonconstant",
                                x=x, pairs=[(y1, z1), (y2, z2)],
                                values={"omega1": w1, "omega2": w2},
                                message="weight depends on pair"
                            )
                        )
                    else:
                        for x0 in X_DOMAIN:
                            s_map[x0] *= w1
                        ledger.append("merge_rew")
                else:
                    ledger.append("merge_rew")
            elif op.s_of_x is not None:
                for x0 in X_DOMAIN:
                    s_map[x0] *= float(op.s_of_x(x0))
                ledger.append("merge_rew")
            else:
                ledger.append("merge_rew")

        elif isinstance(op, Add):
            if op.delta_phi is not None:
                w = check_cocycle_from_delta(op.delta_phi, tol=tol_c)
                if w is not None:
                    return Certificate(verdict="irreducible", witness=w)
            if op.phi is not None:
                for (x, y), v in op.phi.items():
                    phi_acc[(x, y)] = phi_acc.get((x, y), 0.0) + float(v)
                ledger.append("merge_add")
                w = check_cocycle_from_phi(phi_acc, tol=tol_c)
                if w is not None:
                    return Certificate(verdict="irreducible", witness=w)

        elif isinstance(op, Link):
            if not op.monotone:
                return Certificate(
                    verdict="irreducible",
                    witness=Witness(type="link_nonmonotone", message="link is not strictly monotone")
                )
            g_name = op.g_name
            for x0 in X_DOMAIN:
                s_map[x0] *= float(op.beta)
            ledger.append("absorb_scale")
        else:
            raise ValueError("unknown operator type")

    # Gauge fix, serialize, hash
    phi_gauge = gauge_fix_phi(phi_acc)
    ledger.append("gauge_zero_mean")
    ser = serialize_canonical(phi_gauge, s_map, g_name=g_name)
    h = sha256_hex(ser)

    return Certificate(
        verdict="reducible",
        canon_hash=h,
        rewrite_ledger=ledger,
        witness=None,
        serialization_hex=ser.hex(),
    )

# ------------------------------
# Helpers to build method ladders
# ------------------------------
def phi_from_values(vals: Dict[str, float]) -> Dict[Tuple[str, str], float]:
    return {("x1", y): float(vals.get(y, 0.0)) for y in Y_DOMAIN_BY_X["x1"]}

def ladder_DPO(beta: float = 0.5) -> Ladder:
    phi = phi_from_values({"a": 0.2, "b": 0.1, "c": -0.1})
    return Ladder([Add(phi=phi), Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_IPO() -> Ladder:
    phi = phi_from_values({"a": 0.2, "b": 0.1, "c": -0.1})
    return Ladder([Add(phi=phi), Link(g_name="tanh", beta=1.0, monotone=True)])

def ladder_SimPO() -> Ladder:
    return Ladder([Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_fDPO() -> Ladder:
    phi = phi_from_values({"a": 0.2, "b": 0.1, "c": -0.1})
    return Ladder([Add(phi=phi), Link(g_name="expit", beta=1.0, monotone=True)])

def ladder_ORPO() -> Ladder:
    return Ladder([Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_BT_hinge() -> Ladder:
    return Ladder([Link(g_name="identity", beta=1.0, monotone=True)])

def ladder_RRHF_gated() -> Ladder:
    # Irreducible by cocycle violation via a delta_phi with nonzero triangle cycle
    def delta_phi(x, y, z):
        base = {("a","b"): 0.1, ("b","c"): 0.1, ("c","a"): -0.2}
        return base.get((y, z), -base.get((z, y), 0.0))
    return Ladder([Add(delta_phi=delta_phi), Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_SLiC_HF() -> Ladder:
    # Irreducible by gating one edge to zero
    def delta_phi(x, y, z):
        base = {("a","b"): 0.0, ("b","c"): 0.1, ("c","a"): -0.2}
        return base.get((y, z), -base.get((z, y), 0.0))
    return Ladder([Add(delta_phi=delta_phi), Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_KTO_pair_reduction() -> Ladder:
    # Irreducible via pair-dependent weights
    def omega(x, y, z):
        return 1.0 if (y, z) == ("a", "b") else 1.3
    return Ladder([Rew(omega=omega), Link(g_name="logistic", beta=1.0, monotone=True)])

def ladder_PPO_KL_pair_reduction() -> Ladder:
    # Irreducible via pair-dependent weights
    def omega(x, y, z):
        return 0.9 if (y, z) == ("b", "c") else 1.1
    return Ladder([Rew(omega=omega), Link(g_name="logistic", beta=1.0, monotone=True)])

# ------------------------------
# Main: run and show tables
# ------------------------------
def run_demo(save_csv: bool = False) -> None:
    methods = [
        ("DPO", ladder_DPO()),
        ("IPO", ladder_IPO()),
        ("SimPO", ladder_SimPO()),
        ("f-DPO", ladder_fDPO()),
        ("ORPO", ladder_ORPO()),
        ("BT-hinge", ladder_BT_hinge()),
        ("RRHF", ladder_RRHF_gated()),
        ("SLiC-HF", ladder_SLiC_HF()),
        ("KTO-pair-red", ladder_KTO_pair_reduction()),
        ("PPO-KL-pair-red", ladder_PPO_KL_pair_reduction()),
    ]

    rows = []
    hashes = {}

    for name, lad in methods:
        cert = canonicalize(lad)
        if cert.verdict == "reducible":
            h = cert.canon_hash or ""
            hashes.setdefault(h, []).append(name)
            rows.append({
                "method": name,
                "verdict": "reducible",
                "canon_hash_prefix": h[:12],
                "witness": "",
                "note": "collapses to canonical margin"
            })
        else:
            w = cert.witness
            if w is None:
                w_desc = "n/a"
            elif w.type == "weight_nonconstant":
                w_desc = f"weight_nonconstant pairs={w.pairs} omegas={w.values}"
            elif w.type == "cocycle_violation":
                w_desc = f"cocycle_violation triple={w.triple} cycle={w.values}"
            elif w.type == "link_nonmonotone":
                w_desc = "link_nonmonotone"
            else:
                w_desc = w.type
            rows.append({
                "method": name,
                "verdict": "irreducible",
                "canon_hash_prefix": "",
                "witness": w_desc,
                "note": "emits finite witness"
            })

    df = pd.DataFrame(rows, columns=["method", "verdict", "canon_hash_prefix", "witness", "note"])
    print("\nRLHF Canonicalizer Results")
    print(df.to_string(index=False))

    groups = [{"canon_hash_prefix": h[:12], "methods": ", ".join(sorted(ms))}
              for h, ms in hashes.items()]
    df_groups = pd.DataFrame(groups, columns=["canon_hash_prefix", "methods"]).sort_values("methods")
    print("\nEqual-hash groups (methods that collapse to same canonical form)")
    if len(df_groups) == 0:
        print("(none)")
    else:
        print(df_groups.to_string(index=False))

    if save_csv:
        df.to_csv("rlhf_canonicalizer_results.csv", index=False)
        df_groups.to_csv("rlhf_equal_hash_groups.csv", index=False)
        print("\nSaved: rlhf_canonicalizer_results.csv, rlhf_equal_hash_groups.csv")

if __name__ == "__main__":
    run_demo(save_csv=False)

