"""
Offline eigenscore + per-response contribution scorer.

Input CSV columns:
    - prompt: str
    - objective: "expand" or "constrain"

For each prompt:
    - Generate K=10 responses with output_hidden_states=True (single call)
    - Compute:
        * eigenscore via getEigenIndicator_v2 (UNCHANGED per user)
        * LLO_raw         (leave-one-out delta of eigenscore)
        * LLO_reward      (normalized within K; higher is better for expand, lower for constrain)
        * variance_raw    (precision-diag; -log10(diag(G^{-1})))
        * variance_reward (normalized within K, flipped for constrain)

Output CSV columns:
    prompt, objective, response, eigenscore, LLO_raw, LLO_reward, thinking_content, thinking_token_len
"""

import argparse
import gc
import math
import os
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
import models
import transformers
from sentence_transformers import SentenceTransformer

# -----------------------------
# === UNCHANGED EIGENSCORE ===
# (exactly as provided)
# -----------------------------
import numpy as np
import torch

def getEigenIndicator_v2(hidden_states, num_tokens):
    alpha = 1e-3
    LayerEigens = []
    if len(hidden_states) < 2:
        return 0

    last_s = None  # keep the most recent singular values to return like before

    for layer_ind in range(len(hidden_states[0])):
        # build mean embedding per sequence (unchanged logic)
        last_embeddings = torch.zeros(
            hidden_states[1][-1].shape[0],
            hidden_states[1][-1].shape[2],
            device="cuda",
        )
        for seq_ind in range(hidden_states[1][-1].shape[0]):
            for token_ind in range(len(hidden_states) - 1):
                if token_ind > num_tokens[seq_ind] - 1:
                    continue
                last_embeddings[seq_ind, :] += hidden_states[token_ind + 1][layer_ind][seq_ind, 0, :]
            # NOTE: keep original denominator (may be zero/negative if inputs are off)
            last_embeddings[seq_ind, :] = last_embeddings[seq_ind, :] / (num_tokens[seq_ind] - 1)

        # cov -> numpy (unchanged math)
        CovMatrix = torch.cov(last_embeddings)
        Cov_np = CovMatrix.detach().cpu().numpy().astype(float)

        # quick sanity check: if the matrix is numerically bad, return NaN instead of crashing
        if not np.isfinite(Cov_np).all():
            return float("nan")

        # SVD with defense: return NaN if it doesn't converge
        try:
            u, s, vT = np.linalg.svd(Cov_np + alpha * np.eye(Cov_np.shape[0]), full_matrices=False)
        except np.linalg.LinAlgError:
            return float("nan")
        except Exception:
            # catch-all to avoid unexpected hard crashes
            return float("nan")

        last_s = s
        eigenIndicator = np.mean(np.log10(s))
        LayerEigens.append(eigenIndicator)

    LayerEigens = np.array(LayerEigens)
    # averages eigenscores from layer 20 to the second-to-last layer (unchanged)
    return np.mean(LayerEigens[20:-2])


def getEigenIndicator_v0(hidden_states, num_tokens): 
    alpha = 1e-3
    selected_layer = int(len(hidden_states[0])/2)
    # selected_layer = -1
    if len(hidden_states)<2:
        return 0, "None"
    last_embeddings = torch.zeros(hidden_states[1][-1].shape[0], hidden_states[1][-1].shape[2]).to("cuda")
    for ind in range(hidden_states[1][-1].shape[0]):
        last_embeddings[ind,:] = hidden_states[num_tokens[ind]-2][selected_layer][ind,0,:] 
    CovMatrix = torch.cov(last_embeddings).cpu().numpy().astype(float)
    u, s, vT = np.linalg.svd(CovMatrix+alpha*np.eye(CovMatrix.shape[0]))
    eigenIndicator = np.mean(np.log10(s))
    return eigenIndicator, s




# -----------------------------
# Helpers consistent with v2
# -----------------------------
def get_num_tokens(generation_ids: torch.Tensor) -> List[int]:
    """
    generation_ids: Tensor [K, T] (token ids for generated segment)
    Mimics user's counting: tokens with id > 2 + 1.
    """
    out = []
    for row in generation_ids:
        count = 0
        for tid in row.tolist():
            if int(tid) > 2:
                count += 1
        out.append(count + 1)
    return out

