
import torch
import torch.nn.functional as F
from dataclasses import dataclass
import argparse, json, math, csv
import matplotlib.pyplot as plt
from typing import List

DEVICE="cpu"

# -----------------------------
# Utilities
# -----------------------------
def one_hot(n, j):
    v = torch.zeros(n, device=DEVICE)
    v[j] = 1.0
    return v

def build_grid(B: float, d: int, P: int):
    L0, LP = -d*(B**2), d*(B**2)
    L = torch.linspace(L0, LP, P+1, device=DEVICE)
    return L

# -----------------------------
# Dataset with target attention (no training)
# -----------------------------
@dataclass
class Sample:
    X: torch.Tensor   # [n,d]
    Y: torch.Tensor   # [n,n]
    K: torch.Tensor   # [n,n]
    Q: torch.Tensor   # [n,n]
    V: torch.Tensor   # [n,n]

class SimAttentionDataset:
    def __init__(self, num_samples: int, n: int, d: int, seed: int = 0):
        g = torch.Generator().manual_seed(seed)
        self.N=num_samples; self.n=n; self.d=d
        self.X = 2*torch.rand(num_samples, n, d, generator=g, device=DEVICE) - 1.0
        self.WK = torch.randn(n, d, generator=g, device=DEVICE)
        self.WQ = torch.randn(n, d, generator=g, device=DEVICE)
        self.WV = torch.randn(n, d, generator=g, device=DEVICE)

    def target_forward(self, X: torch.Tensor):
        # X: [n,d]
        K = self.WK @ X.T              # [n,n]
        Q = self.WQ @ X.T              # [n,n]
        V = self.WV @ X.T              # [n,n]
        W = torch.softmax(K.T @ Q, dim=0)  # [n,n] (column-softmax over keys)
        Y = V @ W                       # [n,n]
        return Y, K, Q, V

    def __len__(self): return self.N

    def __getitem__(self, idx: int) -> Sample:
        X = self.X[idx]
        Y, K, Q, V = self.target_forward(X)
        return Sample(X=X, Y=Y, K=K, Q=Q, V=V)

# -----------------------------
# Explicit l_j for K/Q/V branches
# Each l_j outputs (2d+3) × (P+1).
# -----------------------------
def li_apply_for_rowvec(X: torch.Tensor, Wrow: torch.Tensor, n: int, d: int, P: int, L: torch.Tensor, j: int):
    """
    Construct l_j(Z) for a given row vector in R^d (k_j, q_j, or v_j).
    Output: (2d+3) x (P+1) matrix:
        [ X^T | 0 ]                                  (d x (P+1))
        [ 2 L_r * Wrow (as columns) ]                (d x (P+1))
        [ 1 ... 1  0 ... 0 ]                         (1 x (P+1))
        [ L_0 ... L_P ]                              (1 x (P+1))
        [ -L_0^2 ... -L_P^2 ]                        (1 x (P+1))
    """
    # Top block: X placed in first n columns
    top = torch.zeros(d, P+1, device=DEVICE)
    top[:, :n] = X.T  # [d,n]

    # Middle block: columns r are 2 L_r * Wrow (Wrow ∈ R^d)
    mid = torch.stack([2.0*L[r]*Wrow for r in range(P+1)], dim=1)  # [d, P+1]

    # Tail constants
    M1 = torch.zeros(1, P+1, device=DEVICE); M1[:, :n] = 1.0       # ones in first n cols
    ML = torch.stack([L, -(L**2)], dim=0)                          # [2, P+1]
    out = torch.cat([top, mid, M1, ML], dim=0)                     # (2d+3) x (P+1)
    return out

