#!/usr/bin/env python3

import os
import sys
from datetime import datetime

from tqdm import tqdm

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from paper_experiments.shared_utils import *


def main(mode, model_name, model, tokenizer, rhyme_family1, rhyme_family2, output_dir):
    print(
        f"Starting standard metrics for {model_name}: {rhyme_family1} vs {rhyme_family2}"
    )

    # Setup directories
    exp_dir = setup_output_directory(
        mode, output_dir, model_name, rhyme_family1, rhyme_family2
    )
    input_file = os.path.join(exp_dir, "generated_lines.json")
    output_file = os.path.join(exp_dir, "standard_metrics.json")

    # Check if already completed
    if os.path.exists(output_file):
        print(f"Standard metrics already completed: {output_file}")
        return

    # Check if input exists
    if not os.path.exists(input_file):
        print(f"Input file not found: {input_file}")
        print("Please run stage_line_generation.py first")
        return

    # Load generated lines
    print("Loading generated lines...")
    line_data = load_data(input_file)

    unsteered_texts = line_data["unsteered_texts"]
    steered_texts = line_data["steered_texts"]
    layers = line_data["layers"]
    strip_newline = line_data["strip_newline"]

    # Compute unsteered metrics (word-based, no model needed)
    print("Computing unsteered word metrics...")
    (
        last_word_correct_unsteered_rhyme_family1,
        last_word_correct_unsteered_rhyme_family2,
    ) = get_last_word_correct(
        unsteered_texts, [rhyme_family1, rhyme_family2], num_words=1
    )

    # Use batch size from line generation
    batch_size_small = line_data["batch_size_small"]
    # batch_size_small = 1000
    print(f"Using batch size: {batch_size_small}")

    print("Computing unsteered regeneration metrics...")
    (
        last_word_regeneration_unsteered_rhyme_family1,
        last_word_regeneration_unsteered_rhyme_family2,
    ) = get_last_word_regeneration_correct(
        model,
        tokenizer,
        unsteered_texts,
        [rhyme_family1, rhyme_family2],
        batch_size_small,
    )

    # Compute steered metrics for each layer
    steered_metrics = {}

    for layer in tqdm(layers, desc="Computing steered metrics"):
        print(f"Processing layer {layer}...")

        layer_steered_texts = steered_texts[str(layer)]

        # Word-based metrics (no model needed)
        (
            last_word_correct_steered_rhyme_family1,
            last_word_correct_steered_rhyme_family2,
        ) = get_last_word_correct(
            layer_steered_texts,
            [rhyme_family1, rhyme_family2],
            num_words=1,
        )

        # Regeneration metrics (model needed)
        (
            last_word_regeneration_steered_rhyme_family1,
            last_word_regeneration_steered_rhyme_family2,
        ) = get_last_word_regeneration_correct(
            model,
            tokenizer,
            layer_steered_texts,
            [rhyme_family1, rhyme_family2],
            batch_size_small,
        )

        steered_metrics[layer] = {
            "last_word_correct_rhyme_family1": last_word_correct_steered_rhyme_family1.tolist(),
            "last_word_correct_rhyme_family2": last_word_correct_steered_rhyme_family2.tolist(),
            "last_word_regeneration_rhyme_family1": last_word_regeneration_steered_rhyme_family1.tolist(),
            "last_word_regeneration_rhyme_family2": last_word_regeneration_steered_rhyme_family2.tolist(),
        }

        cleanup_gpu_memory()
        print(f"Completed layer {layer}")

    # Prepare output data
    output_data = {
        "unsteered_metrics": {
            "last_word_correct_rhyme_family1": last_word_correct_unsteered_rhyme_family1.tolist(),
            "last_word_correct_rhyme_family2": last_word_correct_unsteered_rhyme_family2.tolist(),
            "last_word_regeneration_rhyme_family1": last_word_regeneration_unsteered_rhyme_family1.tolist(),
            "last_word_regeneration_rhyme_family2": last_word_regeneration_unsteered_rhyme_family2.tolist(),
        },
        "steered_metrics": steered_metrics,
        "metadata": {
            "model_name": model_name,
            "rhyme_family1": rhyme_family1,
            "rhyme_family2": rhyme_family2,
            "timestamp": datetime.now().isoformat(),
            "batch_size_small": batch_size_small,
        },
    }

    # Save results
    save_data(output_data, output_file)
    print(f"Standard metrics completed and saved to: {output_file}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    # Load model once
    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    try:
        for rhyme_family1, rhyme_family2 in pairs:
            main(
                args.mode,
                args.model_name,
                model,
                tokenizer,
                rhyme_family1,
                rhyme_family2,
                args.output_dir,
            )
    finally:
        # Clean up model
        del model, tokenizer
        cleanup_gpu_memory()