def _build_last_embeddings_v2(hidden_states, num_tokens, layer_ind) -> np.ndarray:
    """
    Recreate per-sequence averaged embedding at a given layer as in v2.
    Returns numpy array [K, D].
    """
    K = hidden_states[1][-1].shape[0]
    D = hidden_states[1][-1].shape[2]
    device = hidden_states[1][-1].device
    dtype  = hidden_states[1][-1].dtype

    last_embeddings = torch.zeros(K, D, device=device, dtype=dtype)
    T = len(hidden_states) - 1  # number of generation steps iterated in v2

    for seq_ind in range(K):
        n_tok = max(1, int(num_tokens[seq_ind]) - 1)
        acc = torch.zeros(D, device=device, dtype=dtype)
        for token_ind in range(T):
            if token_ind > num_tokens[seq_ind] - 1:
                continue
            acc += hidden_states[token_ind + 1][layer_ind][seq_ind, 0, :]
        last_embeddings[seq_ind, :] = acc / n_tok

    return last_embeddings.detach().float().cpu().numpy()  # [K, D] on CPU

def _cov_plus_alpha(last_embeddings: np.ndarray, alpha: float = 1e-3) -> Tuple[np.ndarray, np.ndarray]:
    """
    Covariance across K sequences + alpha I -> G (KxK) and its singular values.
    """
    Cov = np.cov(last_embeddings, bias=False)  # variables in rows
    G = Cov + alpha * np.eye(Cov.shape[0], dtype=Cov.dtype)
    s = np.linalg.svd(G, compute_uv=False)
    return G, s

def _layer_score_from_G(G: np.ndarray) -> float:
    """
    Layer score = mean(log10(singular values of G)).
    """
    s = np.linalg.svd(G, compute_uv=False)
    return float(np.mean(np.log10(s + 1e-12)))

def _loo_contrib_layer(last_embeddings: np.ndarray, alpha: float = 1e-3) -> np.ndarray:
    """
    Per-response LOO contributions at a single layer:
        contrib_i = E_full - E_minus_i
    """
    K = last_embeddings.shape[0]
    G_full, _ = _cov_plus_alpha(last_embeddings, alpha)
    E_full = _layer_score_from_G(G_full)
    contrib = np.zeros(K, dtype=np.float64)
    for i in range(K):
        mask = np.ones(K, dtype=bool); mask[i] = False
        G_minus_i, _ = _cov_plus_alpha(last_embeddings[mask], alpha)
        E_minus_i = _layer_score_from_G(G_minus_i)
        contrib[i] = E_full - E_minus_i
    return contrib

def _precision_diag_contrib_layer(last_embeddings: np.ndarray, alpha: float = 1e-3, eps: float = 1e-12) -> np.ndarray:
    """
    Per-response precision-diagonal contribution at a single layer:
        contrib_i = -log10( (G^{-1})_ii + eps )
    """
    G, _ = _cov_plus_alpha(last_embeddings, alpha)
    # K is small (10), invert directly
    try:
        Ginv = np.linalg.inv(G)
    except np.linalg.LinAlgError:
        Ginv = np.linalg.pinv(G, rcond=1e-8)
    diag_inv = np.clip(np.diag(Ginv), eps, None)
    return -np.log10(diag_inv)

def _normalize_per_prompt(x: np.ndarray, expand: bool) -> np.ndarray:
    """
    Min-max normalize within prompt; expand keeps as-is (higher=better),
    constrain flips (lower=better).
    """
    xmin, xmax = float(np.min(x)), float(np.max(x))
    denom = xmax - xmin if xmax > xmin else 1.0
    z = (x - xmin) / denom
    return z if expand else (1.0 - z)

