#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Analyze usage-aware steering effects from a generations CSV.

"""

import os
import argparse
import math
import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def _normalize(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return x / (x.norm() + eps)


def load_axis(path: str, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
    v = torch.tensor(np.load(path), dtype=dtype, device=device)
    return _normalize(v)


@torch.inference_mode()
def encode_last_token(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    text: str,
    human_layer_index: int,
) -> torch.Tensor:
    """
    Get the normalized last-token hidden state at a given layer index,
    where hidden_states[1] corresponds to Layer 1 in human counting.
    """
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    out = model(**inputs, output_hidden_states=True, use_cache=False)
    hs = out.hidden_states  # [emb, L1, ..., Ln]
    idx = human_layer_index  # hidden_states[1] == Layer 1
    h = hs[idx][:, -1, :]    # last token
    return _normalize(h.squeeze(0))


def linreg_slope_x_y(x, y):
    """
    Simple least-squares slope for y ~ x (no intercept change),
    returns (slope, Pearson r, r^2).
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if len(x) < 2 or np.std(x) == 0:
        return np.nan, np.nan, np.nan
    slope = np.cov(x, y, bias=True)[0, 1] / np.var(x)
    r = np.corrcoef(x, y)[0, 1]
    return slope, r, r * r


def oddness_score(betas, deltas):
    """
    Measures sign-symmetry: compare mean delta_s at +b and -b for each magnitude b.
    Returns a score in [0, 1], where higher means more symmetric (odd) behavior.
    """
    pairs = []
    abs_set = {abs(x) for x in betas if x != 0}
    for b in abs_set:
        has_pos = any(x == +b for x in betas)
        has_neg = any(x == -b for x in betas)
        if has_pos and has_neg:
            dpos = np.mean([d for x, d in zip(betas, deltas) if x == +b])
            dneg = np.mean([d for x, d in zip(betas, deltas) if x == -b])
            pairs.append((dpos, dneg))
    if not pairs:
        return np.nan
    num = np.mean([abs(dp + dn) for dp, dn in pairs])
    den = np.mean([abs(dp - dn) for dp, dn in pairs]) + 1e-12
    return 1.0 - (num / den)


def ensure_sentiment_score(
    df: pd.DataFrame,
    model_name: str,
    layer_to_score: int,
    main_axis_path: str,
    dtype: torch.dtype,
    device_map: dict | None,
) -> pd.DataFrame:
    """
    Ensure column s_main exists. If missing, compute by:
    s_main = <last_token_hidden_state at layer> · <normalized main axis>.
    """
    if "s_main" in df.columns and df["s_main"].notna().any():
        return df

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map=device_map
    )
    device = model.device

    if not os.path.exists(main_axis_path):
        raise FileNotFoundError(f"MAIN_AXIS_PATH not found: {main_axis_path}")

    u_main = load_axis(main_axis_path, device, dtype)

    scores = []
    for _, row in df.iterrows():
        text = str(row["output"])
        hL = encode_last_token(model, tokenizer, text, layer_to_score)
        scores.append(float((hL @ u_main).item()))

    out_df = df.copy()
    out_df["s_main"] = scores
    return out_df