def build_Wproj_mats(d: int):
    """
    Frozen projections used for every (j) head:
      - W_K^{(j)} picks the 2nd d-row block and the last scalar row (−L^2) -> (d+1) x (2d+3)
      - W_Q^{(j)} picks the top d rows (X) and the "ones" row -> (d+1) x (2d+3)
    """
    WK = torch.zeros(d+1, 2*d+3, device=DEVICE)
    WK[0:d, d:2*d] = torch.eye(d, device=DEVICE)  # select 2L_r Wrow
    WK[d, 2*d+2] = 1.0                            # select −L^2 row

    WQ = torch.zeros(d+1, 2*d+3, device=DEVICE)
    WQ[0:d, 0:d] = torch.eye(d, device=DEVICE)    # select X-block
    WQ[d, 2*d] = 1.0                              # select ones-row
    return WK, WQ

def run_branch_emulation(X: torch.Tensor, W: torch.Tensor, n: int, d: int, P: int, L: torch.Tensor, beta: float, block_offset: int):
    """
    Emulate one of K/Q/V:
      - iterate j in [n], take Wrow = W[j] ∈ R^d
      - T_j = l_j(X, Wrow)
      - Kmat = WK @ T_j ; Qmat = WQ @ T_j
      - A_j = Softmax( beta * (Kmat^T Qmat) ) over rows
      - V_j is a (3n x (P+1)) selector with row (block_offset + j) set to L
      - sum_j V_j @ A_j to accumulate the corresponding nxn block
    Returns the nxn block for this branch.
    """
    WKp, WQp = build_Wproj_mats(d)
    acc = torch.zeros(3*n, n, device=DEVICE)
    for j in range(n):
        Wrow = W[j, :]                               # [d]
        Tj = li_apply_for_rowvec(X, Wrow, n, d, P, L, j)     # (2d+3) x (P+1)
        Kmat = WKp @ Tj                              # (d+1) x (P+1)
        Qmat = WQp @ Tj                              # (d+1) x (P+1)
        logits = beta * (Kmat.T @ Qmat)              # (P+1) x (P+1)
        A = torch.softmax(logits, dim=0)             # attention over rows
        Vsel = torch.zeros(3*n, P+1, device=DEVICE)
        Vsel[block_offset + j, :] = L                # place grid values into the proper block row
        out = Vsel @ A                               # (3n) x (P+1)
        acc += out[:, :n]                            # keep first n columns
    return acc[block_offset:block_offset+n, :]       # n x n

# -----------------------------
# Final frozen single-head Attn_s
# -----------------------------
def attn_s_combine(K_hat: torch.Tensor, Q_hat: torch.Tensor, V_hat: torch.Tensor) -> torch.Tensor:
    W_hat = torch.softmax(K_hat.T @ Q_hat, dim=0)    # [n,n]
    Y_hat = V_hat @ W_hat                             # [n,n]
    return Y_hat, W_hat

# -----------------------------
# End-to-end evaluation for one configuration
# -----------------------------
def evaluate_config(n: int, d: int, P: int, beta: float, samples: int, seed: int):
    torch.manual_seed(seed)
    ds = SimAttentionDataset(num_samples=samples, n=n, d=d, seed=seed)
    all_metrics = []
    for i in range(len(ds)):
        X, Y_t, K_t, Q_t, V_t = ds[i].X, ds[i].Y, ds[i].K, ds[i].Q, ds[i].V
        # Build grid for this sample
        B = max(X.abs().max().item(), ds.WK.abs().max().item(), ds.WQ.abs().max().item(), ds.WV.abs().max().item())
        L = build_grid(B, d, P)
        # Emulate
        K_hat = run_branch_emulation(X, ds.WK, n, d, P, L, beta, block_offset=0)
        Q_hat = run_branch_emulation(X, ds.WQ, n, d, P, L, beta, block_offset=n)
        V_hat = run_branch_emulation(X, ds.WV, n, d, P, L, beta, block_offset=2*n)
        Y_hat, W_hat = attn_s_combine(K_hat, Q_hat, V_hat)
        m = {
            "P": P, "beta": beta, "n": n, "d": d, "samples": samples,
            "||K'-K||_inf": float((K_hat - K_t).abs().max()),
            "||Q'-Q||_inf": float((Q_hat - Q_t).abs().max()),
            "||V'-V||_inf": float((V_hat - V_t).abs().max()),
            "||W'-W||_inf": float((W_hat - torch.softmax(K_t.T @ Q_t, dim=0)).abs().max()),
            "||Y'-Y||_inf": float((Y_hat - Y_t).abs().max()),
            "MSE(Y',Y)": float(F.mse_loss(Y_hat, Y_t))
        }
        all_metrics.append(m)
    mse_values = [m["MSE(Y',Y)"] for m in all_metrics]
    mean_mse = sum(m["MSE(Y',Y)"] for m in all_metrics) / len(all_metrics)
    std_mse = math.sqrt(sum((x - mean_mse)**2 for x in mse_values) / len(mse_values))
    return mean_mse, all_metrics, std_mse

