import argparse
from pathlib import Path

import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

from utils import (
    get_word_probability,
    get_steering_vector,
    get_steered_word_probability,
    get_hidden_states,
)

output_dir = Path("output")
data_dir = Path("data")


def find_diff_positions(tokenizer, text_a, text_b):
    """Find token position ranges where two texts diverge."""
    tokens_a = tokenizer.encode(text_a, add_special_tokens=False)
    tokens_b = tokenizer.encode(text_b, add_special_tokens=False)

    # First mismatch from the start
    start = 0
    for i in range(min(len(tokens_a), len(tokens_b))):
        if tokens_a[i] != tokens_b[i]:
            start = i
            break

    # First mismatch from the end
    end_a, end_b = len(tokens_a), len(tokens_b)
    for i in range(1, min(len(tokens_a), len(tokens_b)) + 1):
        if tokens_a[-i] != tokens_b[-i]:
            break
        end_a -= 1
        end_b -= 1

    return (start, end_a), (start, end_b)


def _compute_pair_diff(model, tokenizer, explicit_text, implicit_text, layer, pooling, method):
    """Compute the hidden-state difference for a single explicit/implicit pair."""
    if method in ("full", "loo"):
        return get_steering_vector(
            model, tokenizer,
            positive_texts=[explicit_text],
            negative_texts=[implicit_text],
            layer=layer,
            pooling=pooling,
        )
    elif method == "cue":
        (exp_start, exp_end), (imp_start, imp_end) = find_diff_positions(
            tokenizer, explicit_text, implicit_text
        )
        exp_hidden = get_hidden_states(
            model, tokenizer, explicit_text, layer=layer, pooling=None
        )
        imp_hidden = get_hidden_states(
            model, tokenizer, implicit_text, layer=layer, pooling=None
        )
        exp_cue = exp_hidden[0, exp_start:exp_end, :].mean(dim=0)
        imp_cue = imp_hidden[0, imp_start:imp_end, :].mean(dim=0)
        return exp_cue - imp_cue
    else:
        raise ValueError(f"Unknown method: {method}. Use 'full', 'cue', or 'loo'.")


def compute_steering_vectors(model, tokenizer, fb_data, layer, pooling, method="full"):
    """Compute steering vectors from paired explicit/implicit passages.

    Args:
        model: The language model
        tokenizer: The tokenizer
        fb_data: DataFrame with columns: item, condition, first_mention, recent_mention,
                 knowledge_cue, passage
        layer: Which layer to extract hidden states from
        pooling: Pooling strategy ("last", "mean", "first")
        method: "full" for full-passage pooling, "cue" for cue-targeted pooling,
                "loo" for leave-one-out (per item, averaged from other items)

    Returns:
        For "full"/"cue": dict {condition: steering_vector}
        For "loo": dict {(condition, item): steering_vector}
    """
    group_cols = ["item", "condition", "first_mention", "recent_mention"]
    # condition -> {item: [diff_vectors]}
    pair_diffs = {}

    with torch.no_grad():
        for group_key, group_df in tqdm(fb_data.groupby(group_cols), desc="Computing steering vectors"):
            item = group_key[0]
            condition = group_key[1]

            explicit_row = group_df.loc[group_df["knowledge_cue"] == "Explicit"]
            implicit_row = group_df.loc[group_df["knowledge_cue"] == "Implicit"]

            if len(explicit_row) != 1 or len(implicit_row) != 1:
                continue

            explicit_text = explicit_row.iloc[0]["passage"].replace("[MASK].", "")
            implicit_text = implicit_row.iloc[0]["passage"].replace("[MASK].", "")

            diff = _compute_pair_diff(model, tokenizer, explicit_text, implicit_text,
                                      layer, pooling, method)

            if condition not in pair_diffs:
                pair_diffs[condition] = {}
            if item not in pair_diffs[condition]:
                pair_diffs[condition][item] = []
            pair_diffs[condition][item].append(diff)

    if method == "loo":
        # Leave-one-out: for each item, average all OTHER items' diffs in that condition
        steering_vectors = {}
        for condition, item_diffs in pair_diffs.items():
            items = list(item_diffs.keys())
            # Flatten each item's diffs into a single mean vector
            item_means = {it: torch.stack(diffs).mean(dim=0) for it, diffs in item_diffs.items()}
            for held_out in items:
                others = [item_means[it] for it in items if it != held_out]
                steering_vectors[(condition, held_out)] = torch.stack(others).mean(dim=0)
        return steering_vectors
    else:
        # Average all diffs per condition
        steering_vectors = {}
        for condition, item_diffs in pair_diffs.items():
            all_diffs = [d for diffs in item_diffs.values() for d in diffs]
            steering_vectors[condition] = torch.stack(all_diffs).mean(dim=0)
        return steering_vectors


