# ---------------------------------------------------------------
# run_func_error.py
# ---------------------------------------------------------------
import math, pathlib, csv
import torch
import torch.nn.functional as F
import pandas as pd


from SNN.spike_silu import SiLU4bitFromExp, Softmax8bitFromExp
def calc_K(n_iter: int = 12) -> float:
    K = 1.0
    for k in range(n_iter):
        K *= math.sqrt(1 + 2 ** (-2 * k))
    return K
_K12 = calc_K(12)


def cordic_hypot_pair_no_scale(x: torch.Tensor,
                               y: torch.Tensor,
                               n_iter: int = 12) -> torch.Tensor:

    xi, yi = x.clone(), y.clone()

    one  = torch.tensor(1.0,  dtype=xi.dtype, device=xi.device)
    mone = torch.tensor(-1.0, dtype=xi.dtype, device=xi.device)

    for k in range(n_iter):
        di = torch.where(yi >= 0, one, mone)  
        x_shift, y_shift = yi / (1 << k), xi / (1 << k)
        xi = xi + di * x_shift
        yi = yi - di * y_shift

    return xi.abs()                           
def cordic_l2_pairwise(v: torch.Tensor,
                          eps: float = 0.0,
                          n_iter: int = 12) -> torch.Tensor:
    K = _K12 if n_iter == 12 else calc_K(n_iter)
    K_tensor = v.new_tensor(K)

    if eps > 0.0:
        eps_val = torch.full_like(v[..., :1], math.sqrt(eps))
        v = torch.cat([eps_val, v], dim=-1)


    v = v.abs().sort(dim=-1).values
    while v.size(-1) > 1:
        # if v.size(-1) == 2:
        #     return cordic_hypot_pair_no_scale(v[..., 0], v[..., 1], n_iter) / K_tensor
        D = v.size(-1)
        half = D // 2

        left  = v[..., :2*half:2]
        right = v[..., 1:2*half:2]

        v_out = cordic_hypot_pair_no_scale(left, right, n_iter) / K_tensor

        if D % 2 == 1:
            v = torch.cat([v_out, v[..., -1:]], dim=-1)
        else:
            v = v_out
    return v.squeeze(-1)



from SNN.spile_layer_sor import ptsoftmax as sorbet_softmax 


def eval_silu_curve(device="cpu", dtype=torch.float32, n_pts=20001, q_bits=8):
    x = torch.linspace(-10, 10, n_pts, device=device, dtype=dtype)

    step = float(1 << q_bits)                
    ref  = F.silu(x)
    ref_q = torch.floor(ref * step) / step   

    ours   = SiLU4bitFromExp().to(device).to(dtype)(x)
    sorbet = torch.relu(x)

    abs_err_ours   = (ours   - ref_q).abs()
    abs_err_sorbet = (sorbet - ref_q).abs()


    df = pd.DataFrame({
        "x": x.detach().cpu().numpy(),
        "ref_q": ref_q.detach().cpu().numpy(),
        "ours": ours.detach().cpu().numpy(),
        "sorbet": sorbet.detach().cpu().numpy(),
        "abs_err_ours": abs_err_ours.detach().cpu().numpy(),
        "abs_err_sorbet": abs_err_sorbet.detach().cpu().numpy(),
    })


    stat = {
        "max_abs_err_ours":   abs_err_ours.max().item(),
        "mean_abs_err_ours":  abs_err_ours.mean().item(),
        "max_abs_err_sorbet": abs_err_sorbet.max().item(),
        "mean_abs_err_sorbet":abs_err_sorbet.mean().item(),
    }
    return df, stat

def eval_silu(device="cpu", dtype=torch.float32, n_pts=20001):
    x = torch.linspace(-10, 10, n_pts, device=device, dtype=dtype)
    step = 256.0                                      

    ref      = torch.floor(F.silu(x) * step) / step      
    ours     = SiLU4bitFromExp().to(device).to(dtype)(x)
    sorbet   = torch.relu(x)                           

    def mae(a, b): return (a - b).abs().mean().item()
    return {
        "max_abs_err_ours"   : (ours   - ref).abs().max().item(),
        "mean_abs_err_ours"  : mae(ours,  ref),
        "max_abs_err_sorbet" : (sorbet - ref).abs().max().item(),
        "mean_abs_err_sorbet": mae(sorbet, ref),
    }


