#!/usr/bin/env python3

import os
import sys
from datetime import datetime

import numpy as np
import torch
from tqdm import tqdm

# Match batch sizing behavior of the original metrics script
PROB_BATCH_SIZE = 50

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from paper_experiments.shared_utils import *


def get_top_1_pairs_with_diff_GPU(probs_with_steering, probs_no_steering):

    top_1_with = probs_with_steering.argmax(dim=-1)
    top_1_no = probs_no_steering.argmax(dim=-1)
    differs = (top_1_with != top_1_no).long()
    top_1_triples = (
        torch.stack([top_1_with, top_1_no, differs], dim=-1)
        .cpu()
        .numpy()
    )
    return top_1_triples


def ensure_steering_vector(exp_dir, strip_newline, layer, model, tokenizer, negative_prompts, positive_prompts):

    steering_vector_dir = os.path.join(exp_dir, "steering_vectors")
    os.makedirs(steering_vector_dir, exist_ok=True)
    steering_vector_file = os.path.join(
        steering_vector_dir, f"steering_vector_{strip_newline}_{layer}.pkl"
    )
    if os.path.exists(steering_vector_file):
        return torch.load(steering_vector_file)

    steering_vector = get_steering_vector_fast(
        model, tokenizer, negative_prompts, positive_prompts, layer=layer
    )
    torch.save(steering_vector, steering_vector_file)
    return steering_vector


def main(
    mode,
    model_name,
    model,
    tokenizer,
    rhyme_family1,
    rhyme_family2,
    output_dir,
    strip=False,
):
    print(
        f"Starting RAW probability-based metrics for {model_name}: {rhyme_family1} vs {rhyme_family2}"
    )

    exp_dir = setup_output_directory(
        mode, output_dir, model_name, rhyme_family1, rhyme_family2
    )
    input_file = os.path.join(exp_dir, "generated_lines.json")
    output_file = os.path.join(exp_dir, "prob_based_metrics_raw.json")

    if not os.path.exists(input_file):
        print(f"Input file not found: {input_file}")
        print("Please run stage_line_generation.py first")
        return

    print("Loading generated lines...")
    line_data = load_data(input_file)

    unsteered_texts = line_data["unsteered_texts"]
    steered_texts = line_data["steered_texts"]
    layers = line_data["layers"]
    # Default to True as requested; can still be enabled via --strip
    strip_newline = True
    batch_size_small = line_data["batch_size_small"]
    print(f"Using batch size: {batch_size_small}")

    layer_metrics = {}

    last_hidden_states_no_steering = generate_steered_output(
        None,
        model,
        tokenizer,
        unsteered_texts,
        batch_size_per_prompt=1,
        num_prompts_per_rollout=batch_size_small,
        return_type="last_hidden_state",
    )

    min_idxs, max_idxs = get_min_and_max_idxs(unsteered_texts, tokenizer)
    token_to_steer = min_idxs.copy()
    if strip_newline:
        token_to_steer -= 1

    # Build prompts for steering vectors consistent with other scripts
    pair = (rhyme_family1, rhyme_family2)
    negative_prompts = load_prompts(mode, "train", pair[0], model_name, strip=strip_newline)
    positive_prompts = load_prompts(mode, "train", pair[1], model_name, strip=strip_newline)

    for layer in tqdm(layers, desc="Computing RAW probability-based metrics"):
        print(f"Processing layer {layer}...")

        steering_vector = ensure_steering_vector(
            exp_dir,
            strip_newline,
            layer,
            model,
            tokenizer,
            negative_prompts,
            positive_prompts,
        )

        print(f"Computing probabilities for layer {layer}...")
        last_hidden_states_with_steering = generate_steered_output(
            steering_vector,
            model,
            tokenizer,
            unsteered_texts,
            batch_size_per_prompt=1,
            num_prompts_per_rollout=batch_size_small,
            layer=layer,
            steering_multiplier=STEERING_MULTIPLIER,
            return_type="last_hidden_state",
            token_to_steer=token_to_steer.tolist(),
        )

        kl_divergences = []
        top_1_triples_list = []

        for i in range(0, BATCH_SIZE, PROB_BATCH_SIZE):
            with torch.no_grad():
                probs_with_steering = torch.nn.functional.softmax(
                    model.lm_head(
                        last_hidden_states_with_steering[i : i + PROB_BATCH_SIZE]
                    ),
                    dim=-1,
                )
                probs_no_steering = torch.nn.functional.softmax(
                    model.lm_head(
                        last_hidden_states_no_steering[i : i + PROB_BATCH_SIZE]
                    ),
                    dim=-1,
                )

            kl_divergences.append(
                calculate_kl_divergence(probs_with_steering, probs_no_steering)
            )
            top_1_triples_list.append(
                get_top_1_pairs_with_diff_GPU(probs_with_steering, probs_no_steering)
            )

        kl_divergences = np.concatenate(kl_divergences, axis=0)
        top_1_triples = np.concatenate(top_1_triples_list, axis=0)
        if strip_newline:
            kl_divergences[np.arange(kl_divergences.shape[0]), token_to_steer] = 0.0

        idx_of_first_high_kl_divergence = get_idx_of_first_high_kl_divergence(
            kl_divergences > 1.0, max_idxs
        )
        top_1_differences = top_1_triples[:, :, 2].astype(bool) if top_1_triples.ndim == 3 else top_1_triples[:, 2:3].astype(bool)
        idx_of_first_top_1_difference = get_idx_of_first_top_1_difference(
            top_1_differences, max_idxs
        )

        layer_metrics[layer] = {
            "kl_divergences": kl_divergences.tolist(),
            "top_1_triples": top_1_triples.tolist(),
            "idx_of_first_top_1_difference": idx_of_first_top_1_difference.tolist(),
            "idx_of_first_high_kl_divergence": idx_of_first_high_kl_divergence.tolist(),
            "min_idxs": min_idxs.tolist(),
            "max_idxs": max_idxs.tolist(),
        }

        del probs_with_steering, kl_divergences
        cleanup_gpu_memory()

        print(f"Completed layer {layer}")
        print(
            f"  Avg first top-1 diff idx: {idx_of_first_top_1_difference.mean():.2f}"
        )
        print(
            f"  Avg first high KL divergence idx: {idx_of_first_high_kl_divergence.mean():.2f}"
        )
        print(f"  Avg min idx: {min_idxs.mean():.2f}")
        print(f"  Avg max idx: {max_idxs.mean():.2f}")

    output_data = {
        "layer_metrics": layer_metrics,
        "metadata": {
            "model_name": model_name,
            "rhyme_family1": rhyme_family1,
            "rhyme_family2": rhyme_family2,
            "timestamp": datetime.now().isoformat(),
            "batch_size_small": batch_size_small,
            "kl_threshold": 1.0,
        },
    }

    save_data(output_data, output_file)
    print(f"RAW probability-based metrics completed and saved to: {output_file}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    try:
        for rhyme_family1, rhyme_family2 in pairs:
            main(
                args.mode,
                args.model_name,
                model,
                tokenizer,
                rhyme_family1,
                rhyme_family2,
                args.output_dir,
                strip=args.strip,
            )
    finally:
        del model, tokenizer
        cleanup_gpu_memory()