def compute_base_probabilities(model, tokenizer, fb_data, bow_token_ids):
    """Compute unsteered probabilities once for all items.

    Returns:
        dict: {item_id: {"start_prob": float, "end_prob": float}}
    """
    base_probs = {}
    with torch.no_grad():
        for _, row in tqdm(fb_data.iterrows(), total=fb_data.shape[0],
                           desc="Computing base probabilities"):
            context = row["passage"].replace("[MASK].", "")
            start_prob = get_word_probability(
                model, tokenizer, context, row["start"], bow_token_ids
            )
            end_prob = get_word_probability(
                model, tokenizer, context, row["end"], bow_token_ids
            )
            base_probs[row["item_id"]] = {
                "start_prob": start_prob,
                "end_prob": end_prob,
            }
    return base_probs


def evaluate_steering(model, tokenizer, fb_data, bow_token_ids, steering_vectors,
                      layer, scales, model_details, loo=False, base_probs=None):
    """Run evaluation with steered and unsteered probabilities.

    Args:
        model: The language model
        tokenizer: The tokenizer
        fb_data: DataFrame with FB stimuli
        bow_token_ids: List of beginning-of-word token IDs
        steering_vectors: dict {condition: vector} or {(condition, item): vector} for LOO
        layer: Which layer to apply steering
        scales: List of scale values to evaluate (e.g., [0.1, -0.1, 1.0, -1.0])
        model_details: dict with model metadata
        loo: If True, look up steering vectors by (condition, item)
        base_probs: Pre-computed {item_id: {"start_prob", "end_prob"}} or None

    Returns:
        list[dict]: Output rows with probabilities and predictions
    """
    implicit_data = fb_data[fb_data["knowledge_cue"] == "Implicit"].copy()
    explicit_data = fb_data[fb_data["knowledge_cue"] == "Explicit"].copy()

    output = []
    with torch.no_grad():
        for data in [implicit_data, explicit_data]:
            cue_label = data.iloc[0]["knowledge_cue"]
            for scale in scales:
                for _, row in tqdm(data.iterrows(), total=data.shape[0],
                                   desc=f"{cue_label} scale={scale}"):
                    context = row["passage"].replace("[MASK].", "")
                    if loo:
                        steering_vec = steering_vectors[(row["condition"], row["item"])]
                    else:
                        steering_vec = steering_vectors[row["condition"]]

                    if base_probs is not None:
                        start_prob = base_probs[row["item_id"]]["start_prob"]
                        end_prob = base_probs[row["item_id"]]["end_prob"]
                    else:
                        start_prob = get_word_probability(
                            model, tokenizer, context, row["start"], bow_token_ids
                        )
                        end_prob = get_word_probability(
                            model, tokenizer, context, row["end"], bow_token_ids
                        )

                    steered_start_prob = get_steered_word_probability(
                        model, tokenizer, context, row["start"],
                        steering_vec=steering_vec, layer=layer,
                        scale=scale, bow_token_ids=bow_token_ids,
                    )
                    steered_end_prob = get_steered_word_probability(
                        model, tokenizer, context, row["end"],
                        steering_vec=steering_vec, layer=layer,
                        scale=scale, bow_token_ids=bow_token_ids,
                    )

                    highest_prob_word = (
                        row["start"] if start_prob > end_prob else row["end"]
                    )
                    steered_highest_prob_word = (
                        row["start"] if steered_start_prob > steered_end_prob else row["end"]
                    )
                    correct = 1 if highest_prob_word == row["critical_a"] else 0
                    steered_correct = 1 if steered_highest_prob_word == row["critical_a"] else 0

                    output.append({
                        "item_id": row["item_id"],
                        "item": row["item"],
                        "condition": row["condition"],
                        "knowledge_cue": row["knowledge_cue"],
                        "scale": scale,
                        **model_details,
                        "context": context,
                        "token_c1": row["start"],
                        "start_prob": start_prob,
                        "end_prob": end_prob,
                        "steered_start_prob": steered_start_prob,
                        "steered_end_prob": steered_end_prob,
                        "token_c2": row["end"],
                        "prediction": highest_prob_word,
                        "steered_prediction": steered_highest_prob_word,
                        "correct": correct,
                        "steered_correct": steered_correct,
                    })

    return output