# -----------------------------
# Sweeps & Plots
# -----------------------------
def sweep_d_beta_samples(n: int, P: int, d_list: List[int], beta_list: List[float], samples_list: List[int], seed: int, out_csv: str):
    rows = []
    for d in d_list:
        for beta in beta_list:
            for samples in samples_list:
                mean_mse, _, _ = evaluate_config(n=n, d=d, P=P, beta=beta, samples=samples, seed=seed)
                rows.append({"n": n, "P": P, "d": d, "beta": beta, "samples": samples, "mean_MSE": mean_mse})
                print(json.dumps({"tune_trial": rows[-1]}, indent=2))
    # write CSV
    with open(out_csv, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader(); w.writerows(rows)

def plot_mse_vs_P(n: int, d: int, beta: float, samples: int, P_list: List[int], seed: int, out_png: str, out_csv: str):
    points = []
    for P in P_list:
        mean_mse, _, std_mse = evaluate_config(n=n, d=d, P=P, beta=beta, samples=samples, seed=seed)
        points.append((P, mean_mse, std_mse))
        print(json.dumps({"plotP_trial": {"P": P, "mean_MSE": mean_mse, "std_MSE": std_mse}}, indent=2))

    # Save CSV
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["P", "mean_MSE", "std_MSE"])
        for P, mean_mse, std_mse in points:
            w.writerow([P, mean_mse, std_mse])

    # Extract columns
    Ps   = [p for p, _, _ in points]
    mses = [m for _, m, _ in points]
    stds = [s for _, _, s in points]

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(Ps, mses, color='darkgreen', marker='o', linewidth=3)

    # std band
    ax.fill_between(
        Ps,
        [m - s for m, s in zip(mses, stds)],
        [m + s for m, s in zip(mses, stds)],
        color='gray',
        alpha=0.2
    )

    ax.set_title(f"Number of Interpolation Points Plot",
                 fontsize=26, pad=12)
    ax.set_xlabel("Number of Interpolation Points", fontsize=26, labelpad=10)
    ax.set_ylabel("Mean Loss", fontsize=26, labelpad=10)

    # ticks
    ax.set_xticks(Ps)
    ax.set_xticklabels(Ps, fontsize=22)
    ax.tick_params(axis='both', labelsize=22)

    # ax.grid(True)

    plt.savefig(out_png, bbox_inches="tight")
    plt.close()


def plot_mse_vs_n(P: int, d: int, beta: float, samples: int, n_list: List[int], seed: int, out_png: str, out_csv: str):
    points = []
    for n in n_list:
        mean_mse, _ = evaluate_config(n=n, d=d, P=P, beta=beta, samples=samples, seed=seed)
        points.append((n, mean_mse))
        print(json.dumps({"plotN_trial": {"n": n, "mean_MSE": mean_mse}}, indent=2))
    # Save CSV
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f); w.writerow(["n", "mean_MSE"]); w.writerows(points)
    # Plot
    Ns = [x for x,_ in points]; mses = [m for _,m in points]
    plt.figure()
    plt.plot(Ns, mses, marker="o")
    plt.xlabel("n")
    plt.ylabel("mean MSE(Y',Y)")
    plt.title(f"MSE vs n (P={P}, d={d}, beta={beta}, samples={samples})")
    plt.grid(True)
    plt.savefig(out_png, bbox_inches="tight")
    plt.close()