def eval_softmax(dims, device="cpu", dtype=torch.float32, num_vec=4096):
    model = Softmax8bitFromExp(dim=-1).to(device).to(dtype)
    rows = []

    torch.manual_seed(0)
    for d in dims:
        logits = (torch.randn(num_vec, d, device=device, dtype=dtype) * 7)

        ref     = torch.softmax(logits, dim=-1)
        ref_q   = torch.floor(ref * 256) / 256            
        ours    = model(logits)
        sorbet  = sorbet_softmax(logits, dim=-1).to(dtype)

        rows.append({
            "dim": d,
            "abs_max_ours"   : (ours   - ref_q).abs().max().item(),
            "abs_mean_ours"  : (ours   - ref_q).abs().mean().item(),
            "abs_max_sorbet" : (sorbet - ref_q).abs().max().item(),
            "abs_mean_sorbet": (sorbet - ref_q).abs().mean().item(),
        })
    return pd.DataFrame(rows)


def eval_rmsnorm(dims,                    
                 eps: float = 1e-5,
                 n_iter: int = 8,
                 q_bits: int = 8,
                 num_batch: int = 4096,  
                 device: str = "cpu",
                 dtype: torch.dtype = torch.float32):

    rows = []
    torch.manual_seed(0)

    for D in dims:

        X = torch.randn(num_batch, D, device=device, dtype=dtype)  # (N,D)


        ref = X * torch.rsqrt(X.pow(2).mean(-1, keepdim=True) + eps)


        r       = cordic_l2_pairwise(X, eps*D, n_iter=n_iter)           
        inv_std = math.sqrt(D) / r                           
        ours    = X * inv_std.unsqueeze(-1)

        step = float(1 << q_bits)
        ours_q  = torch.floor(ours * step) / step      


        mean_abs = X.abs().mean(-1, keepdim=True) + eps
        scale    = torch.pow(2.0, -torch.ceil(torch.log2(mean_abs)))
        sorbet   = X * scale

        diff_ours   = (ours_q  - ref).abs()
        diff_sorbet = (sorbet - ref).abs()

        rows.append({
            "dim": D,
            "abs_max_ours"   : diff_ours.max().item(),
            "abs_mean_ours"  : diff_ours.mean().item(),
            "abs_max_sorbet" : diff_sorbet.max().item(),
            "abs_mean_sorbet": diff_sorbet.mean().item(),
        })

    return pd.DataFrame(rows)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"[INFO] device = {device}")

    df_silu_curve, silu_stat = eval_silu_curve(device=device, dtype=torch.float32, n_pts=20001, q_bits=8)

    dims = [8, 16, 32, 64, 128, 256]
    df_soft = eval_softmax(dims, device=device)

    df_rms = eval_rmsnorm(dims, device=device)


    print("\n=== SiLU 4-bit Error ===")
    for k, v in silu_stat.items():
        print(f"{k:24s}: {v:.4e}")

    print("\n=== Softmax 8-bit Error (With dim) ===")
    print(df_soft.to_string(index=False))

    print("\n=== RMSNorm Error (With dim) ===")
    print(df_rms.to_string(index=False))


    import csv
    from pathlib import Path

    def dump_kv_csv(path: Path, mapping: dict):
        with open(path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["metric", "value"])
            for k, v in mapping.items():
                w.writerow([k, f"{v:.8e}" if isinstance(v, (float, int)) else v])

    def dump_rms_like_silu(df_rms: pd.DataFrame, path: Path):
        with open(path, "w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["metric", "value"])
            for _, r in df_rms.iterrows():
                d = int(r["dim"])
                for k in ["abs_max_ours", "abs_mean_ours",
                        "abs_max_sorbet", "abs_mean_sorbet"]:
                    w.writerow([f"d{d}/{k}", f"{r[k]:.8e}"])

    out_dir = Path("./experiments/results").resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    soft_csv = out_dir / "softmax_error.csv"
    rms_csv  = out_dir / "rms_error.csv"
    silu_curve_csv = out_dir / "silu_curve.csv"
    silu_stat_csv  = out_dir / "silu_error.csv"              # NEW
    rms_silu_style = out_dir / "rms_error_silu_style.csv"

    df_soft.to_csv(soft_csv, index=False)
    df_rms.to_csv(rms_csv, index=False)
    df_silu_curve.to_csv(silu_curve_csv, index=False)

    dump_kv_csv(silu_stat_csv, silu_stat)                    # NEW
    dump_rms_like_silu(df_rms, rms_silu_style)

    print(f"\n[INFO] results in: {out_dir}")
    print(f"       - {soft_csv.name}")
    print(f"       - {rms_csv.name}")
    print(f"       - {silu_curve_csv.name}")
    print(f"       - {silu_stat_csv.name}")
    print(f"       - {rms_silu_style.name}")
