#!/usr/bin/env python3

import os
import sys
from datetime import datetime
from typing import Dict, List

# Add the parent directory to sys.path to import shared_utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from paper_experiments.shared_utils import *


def regenerate_last_words_only(model, tokenizer, texts, batch_size_small: int) -> List[str]:
    """Regenerate the last word for each text's last line and return the words."""
    last_lines = [text.split("\n")[-1] for text in texts]
    last_lines_without_last_word = [line.rsplit(" ", 1)[0] for line in last_lines]
    prompts = ["A short sentence:\n" + line for line in last_lines_without_last_word]

    regenerated_last_lines = generate_steered_output(
        None,
        model,
        tokenizer,
        prompts,
        batch_size_per_prompt=1,
        num_prompts_per_rollout=batch_size_small,
        max_new_tokens=3,
    )

    generated_last_words = [
        remove_non_alphanumeric_characters_from_right(
            remove_non_alphanumeric_characters_from_left(
                regenerated_last_line[len(prompt) :] + " wrongwongwongxxx"
            ).split()[0]
        )
        for regenerated_last_line, prompt in zip(regenerated_last_lines, prompts)
    ]
    return generated_last_words


def find_any_generated_lines_file(base_dir: str, mode: str, model_name: str, rhyme_family1: str) -> str:
    """Find any generated_lines.json for the given rhyme_family1 across its pairs."""
    model_dir = os.path.join(base_dir, mode, model_name)
    if not os.path.exists(model_dir):
        return None
    for entry in os.listdir(model_dir):
        # Expect directories named like "{rhyme_family1}_{rhyme_family2}"
        if not os.path.isdir(os.path.join(model_dir, entry)):
            continue
        if not entry.startswith(f"{rhyme_family1}_"):
            continue
        candidate = os.path.join(model_dir, entry, "generated_lines.json")
        if os.path.exists(candidate):
            return candidate
    return None


def main(mode, model_name, model, tokenizer, output_dir):
    print(f"Starting combined unsteered regeneration for {model_name}")

    # Prepare list of all rhyme families (10 total)
    all_rhyme_families = list(rhyme_family_words.keys())

    # Where to save combined outputs (single place, not per-family-pair)
    combined_dir = os.path.join(output_dir, mode, model_name)
    os.makedirs(combined_dir, exist_ok=True)

    combined_results_path = os.path.join(combined_dir, "combined_unsteered_regeneration_results.json")
    combined_stats_path = os.path.join(combined_dir, "combined_regeneration_statistics.json")

    # If results already exist, skip regeneration and just (re)compute stats
    combined_results = {}
    if os.path.exists(combined_results_path):
        print(f"Found existing combined results: {combined_results_path}")
        combined_results = load_data(combined_results_path)

    # Accumulate regeneration words and correctness per rhyme_family1
    for rhyme_family1 in all_rhyme_families:
        print(f"Processing rhyme_family1 = {rhyme_family1} ...")

        # Try loading an existing generated_lines.json to reuse unsteered texts and batch size
        input_file = find_any_generated_lines_file(output_dir, mode, model_name, rhyme_family1)
        if input_file is None:
            print(f"No generated_lines.json found for {rhyme_family1}. Please run stage_line_generation.py first for any pair starting with '{rhyme_family1}_'.")
            continue

        line_data = load_data(input_file)
        unsteered_texts = line_data["unsteered_texts"]
        batch_size_small = line_data["batch_size_small"]

        # If we already have words stored for this family, skip regeneration
        if rhyme_family1 in combined_results and "regenerated_last_words" in combined_results[rhyme_family1]:
            print(f"Skipping regeneration for {rhyme_family1}; using cached results.")
            generated_last_words = combined_results[rhyme_family1]["regenerated_last_words"]
        else:
            print("Regenerating last words (unsteered)...")
            generated_last_words = regenerate_last_words_only(
                model, tokenizer, unsteered_texts, batch_size_small
            )

        # Compute correctness against every rhyme family
        per_family_correct_arrays = {
            rf: get_word_correct(generated_last_words, rf).tolist() for rf in all_rhyme_families
        }

        # Store results per rhyme_family1
        combined_results[rhyme_family1] = {
            "regenerated_last_words": generated_last_words,
            "per_family_correct": per_family_correct_arrays,
            "metadata": {
                "model_name": model_name,
                "rhyme_family1": rhyme_family1,
                "timestamp": datetime.now().isoformat(),
                "batch_size_small": batch_size_small,
            },
        }

        # Save incremental progress
        save_data(combined_results, combined_results_path)
        print(f"Saved updated combined results to: {combined_results_path}")

    # Compute statistics: per (rhyme_family1, target_family) mean, and cross-family averages excluding self
    print("Computing combined regeneration statistics...")
    regeneration_rates: Dict[str, Dict[str, float]] = {}
    for rhyme_family1, result in combined_results.items():
        per_family_correct = result["per_family_correct"]
        regeneration_rates[rhyme_family1] = {
            rf: float(sum(per_family_correct[rf]) / max(1, len(per_family_correct[rf]))) for rf in all_rhyme_families
        }

    # Average per target family excluding same source family (expect 9 sources)
    average_excluding_source: Dict[str, float] = {}
    for target_family in all_rhyme_families:
        vals = [regeneration_rates[src][target_family] for src in regeneration_rates.keys() if src != target_family and target_family in regeneration_rates[src]]
        average_excluding_source[target_family] = float(sum(vals) / max(1, len(vals))) if len(vals) > 0 else 0.0

    combined_statistics = {
        "regeneration_rates_by_source": regeneration_rates,
        "average_target_rate_excluding_source": average_excluding_source,
        "metadata": {
            "model_name": model_name,
            "mode": mode,
            "timestamp": datetime.now().isoformat(),
            "note": "Unsteered regeneration only; aggregated across all rhyme families without per-pair directories",
        },
    }

    save_data(combined_statistics, combined_stats_path)
    print(f"Combined regeneration statistics saved to: {combined_stats_path}")


if __name__ == "__main__":
    parser = get_common_args()
    args = parser.parse_args()

    print("Loading model...")
    model, tokenizer = get_model(args.model_name)

    try:
        main(
            args.mode,
            args.model_name,
            model,
            tokenizer,
            args.output_dir,
        )
    finally:
        del model, tokenizer
        cleanup_gpu_memory()


