#!/usr/bin/env python3

import json
import os
import sys
from datetime import datetime

import numpy as np

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from paper_experiments.shared_utils import *


def main(mode, model_name, rhyme_family1, rhyme_family2, output_dir):
    print(f"Starting combination for {model_name}: {rhyme_family1} vs {rhyme_family2}")

    # Setup directories
    exp_dir = setup_output_directory(
        mode, output_dir, model_name, rhyme_family1, rhyme_family2
    )

    # Input files
    lines_file = os.path.join(exp_dir, "generated_lines.json")
    standard_file = os.path.join(exp_dir, "standard_metrics.json")
    prob_file = os.path.join(exp_dir, "prob_based_metrics.json")

    # Output file
    output_file = os.path.join(exp_dir, "combined_results.json")

    # Check if already completed
    if os.path.exists(output_file):
        print(f"Combination already completed: {output_file}")
        return

    # Check if all inputs exist
    missing_files = []
    for file_path, stage_name in [
        (lines_file, "line generation"),
        (standard_file, "standard metrics"),
        (prob_file, "probability-based metrics"),
    ]:
        if not os.path.exists(file_path):
            missing_files.append((file_path, stage_name))

    if missing_files:
        print("Missing input files:")
        for file_path, stage_name in missing_files:
            print(f"  {file_path} (from {stage_name})")
        print("Please run the previous stages first")
        return

    # Load all data
    print("Loading data from all stages...")
    line_data = load_data(lines_file)
    standard_data = load_data(standard_file)
    prob_data = load_data(prob_file)

    # Extract data
    generation_prompts = line_data["generation_prompts"]
    batch_size_per_prompt = line_data["batch_size_per_prompt"]
    layers = line_data["layers"]

    unsteered_metrics = standard_data["unsteered_metrics"]
    steered_metrics = standard_data["steered_metrics"]
    layer_metrics = prob_data["layer_metrics"]

    # Initialize results list (matching original format)
    results_list = []

    # Combine results exactly like the original script
    print("Combining results...")

    for layer in layers:
        # Get metrics for this layer
        steered_layer_metrics = steered_metrics[str(layer)]
        prob_layer_metrics = layer_metrics[str(layer)]

        # Process each prompt individually (matching original implementation)
        for i, prompt in enumerate(generation_prompts):
            # Calculate indices for this prompt's data (exactly like original)
            start_idx = i * batch_size_per_prompt
            end_idx = (i + 1) * batch_size_per_prompt
            indices = slice(start_idx, end_idx)

            # Extract metrics for this prompt and layer (exactly like original add_to_results call)
            result = {
                "model_name": model_name,
                "rhyme_family1": rhyme_family1,
                "rhyme_family2": rhyme_family2,
                "layer": layer,
                "steering_multiplier": STEERING_MULTIPLIER,
                "prompt": prompt,
                # Unsteered metrics (same for all layers, sliced per prompt)
                "last_word_fraction_unsteered_rhyme_family1": float(
                    np.array(unsteered_metrics["last_word_correct_rhyme_family1"])[
                        indices
                    ].mean()
                ),
                "last_word_fraction_unsteered_rhyme_family2": float(
                    np.array(unsteered_metrics["last_word_correct_rhyme_family2"])[
                        indices
                    ].mean()
                ),
                "last_word_regeneration_unsteered_rhyme_family1": float(
                    np.array(unsteered_metrics["last_word_regeneration_rhyme_family1"])[
                        indices
                    ].mean()
                ),
                "last_word_regeneration_unsteered_rhyme_family2": float(
                    np.array(unsteered_metrics["last_word_regeneration_rhyme_family2"])[
                        indices
                    ].mean()
                ),
                # Steered metrics (layer-specific, sliced per prompt)
                "last_word_fraction_steered_rhyme_family1": float(
                    np.array(steered_layer_metrics["last_word_correct_rhyme_family1"])[
                        indices
                    ].mean()
                ),
                "last_word_fraction_steered_rhyme_family2": float(
                    np.array(steered_layer_metrics["last_word_correct_rhyme_family2"])[
                        indices
                    ].mean()
                ),
                "last_word_regeneration_steered_rhyme_family1": float(
                    np.array(
                        steered_layer_metrics["last_word_regeneration_rhyme_family1"]
                    )[indices].mean()
                ),
                "last_word_regeneration_steered_rhyme_family2": float(
                    np.array(
                        steered_layer_metrics["last_word_regeneration_rhyme_family2"]
                    )[indices].mean()
                ),
                # Probability-based metrics (layer-specific, sliced per prompt)
                "fraction_top_1_difference": float(
                    get_top_1_difference_fraction(
                        np.array(prob_layer_metrics["top_1_differences"])[indices],
                        np.array(prob_layer_metrics["min_idxs"])[indices],
                        np.array(prob_layer_metrics["max_idxs"])[indices],
                    )
                ),
                "avg_idx_of_first_top_1_difference": float(
                    np.array(prob_layer_metrics["idx_of_first_top_1_difference"])[
                        indices
                    ].mean()
                ),
                "fraction_high_kl_divergence": float(
                    get_high_kl_divergence_fraction(
                        np.array(prob_layer_metrics["high_kl_divergences"])[indices],
                        np.array(prob_layer_metrics["min_idxs"])[indices],
                        np.array(prob_layer_metrics["max_idxs"])[indices],
                    )
                ),
                "avg_idx_of_first_high_kl_divergence": float(
                    np.array(prob_layer_metrics["idx_of_first_high_kl_divergence"])[
                        indices
                    ].mean()
                ),
                "avg_min_idx": float(
                    np.array(prob_layer_metrics["min_idxs"])[indices].mean()
                ),
                "avg_max_idx": float(
                    np.array(prob_layer_metrics["max_idxs"])[indices].mean()
                ),
            }

            results_list.append(result)

    # Create metadata
    metadata = {
        "timestamp": datetime.now().isoformat(),
        "model_name": model_name,
        "rhyme_family1": rhyme_family1,
        "rhyme_family2": rhyme_family2,
        "num_layers": len(layers),
        "num_prompts": len(generation_prompts),
        "batch_size_per_prompt": batch_size_per_prompt,
        "steering_multiplier": STEERING_MULTIPLIER,
        "source_files": {
            "lines": lines_file,
            "standard": standard_file,
            "prob": prob_file,
        },
    }

    # Prepare final output (matching original results.json format)
    output_data = {"results": results_list, "metadata": metadata}

    # Save results
    with open(output_file, "w") as f:
        json.dump(output_data, f, indent=2)

    print(f"Combination completed and saved to: {output_file}")
    print(f"Total results: {len(results_list)}")
    print(f"Results per layer: {len(generation_prompts)}")
    print(f"Total layers: {len(layers)}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    if args.mode == "rhyme_family_steering":
        pairs = RHYME_FAMILY_PAIRS
    if args.rhyme_family1 is not None and args.rhyme_family2 is not None:
        pairs = [(args.rhyme_family1, args.rhyme_family2)]
    elif args.mode == "specific_word_steering":
        pairs = SPECIFIC_WORD_PAIRS

    for rhyme_family1, rhyme_family2 in pairs:
        main(
            args.mode,
            args.model_name,
            rhyme_family1,
            rhyme_family2,
            args.output_dir,
        )
