#!/usr/bin/env python3

import os
import sys
from datetime import datetime

from tqdm import tqdm

PROB_BATCH_SIZE = 50

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Also add the parent of the parent
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# and the parent of the parent of the parent
sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)

from paper_experiments.shared_utils import *


def main(
    mode,
    model_name,
    model,
    tokenizer,
    rhyme_family1,
    rhyme_family2,
    output_dir,
    strip=False,
):
    print(
        f"Starting probability-based metrics for {model_name}: {rhyme_family1} vs {rhyme_family2}"
    )

    # Setup directories
    exp_dir = setup_output_directory(
        mode, output_dir, model_name, rhyme_family1, rhyme_family2
    )
    input_file = os.path.join(exp_dir, "generated_lines.json")
    output_file = os.path.join(exp_dir, "prob_based_metrics.json")

    # Check if already completed
    if os.path.exists(output_file):
        print(f"Probability-based metrics already completed: {output_file}")
        return

    # Check if input exists
    if not os.path.exists(input_file):
        print(f"Input file not found: {input_file}")
        print("Please run stage_line_generation.py first")
        return

    # Load generated lines
    print("Loading generated lines...")
    line_data = load_data(input_file)

    unsteered_texts = line_data["unsteered_texts"]
    steered_texts = line_data["steered_texts"]
    # steering_vectors = line_data["steering_vectors"]
    layers = line_data["layers"]
    strip_newline = line_data["strip_newline"]
    # Use batch size from line generation
    batch_size_small = line_data["batch_size_small"]
    # batch_size_small = 1000
    print(f"Using batch size: {batch_size_small}")

    # Compute probability-based metrics for each layer
    layer_metrics = {}

    # Get probability distributions
    last_hidden_states_no_steering = generate_steered_output(
        None,
        model,
        tokenizer,
        unsteered_texts,
        batch_size_per_prompt=1,
        num_prompts_per_rollout=batch_size_small,
        return_type="last_hidden_state",
    )

    min_idxs, max_idxs = get_min_and_max_idxs(unsteered_texts, tokenizer)
    token_to_steer = min_idxs.copy()
    if strip_newline:
        token_to_steer -= 1

    for layer in tqdm(layers, desc="Computing probability-based metrics"):
        print(f"Processing layer {layer}...")

        steering_vector_dir = os.path.join(exp_dir, "steering_vectors")
        steering_vector_file = os.path.join(
            steering_vector_dir, f"steering_vector_{strip_newline}_{layer}.pkl"
        )
        steering_vector = torch.load(steering_vector_file)

        print(f"Computing probabilities for layer {layer}...")
        last_hidden_states_with_steering = generate_steered_output(
            steering_vector,
            model,
            tokenizer,
            unsteered_texts,
            batch_size_per_prompt=1,
            num_prompts_per_rollout=batch_size_small,
            layer=layer,
            steering_multiplier=STEERING_MULTIPLIER,
            return_type="last_hidden_state",
            token_to_steer=token_to_steer.tolist(),
        )

        kl_divergences = []
        top_1_differences = []

        for i in range(0, BATCH_SIZE, PROB_BATCH_SIZE):
            with torch.no_grad():
                probs_with_steering = torch.nn.functional.softmax(
                    model.lm_head(
                        last_hidden_states_with_steering[i : i + PROB_BATCH_SIZE]
                    ),
                    dim=-1,
                )
                probs_no_steering = torch.nn.functional.softmax(
                    model.lm_head(
                        last_hidden_states_no_steering[i : i + PROB_BATCH_SIZE]
                    ),
                    dim=-1,
                )

            kl_divergences.append(
                calculate_kl_divergence(probs_with_steering, probs_no_steering)
            )
            top_1_differences.append(
                get_top_1_difference(probs_with_steering, probs_no_steering)
            )

        kl_divergences = np.concatenate(kl_divergences, axis=0)
        top_1_differences = np.concatenate(top_1_differences, axis=0)
        if strip_newline:
            kl_divergences[np.arange(kl_divergences.shape[0]), token_to_steer] = 0.0
            top_1_differences[np.arange(top_1_differences.shape[0]), token_to_steer] = False

        # Calculate metrics
        high_kl_divergences = get_kl_above_threshold(kl_divergences)
        idx_of_first_high_kl_divergence = get_idx_of_first_high_kl_divergence(
            high_kl_divergences, max_idxs
        )
        idx_of_first_top_1_difference = get_idx_of_first_top_1_difference(
            top_1_differences, max_idxs
        )

        layer_metrics[layer] = {
            "high_kl_divergences": high_kl_divergences.tolist(),
            "top_1_differences": top_1_differences.tolist(),
            "idx_of_first_top_1_difference": idx_of_first_top_1_difference.tolist(),
            "idx_of_first_high_kl_divergence": idx_of_first_high_kl_divergence.tolist(),
            "min_idxs": min_idxs.tolist(),
            "max_idxs": max_idxs.tolist(),
        }

        # Clean up memory after each layer
        del probs_with_steering, kl_divergences
        cleanup_gpu_memory()

        print(f"Completed layer {layer}")
        print(
            f"  High KL divergences: {get_high_kl_divergence_fraction(high_kl_divergences, min_idxs, max_idxs):.4f}"
        )
        print(
            f"  Top-1 differences: {get_top_1_difference_fraction(top_1_differences, min_idxs, max_idxs):.4f}"
        )
        print(f"  Avg first top-1 diff idx: {idx_of_first_top_1_difference.mean():.2f}")
        print(
            f"  Avg first high KL divergence idx: {idx_of_first_high_kl_divergence.mean():.2f}"
        )
        print(f"  Avg min idx: {min_idxs.mean():.2f}")
        print(f"  Avg max idx: {max_idxs.mean():.2f}")

    # Prepare output data
    output_data = {
        "layer_metrics": layer_metrics,
        "metadata": {
            "model_name": model_name,
            "rhyme_family1": rhyme_family1,
            "rhyme_family2": rhyme_family2,
            "timestamp": datetime.now().isoformat(),
            "batch_size_small": batch_size_small,
            "kl_threshold": 1.0,
        },
    }

    # Save results
    save_data(output_data, output_file)
    print(f"Probability-based metrics completed and saved to: {output_file}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    # Load model once
    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    try:
        for rhyme_family1, rhyme_family2 in pairs:
            main(
                args.mode,
                args.model_name,
                model,
                tokenizer,
                rhyme_family1,
                rhyme_family2,
                args.output_dir,
                strip=args.strip,
            )
    finally:
        # Clean up model
        del model, tokenizer
        cleanup_gpu_memory()