def main(args):
    fb_data = pd.read_csv(data_dir / "fb_stimuli.csv", delimiter=",")

    model_precision = torch.float32
    if args.model_size > 16:
        model_precision = torch.float16

    model = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        revision=args.model_stage if args.model_stage != "main" else None,
        dtype=model_precision,
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)

    bow_token_ids = []
    for token_id in range(tokenizer.vocab_size):
        token_str = tokenizer.decode([token_id])
        if (
            token_str.startswith("Ġ")
            or token_str.startswith("▁")
            or (token_str.startswith(" ") and len(token_str) > 1)
        ):
            bow_token_ids.append(token_id)

    model.eval()
    print(f"Loaded model: {args.model_id} (stage: {args.model_stage})")

    scales = [float(s) for s in args.scales.split(",")]
    extract_layers = [int(l) for l in args.layers.split(",")]
    inject_layers = [int(l) for l in args.inject_layers.split(",")] if args.inject_layers else None
    loo = args.method == "loo"

    model_details = {
        "model_id": args.model_id,
        "model_stage": args.model_stage,
    }

    results_dir = output_dir / args.experiment_name / "steering_results"
    results_dir.mkdir(parents=True, exist_ok=True)

    # Compute unsteered probabilities once
    print("Computing base (unsteered) probabilities...")
    base_probs = compute_base_probabilities(model, tokenizer, fb_data, bow_token_ids)

    all_results = []

    for extract_layer in extract_layers:
        steering_vectors = compute_steering_vectors(
            model, tokenizer, fb_data, layer=extract_layer, pooling=args.pooling, method=args.method,
        )

        # Save steering vectors
        vec_path = results_dir / f"steering_vectors_layer{extract_layer}_{args.model_stage}.pt"
        torch.save(steering_vectors, vec_path)
        print(f"Saved steering vectors to {vec_path}")

        # If inject_layers specified, apply at those layers; otherwise apply at same layer
        apply_layers = inject_layers if inject_layers else [extract_layer]

        for inject_layer in apply_layers:
            print(f"\n--- Extract: layer {extract_layer}, Inject: layer {inject_layer}, "
                  f"method={args.method}, pooling={args.pooling} ---")

            evaluation_output = evaluate_steering(
                model, tokenizer, fb_data, bow_token_ids, steering_vectors,
                layer=inject_layer, scales=scales,
                model_details={**model_details, "extract_layer": extract_layer, "inject_layer": inject_layer},
                loo=loo,
                base_probs=base_probs,
            )

            all_results.extend(evaluation_output)

            # Print summary
            df = pd.DataFrame(evaluation_output)
            df["margin"] = (df["start_prob"] - df["end_prob"]).abs()
            df["steered_margin"] = (df["steered_start_prob"] - df["steered_end_prob"]).abs()
            df["prob_shift"] = df["steered_margin"] - df["margin"]
            df["signed_shift"] = (
                (df["steered_start_prob"] - df["steered_end_prob"])
                - (df["start_prob"] - df["end_prob"])
            )
            summary = df.groupby(["condition", "knowledge_cue", "scale"])[
                ["correct", "steered_correct", "margin", "steered_margin", "prob_shift", "signed_shift"]
            ].mean()
            print(f"\nSummary (extract={extract_layer}, inject={inject_layer}):")
            print(summary)

    # Save all results to a single CSV
    combined_df = pd.DataFrame(all_results)
    csv_path = results_dir / f"steering_eval_{args.model_stage}.csv"
    combined_df.to_csv(csv_path, index=False)
    print(f"\nSaved combined results to {csv_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Steering vector evaluation for False Belief task")
    parser.add_argument("--model_id", type=str, default="allenai/OLMo-2-1124-7B-Instruct",
                        help="HuggingFace model ID")
    parser.add_argument("--model_stage", type=str, default="main",
                        help="Model revision/branch to load")
    parser.add_argument("--model_size", type=float, default=7.0,
                        help="Model size in billions (for dtype selection)")
    parser.add_argument("--layers", type=str, default="-1",
                        help="Comma-separated layers to extract steering vectors from (e.g., '12,16,-1')")
    parser.add_argument("--inject_layers", type=str, default=None,
                        help="Comma-separated layers to inject steering at. If not set, injects at same layer as extraction.")
    parser.add_argument("--scales", type=str, default="0.1,-0.1,1.0,-1.0",
                        help="Comma-separated scale values (e.g., '0.1,-0.1,1.0,-1.0')")
    parser.add_argument("--pooling", type=str, default="last", choices=["last", "mean", "first"],
                        help="Pooling strategy for steering vector extraction")
    parser.add_argument("--method", type=str, default="full", choices=["full", "cue", "loo"],
                        help="'full' for full-passage, 'cue' for cue-targeted, 'loo' for leave-one-out")
    parser.add_argument("--experiment_name", type=str, default="steering",
                        help="Experiment name (output subdirectory)")

    args = parser.parse_args()
    main(args)