def parse_list_int(s: str) -> List[int]:
    return [int(x.strip()) for x in s.split(",") if x.strip()]

def parse_list_float(s: str) -> List[float]:
    return [float(x.strip()) for x in s.split(",") if x.strip()]

# -----------------------------
# CLI
# -----------------------------
def main():
    p = argparse.ArgumentParser()
    # base args (used as fixed values for plots / sweeps)
    p.add_argument("--n", type=int, default=12)
    p.add_argument("--d", type=int, default=8)
    p.add_argument("--P", type=int, default=31)
    p.add_argument("--beta", type=float, default=6.0)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--samples", type=int, default=4)

    # hyperparameter search over d, beta, samples (with n,P fixed)
    p.add_argument("--tune", action="store_true", help="Run grid search over d, beta, samples with fixed n and P.")
    p.add_argument("--tune_d_list", type=str, default="4,8,12,16")
    p.add_argument("--tune_beta_list", type=str, default="2.0,4.0,8.0,16.0")
    p.add_argument("--tune_samples_list", type=str, default="2,4,8,16")
    p.add_argument("--tune_out_csv", type=str, default="tune_d_beta_samples.csv")

    # plotting: MSE vs P
    p.add_argument("--plotP", action="store_true", help="Plot mean MSE vs P with fixed n,d,beta,samples.")
    p.add_argument("--P_list", type=str, default="11,21,31,41,61,81,101")
    p.add_argument("--plotP_pdf", type=str, default="mse_vs_P.pdf")
    p.add_argument("--plotP_csv", type=str, default="mse_vs_P.csv")

    # plotting: MSE vs n
    p.add_argument("--plotN", action="store_true", help="Plot mean MSE vs n with fixed P,d,beta,samples.")
    p.add_argument("--n_list", type=str, default="6,8,10,12,16,20")
    p.add_argument("--plotN_pdf", type=str, default="mse_vs_n.pdf")
    p.add_argument("--plotN_csv", type=str, default="mse_vs_n.csv")

    args = p.parse_args()

    if args.tune:
        d_list = parse_list_int(args.tune_d_list)
        beta_list = parse_list_float(args.tune_beta_list)
        samples_list = parse_list_int(args.tune_samples_list)
        sweep_d_beta_samples(n=args.n, P=args.P, d_list=d_list, beta_list=beta_list, samples_list=samples_list,
                             seed=args.seed, out_csv=args.tune_out_csv)
        print(json.dumps({"tune_csv_saved": args.tune_out_csv}, indent=2))

    if args.plotP:
        P_list = parse_list_int(args.P_list)
        plot_mse_vs_P(n=args.n, d=args.d, beta=args.beta, samples=args.samples, P_list=P_list,
                      seed=args.seed, out_png=args.plotP_pdf, out_csv=args.plotP_csv)
        print(json.dumps({"plotP_pdf": args.plotP_pdf, "plotP_csv": args.plotP_csv}, indent=2))

    if args.plotN:
        n_list = parse_list_int(args.n_list)
        plot_mse_vs_n(P=args.P, d=args.d, beta=args.beta, samples=args.samples, n_list=n_list,
                      seed=args.seed, out_png=args.plotN_png, out_csv=args.plotN_csv)
        print(json.dumps({"plotN_png": args.plotN_png, "plotN_csv": args.plotN_csv}, indent=2))

    # Default behavior: run a single evaluation with provided n,d,P,beta,samples
    if (not args.tune) and (not args.plotP) and (not args.plotN):
        mean_mse, all_metrics = evaluate_config(n=args.n, d=args.d, P=args.P, beta=args.beta, samples=args.samples, seed=args.seed)
        for i, m in enumerate(all_metrics):
            print(json.dumps({"sample": i, **m}, indent=2))
            print(i, m)
        print(json.dumps({"aggregate": {"n": args.n, "d": args.d, "P": args.P, "beta": args.beta, "samples": args.samples,
                                        "mean_MSE(Y',Y)": mean_mse}}, indent=2))

if __name__ == "__main__":
    main()