# -----------------------------
# NLL computation functions (from NLL.py)
# -----------------------------
def _compute_online_nll_per_token(gen_ids: torch.Tensor, scores: list, eos_id: int, pad_id: int) -> np.ndarray:
    """
    gen_ids: [K, T] generated token ids (after input prefix)
    scores : list of length T, each tensor [K, V] logits for the token at that step
    Returns: per-seq length-normalized NLL (float64) and token counts (int)
    """
    K, T = gen_ids.size(0), gen_ids.size(1)
    # stack per-step log-probs gathered at generated ids -> [K, T]
    logps = []
    for t in range(T):
        # scores[t]: [K, V] logits for step t
        lp_t = torch.log_softmax(scores[t].to(torch.float32), dim=-1)
        idx_t = gen_ids[:, t].unsqueeze(-1)  # [K,1]
        lp_tok_t = lp_t.gather(dim=-1, index=idx_t).squeeze(-1)  # [K]
        logps.append(lp_tok_t)
    logps = torch.stack(logps, dim=1)  # [K, T]

    # Build valid mask: tokens strictly BEFORE first EOS; also exclude pads
    valid = torch.ones((K, T), dtype=torch.bool, device=gen_ids.device)
    if eos_id is not None:
        eos_mask = (gen_ids == eos_id)
        # find first eos per row
        first_eos = torch.where(eos_mask.any(dim=1),
                                eos_mask.float().argmax(dim=1),
                                torch.full((K,), T, device=gen_ids.device))  # if no eos, T
        # zero-out at and after eos
        for k in range(K):
            if first_eos[k] < T:
                valid[k, first_eos[k]:] = False
    if pad_id is not None:
        valid &= (gen_ids != pad_id)

    # sum NLL over valid positions
    n_tokens = valid.sum(dim=1)  # [K]
    nll_sum = (-logps.masked_fill(~valid, 0.0).sum(dim=1))  # [K]
    # length-normalized
    nll_per_tok = nll_sum / torch.clamp(n_tokens, min=1)

    return nll_per_tok.detach().cpu().numpy().astype(np.float64), n_tokens.detach().cpu().numpy().astype(np.int64)

# -----------------------------
# Lexical diversity functions (from lex.py)
# -----------------------------
def embed_texts(texts: List[str], encoder, batch_size: int = 64) -> np.ndarray:
    """
    Returns L2-normalized embeddings as np.ndarray [N, D]
    """
    # Sentence-Transformers encoders expose .encode; fall back to HF if needed
    if hasattr(encoder, "encode"):
        embs = encoder.encode(
            texts, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True
        )
        return embs.astype(np.float32)

    # HF fallback (pool CLS / mean pool)
    from transformers import AutoModel, AutoTokenizer  # local import
    raise RuntimeError("Provided encoder does not support .encode().")

def cosine_to_vector(E: np.ndarray, v: np.ndarray) -> np.ndarray:
    """
    E: [K, D] (assumed L2-normalized)
    v: [D]    (assumed L2-normalized)
    Returns cos(E_i, v) for all i
    """
    return (E @ v.astype(E.dtype))