def parse_dtype(dtype_str: str) -> torch.dtype:
    mapping = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    if dtype_str not in mapping:
        raise ValueError(f"Unsupported dtype: {dtype_str}")
    return mapping[dtype_str]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--csv", required=True, help="Input generations CSV.")
    parser.add_argument("--out_prefix", default="sentiment_usage", help="Output file prefix.")
    parser.add_argument("--blocks", default="usage_only,main_plus_usage",
                        help="exp_mode values to analyze, comma-separated.")
    parser.add_argument("--model_name", default="meta-llama/Meta-Llama-3-8B-Instruct")
    parser.add_argument("--layer", type=int, default=14, help="Human layer index for scoring.")
    parser.add_argument("--dtype", default="float16", choices=["float32", "float16", "bfloat16"])
    parser.add_argument("--main_axis_path", required=True,
                        help="Path to main sentiment axis .npy (no personal paths).")
    parser.add_argument("--device", default="0",
                        help="CUDA device id (e.g., '0'). Use 'cpu' for CPU.")
    args = parser.parse_args()

    # Device map
    if args.device.lower() == "cpu":
        device_map = {"": "cpu"}
    else:
        device_map = {"": int(args.device)}

    # Load CSV
    df = pd.read_csv(args.csv)

    # Normalize column names if needed
    # Some exports may use 'beta' for beta_usage, or 'mode' for ortho.
    alias = {"beta": "beta_usage", "mode": "ortho"}
    for src, dst in alias.items():
        if src in df.columns and dst not in df.columns:
            df[dst] = df[src]

    # Required columns
    required = ["prompt_id", "usage", "ortho", "alpha_main", "beta_usage", "exp_mode", "output"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required column(s): {missing}")

    # Ensure s_main exists
    df = ensure_sentiment_score(
        df=df,
        model_name=args.model_name,
        layer_to_score=args.layer,
        main_axis_path=args.main_axis_path,
        dtype=parse_dtype(args.dtype),
        device_map=device_map,
    )

    # Filter exp_mode blocks
    blocks = [b.strip() for b in args.blocks.split(",") if b.strip()]
    data = df[df["exp_mode"].isin(blocks)].copy()
    if data.empty:
        raise ValueError(f"No rows with exp_mode in {blocks}")

    # Baseline correction within (prompt, usage, ortho, alpha_main, exp_mode) at beta_usage=0
    keys = ["prompt_id", "usage", "ortho", "alpha_main", "exp_mode"]
    data["beta_usage"] = pd.to_numeric(data["beta_usage"], errors="coerce")
    base = (
        data[data["beta_usage"] == 0.0][keys + ["s_main"]]
        .rename(columns={"s_main": "s0"})
    )
    data = data.merge(base, on=keys, how="left")
    data["delta_s"] = data["s_main"] - data["s0"]

    # Per-group metrics
    rows = []
    for gkeys, g in data.groupby(keys):
        g = g.dropna(subset=["beta_usage", "delta_s"]).sort_values("beta_usage")
        if g.empty or g["beta_usage"].nunique() < 2:
            continue
        slope, r, r2 = linreg_slope_x_y(g["beta_usage"], g["delta_s"])
        odd = oddness_score(g["beta_usage"].tolist(), g["delta_s"].tolist())
        rows.append({
            **dict(zip(keys, gkeys)),
            "slope_beta_to_sent": slope,
            "spearman_rho": r,
            "r2": r2,
            "oddness": odd,
            "n_points": len(g),
        })
    pergroup = pd.DataFrame(rows)
    per_out = f"{args.out_prefix}_pergroup.csv"
    pergroup.to_csv(per_out, index=False)
    print(f"[OK] per-group metrics -> {per_out}  (slope/rho/oddness per group)")

    # Summary by (usage, ortho, exp_mode)
    summ = (
        pergroup.groupby(["usage", "ortho", "exp_mode"])[
            ["slope_beta_to_sent", "spearman_rho", "r2", "oddness"]
        ]
        .mean()
        .reset_index()
    )
    sum_out = f"{args.out_prefix}_summary.csv"
    summ.to_csv(sum_out, index=False)
    print(f"[OK] summary -> {sum_out}")

    # Mixed effects model (fallback to cluster-robust OLS)
    try:
        import statsmodels.formula.api as smf
        import statsmodels.api as sm  # noqa: F401
        mdf = data.dropna(subset=["delta_s", "beta_usage"]).copy()
        mdf["ortho"] = mdf["ortho"].astype(int)
        model = smf.mixedlm(
            "delta_s ~ beta_usage * C(usage) + ortho + alpha_main",
            groups="prompt_id",
            data=mdf,
        )
        fit = model.fit(reml=False, method="lbfgs", maxiter=200)
        print("\n[MixedLM] delta_s ~ beta * usage + ortho + alpha")
        print(fit.summary())
    except Exception as e:
        print("[WARN] MixedLM failed; falling back to OLS with cluster-robust SE:", repr(e))
        import statsmodels.formula.api as smf
        mdf = data.dropna(subset=["delta_s", "beta_usage"]).copy()
        mdf["ortho"] = mdf["ortho"].astype(int)
        ols = smf.ols(
            "delta_s ~ beta_usage * C(usage) + ortho + alpha_main",
            data=mdf
        ).fit(cov_type="cluster", cov_kwds={"groups": mdf["prompt_id"]})
        print("\n[OLS-cluster] delta_s ~ beta * usage + ortho + alpha (cluster by prompt)")
        print(ols.summary())

    # Quick read
    if len(summ):
        print("\n=== Quick Read ===")
        for _, r in summ.iterrows():
            tag = f"{r['exp_mode']}/{r['usage']}/ortho={int(r['ortho'])}"
            print(f"{tag:<35s} slope={r['slope_beta_to_sent']:+.4f}  "
                  f"rho={r['spearman_rho']:+.3f}  odd={r['oddness']:+.3f}")


if __name__ == "__main__":
    main()