# -----------------------------
# Main scoring routine
# -----------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input_csv", required=True, help="Path to CSV with columns: prompt, objective")
    ap.add_argument("--output_csv", required=True, help="Where to write long-form per-response CSV")
    ap.add_argument('--model_name', type=str, default='llama-8b-instruct')
    ap.add_argument("--max_new_tokens", type=int, default=256)
    ap.add_argument("--temperature", type=float, default=1.0)
    ap.add_argument("--top_p", type=float, default=0.99)
    ap.add_argument("--top_k", type=int, default=10)
    ap.add_argument("--num_return_sequences", type=int, default=10)  # K=10
    ap.add_argument("--device", default="cuda")
    ap.add_argument("--thinking", type=bool, default=True, help="Enable thinking mode for Qwen models")
    ap.add_argument("--diversity_metric", type=str, default="LOO", choices=["LOO", "NLL", "lex"], 
                    help="Diversity metric to compute: LOO (leave-one-out), NLL (negative log-likelihood), or lex (lexical diversity)")
    args = ap.parse_args()

    # Load data
    df = pd.read_csv(args.input_csv)
    
    # Different required columns based on diversity metric
    if args.diversity_metric == "lex":
        required_cols = {"prompt", "objective", "response"}
    else:
        required_cols = {"prompt", "objective"}
    
    missing = required_cols - set(df.columns)
    if missing:
        raise ValueError(f"Input CSV missing columns: {missing}")

    # Load model only if not using lexical diversity
    if args.diversity_metric != "lex":
        model, tokenizer = models.load_model_and_tokenizer(args.model_name, args.device)
        model.eval()

        # Detect model type
        is_qwen = "qwen" in args.model_name.lower()
        is_llama = "llama" in args.model_name.lower()
        
        if not (is_qwen or is_llama):
            print(f"Warning: Model {args.model_name} not recognized as Qwen or Llama. Defaulting to Llama behavior.")
    else:
        # For lexical diversity, load sentence transformer
        try:
            from sentence_transformers import SentenceTransformer
        except Exception as e:
            raise RuntimeError(
                "Please install sentence-transformers: pip install sentence-transformers"
            ) from e
        encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        model = None
        tokenizer = None
        is_qwen = False
        is_llama = False

    rows_out = []
    K = args.num_return_sequences

    if args.diversity_metric == "lex":
        # Handle lexical diversity case - process by prompt groups
        df_out = df.copy()
        df_out["lex_cos_to_mean"] = np.nan
        df_out["lex_dist_to_mean"] = np.nan
        df_out["lex_reward"] = np.nan

        # Group by prompt (normalize within each group of K responses)
        grouped = df_out.groupby("prompt", sort=False, group_keys=False)

        for prompt, g in tqdm(grouped, desc="Scoring by prompt"):
            idx = g.index.to_list()
            objective_vals = g["objective"].astype(str).str.lower().unique()
            if len(objective_vals) != 1:
                raise ValueError(f"Prompt has mixed objectives: {objective_vals} for prompt: {prompt}")
            objective = objective_vals[0]
            if objective not in {"expand", "constrain"}:
                raise ValueError(f"objective must be 'expand' or 'constrain', got: {objective}")

            responses = g["response"].astype(str).tolist()
            # --- Lexical diversity via embedding distance from mean ---
            E = embed_texts(responses, encoder)     # [K, D], already L2-normalized
            mu = E.mean(axis=0)
            mu_norm = np.linalg.norm(mu)
            if mu_norm <= 1e-12:
                # degenerate case: identical/zero vectors
                mu = np.zeros_like(mu)
                cos = np.zeros(len(E), dtype=np.float32)
            else:
                mu = (mu / mu_norm).astype(np.float32)
                cos = cosine_to_vector(E, mu)      # [-1, 1]

            dist = 1.0 - cos                        # higher => farther from mean => more "diverse"
            lex_reward = _normalize_per_prompt(dist, objective == "expand")

            df_out.loc[idx, "lex_cos_to_mean"] = cos
            df_out.loc[idx, "lex_dist_to_mean"] = dist
            df_out.loc[idx, "lex_reward"] = lex_reward

        out_df = df_out
    else:
        # Handle LOO and NLL cases - generate responses
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Scoring"):
            prompt = str(row["prompt"])
            objective = str(row["objective"]).strip().lower()
            if objective not in {"expand", "constrain"}:
                raise ValueError(f"objective must be 'expand' or 'constrain', got: {objective}")

            # Prepare input ids based on model type
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
            ]
            
            if is_qwen:
                # Qwen-specific chat template with thinking support
                text = tokenizer.apply_chat_template(
                    messages,
                    enable_thinking=args.thinking,
                    tokenize=False,
                    add_generation_prompt=True
                )
                inputs = tokenizer([text], return_tensors="pt").to(model.device)
            else:
                # Llama-specific chat template
                inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
                terminators = [
                    tokenizer.eos_token_id,
                    tokenizer.convert_tokens_to_ids("<|eot_id|>")
                ]
            
            # Determine input_ids and input_length robustly for both Tensor and BatchEncoding
            if isinstance(inputs, torch.Tensor):
                input_ids = inputs.to(model.device)
                input_length = input_ids.shape[1]
                generate_inputs = {"input_ids": input_ids}
            else:
                # BatchEncoding
                input_ids = inputs["input_ids"].to(model.device)
                input_length = input_ids.shape[1]
                generate_inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
            
            if is_qwen:
                generation_config = transformers.GenerationConfig(
                    max_new_tokens=args.max_new_tokens,
                    pad_token_id=tokenizer.eos_token_id
                )
            else:
                generation_config = transformers.GenerationConfig(
                    max_new_tokens=args.max_new_tokens,
                    pad_token_id=tokenizer.eos_token_id,
                    eos_token_id=terminators
                )

            # Generate K samples with appropriate outputs based on diversity metric
            output_hidden_states = (args.diversity_metric == "LOO")
            output_scores = (args.diversity_metric == "NLL")
            
            with torch.no_grad():
                dict_outputs = model.generate(
                    **generate_inputs,
                    num_beams=1,
                    num_return_sequences=K,
                    do_sample=True,
                    top_p   = args.top_p,
                    top_k   = args.top_k,
                    temperature=args.temperature,
                    generation_config=generation_config,
                    output_hidden_states=output_hidden_states,
                    return_dict_in_generate=True,
                    output_scores=output_scores,
                )

            # Extract responses
            generation = dict_outputs.sequences[:, input_length:].cpu()
            responses = [tokenizer.decode(generation[i], skip_special_tokens=True) for i in range(generation.size(0))]
            
            # Extract thinking content for Qwen models
            thinking_contents = []
            thinking_token_lens = []
            final_responses = []
            
            if is_qwen and args.thinking:
                for i in range(generation.size(0)):
                    output_ids = generation[i].tolist()
                    try:
                        # rindex finding 151668 (</think>)
                        index = len(output_ids) - output_ids[::-1].index(151668)
                    except ValueError:
                        index = 0
                    
                    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
                    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
                    thinking_token_ids = output_ids[:index]
                    thinking_token_len = len(thinking_token_ids)
                    
                    thinking_contents.append(thinking_content)
                    thinking_token_lens.append(thinking_token_len)
                    final_responses.append(content)  # Use content as the actual response
            else:
                # For Llama or Qwen without thinking, use original responses
                thinking_contents = [""] * generation.size(0)
                thinking_token_lens = [0] * generation.size(0)
                final_responses = responses

            # Process based on diversity metric
            if args.diversity_metric == "LOO":
                # LOO processing
                num_tokens = get_num_tokens(generation)
                hidden_states = dict_outputs.hidden_states
                
                # Global eigenscore (same for all responses in this prompt)
                E_global = getEigenIndicator_v2(hidden_states, num_tokens)

                # Build per-layer last_embeddings consistent with v2 and derive contributions
                L = len(dict_outputs.hidden_states[0])  # number of layers
                layer_start = 20
                layer_end = max(0, L - 2)  # end exclusive; matches v2's 20:-2
                n_layers = max(1, layer_end - layer_start)

                LLO_accum = None

                for layer_ind in range(layer_start, layer_end):
                    last_emb = _build_last_embeddings_v2(hidden_states, num_tokens, layer_ind)  # [K, D]
                    llo_layer = _loo_contrib_layer(last_emb, alpha=1e-3)                  # [K]
                    LLO_accum = llo_layer if LLO_accum is None else (LLO_accum + llo_layer)
                 
                LLO_raw = LLO_accum / n_layers

                # Normalize to rewards within this prompt, per objective
                expand = (objective == "expand")
                LLO_reward = _normalize_per_prompt(LLO_raw, expand=expand)

                # One row per response
                for i in range(len(final_responses)):
                    row_data = {
                        "prompt": prompt,
                        "objective": objective,
                        "response": final_responses[i],
                        "eigenscore": float(E_global),
                        "LLO_raw": float(LLO_raw[i]),
                        "LLO_reward": float(LLO_reward[i]),
                        "thinking_content": thinking_contents[i],
                        "thinking_token_len": thinking_token_lens[i],
                    }
                    rows_out.append(row_data)

            elif args.diversity_metric == "NLL":
                # NLL processing
                scores_list = [s.to(model.device) for s in dict_outputs.scores]
                eos_id = tokenizer.eos_token_id
                pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
                nll_per_tok, n_tokens_valid = _compute_online_nll_per_token(
                    gen_ids=generation.to(model.device),
                    scores=scores_list,
                    eos_id=eos_id,
                    pad_id=pad_id,
                )
                expand = (objective == "expand")
                NLL_reward = _normalize_per_prompt(nll_per_tok, expand=expand)

                # One row per response
                for i in range(len(final_responses)):
                    row_data = {
                        "prompt": prompt,
                        "objective": objective,
                        "response": final_responses[i],
                        "NLL": float(nll_per_tok[i]),
                        "NLL_reward": float(NLL_reward[i]),
                        "tokens_scored": int(n_tokens_valid[i]),
                        "thinking_content": thinking_contents[i],
                        "thinking_token_len": thinking_token_lens[i],
                    }
                    rows_out.append(row_data)

            # cleanup
            del dict_outputs, generation
            if args.diversity_metric == "LOO":
                del hidden_states
            torch.cuda.empty_cache(); gc.collect()

        out_df = pd.DataFrame(rows_out)

    out_df.to_csv(args.output_csv, index=False)
    print(f"Wrote {len(out_df)} rows to {args.output_csv}")

if __name__ == "__main__":
    main()
