#!/usr/bin/env python3

import os
import json
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def list_models(results_root: str, mode: str) -> List[str]:
    mode_dir = os.path.join(results_root, mode)
    if not os.path.isdir(mode_dir):
        return []
    model_dirs: List[str] = []
    for d in os.listdir(mode_dir):
        if not os.path.isdir(os.path.join(mode_dir, d)):
            continue
        if d.startswith("plots_"):
            continue
        # Heuristic: must contain combined stats or at least one pair with standard metrics
        model_dir = os.path.join(mode_dir, d)
        combined_path = os.path.join(model_dir, "combined_regeneration_statistics.json")
        if os.path.exists(combined_path):
            model_dirs.append(d)
            continue
        found_pair = False
        for entry in os.listdir(model_dir):
            if entry.startswith("plots_"):
                continue
            pair_dir = os.path.join(model_dir, entry)
            if not os.path.isdir(pair_dir):
                continue
            if "_" not in entry:
                continue
            metrics_path = os.path.join(pair_dir, "standard_metrics.json")
            if os.path.exists(metrics_path):
                found_pair = True
                break
        if found_pair:
            model_dirs.append(d)
    return sorted(model_dirs)


def load_combined_baselines(results_root: str, mode: str, model_name: str):
    model_dir = os.path.join(results_root, mode, model_name)
    path = os.path.join(model_dir, "combined_regeneration_statistics.json")
    if not os.path.exists(path):
        return None
    with open(path, "r") as f:
        return json.load(f)


def rhyme_index_map() -> Tuple[List[str], Dict[str, int]]:
    # Fixed rhyme family ordering used across experiments
    families = [
        "ing",
        "air",
        "ip",
        "oat",
        "ird",
        "ee",
        "ight",
        "ake",
        "ow",
        "it",
    ]
    idx = {rf: i for i, rf in enumerate(families)}
    return families, idx


def aggregate_pair_metrics(standard_metrics_path: str) -> Dict:
    with open(standard_metrics_path, "r") as f:
        return json.load(f)


def build_regeneration_matrices_for_model(results_root: str, mode: str, model_name: str):
    families, idx = rhyme_index_map()
    size = len(families)
    # Accumulate means across pairs; will normalize to per-source averages
    counts_unsteered = np.zeros((size, size), dtype=float)
    totals_unsteered = np.zeros((size,), dtype=float)

    # For steered, average across layers per pair before aggregating across pairs
    counts_steered = np.zeros((size, size), dtype=float)
    totals_steered = np.zeros((size,), dtype=float)

    model_dir = os.path.join(results_root, mode, model_name)
    if not os.path.isdir(model_dir):
        return None

    # Scan pair subdirectories
    for entry in os.listdir(model_dir):
        pair_dir = os.path.join(model_dir, entry)
        if not os.path.isdir(pair_dir):
            continue
        if entry.startswith("plots_"):
            continue
        if "_" not in entry:
            continue
        metrics_path = os.path.join(pair_dir, "standard_metrics.json")
        if not os.path.exists(metrics_path):
            continue

        try:
            data = aggregate_pair_metrics(metrics_path)
        except Exception:
            continue

        meta = data.get("metadata", {})
        rf1 = meta.get("rhyme_family1")
        rf2 = meta.get("rhyme_family2")
        if rf1 not in idx or rf2 not in idx:
            continue
        s = idx[rf1]

        # Unsteered arrays are 0/1 lists; accumulate per target for the two available families in this pair
        un = data.get("unsteered_metrics", {})
        rf1_arr = np.array(un.get("last_word_regeneration_rhyme_family1", []), dtype=float)
        rf2_arr = np.array(un.get("last_word_regeneration_rhyme_family2", []), dtype=float)
        n = max(len(rf1_arr), len(rf2_arr))
        if n == 0:
            continue

        # For a source family, measure hits on each target family using get_word_correct on regenerated words would be ideal,
        # but here we only have correctness per the two families in the pair. We map those to the appropriate target columns.
        t1 = idx[rf1]
        t2 = idx[rf2]
        # Add per-pair means. We'll later average across all pairs for this source.
        counts_unsteered[s, t1] += float(rf1_arr.mean())
        counts_unsteered[s, t2] += float(rf2_arr.mean())
        totals_unsteered[s] += 1.0

        # Steered metrics: average across layers for each target, then accumulate
        steered = data.get("steered_metrics", {})
        if isinstance(steered, dict) and len(steered) > 0:
            rf1_layer_means = []
            rf2_layer_means = []
            for layer_key, layer_obj in steered.items():
                a1 = np.array(layer_obj.get("last_word_regeneration_rhyme_family1", []), dtype=float)
                a2 = np.array(layer_obj.get("last_word_regeneration_rhyme_family2", []), dtype=float)
                if a1.size > 0:
                    rf1_layer_means.append(float(a1.mean()))
                if a2.size > 0:
                    rf2_layer_means.append(float(a2.mean()))
            if len(rf1_layer_means) > 0:
                counts_steered[s, t1] += float(np.mean(rf1_layer_means))
            if len(rf2_layer_means) > 0:
                counts_steered[s, t2] += float(np.mean(rf2_layer_means))
            totals_steered[s] += 1.0

    # Convert accumulated means-over-pairs into per-source target fractions
    with np.errstate(invalid="ignore", divide="ignore"):
        unsteered_matrix = np.where(totals_unsteered[:, None] > 0, counts_unsteered / totals_unsteered[:, None], 0.0)
        steered_matrix = np.where(totals_steered[:, None] > 0, counts_steered / totals_steered[:, None], 0.0)

    return unsteered_matrix, steered_matrix, families


def compute_overall_unsteered_success_vs_baseline(combined: Dict, families: List[str]) -> Tuple[float, float]:
    # Use combined_regeneration_statistics: regeneration_rates_by_source
    rates_by_source = (combined or {}).get("regeneration_rates_by_source", {})
    if not rates_by_source:
        return 0.0, 0.0
    sources = [rf for rf in families if rf in rates_by_source]
    if len(sources) == 0:
        return 0.0, 0.0
    success_self = []
    baseline_self = []
    for s in sources:
        per_target = rates_by_source[s]
        # success for its own family
        success_self.append(float(per_target.get(s, 0.0)))
        # baseline: average over other targets for this source
        others = [float(v) for t, v in per_target.items() if t != s]
        if len(others) > 0:
            baseline_self.append(float(np.mean(others)))
    success_mean = float(np.mean(success_self)) if len(success_self) > 0 else 0.0
    baseline_mean = float(np.mean(baseline_self)) if len(baseline_self) > 0 else 0.0
    return success_mean, baseline_mean


def compute_overall_steered_success_vs_baseline(
    combined: Dict,
    steered_matrix: np.ndarray,
    families: List[str],
) -> Tuple[float, float]:
    # For steered success per target: average across sources of steered success into that target
    size = len(families)
    idx = {rf: i for i, rf in enumerate(families)}
    target_success_means = []
    target_baseline_means = []

    rates_by_source = (combined or {}).get("regeneration_rates_by_source", {})
    for t in families:
        j = idx[t]
        # sources that have a steered value into this target (exclude self)
        steered_vals = [steered_matrix[i, j] for i in range(size) if i != j and steered_matrix[i, j] > 0]
        if len(steered_vals) == 0:
            continue
        target_success_means.append(float(np.mean(steered_vals)))

        # baseline for this target: for the same sources i, average over other targets != t using combined rates
        baseline_vals = []
        for i in range(size):
            s = families[i]
            if i == j:
                continue
            if steered_matrix[i, j] <= 0:
                continue
            per_target = rates_by_source.get(s, {})
            others = [float(v) for tt, v in per_target.items() if tt != t]
            if len(others) > 0:
                baseline_vals.append(float(np.mean(others)))
        if len(baseline_vals) > 0:
            target_baseline_means.append(float(np.mean(baseline_vals)))

    success_mean = float(np.mean(target_success_means)) if len(target_success_means) > 0 else 0.0
    baseline_mean = float(np.mean(target_baseline_means)) if len(target_baseline_means) > 0 else 0.0
    return success_mean, baseline_mean


def plot_heatmap(matrix: np.ndarray, families: List[str], title: str, save_path: str):
    plt.figure(figsize=(10, 8))
    sns.heatmap(matrix, annot=True, fmt=".2f", cmap="viridis", xticklabels=families, yticklabels=families)
    plt.xlabel("Target rhyme family")
    plt.ylabel("Source rhyme family")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()


def plot_baseline_vs_observed(families: List[str], baseline_avg_excl_source: Dict[str, float], unsteered_matrix: np.ndarray, steered_matrix: np.ndarray, save_path: str):
    # Compute mean across sources excluding self for each target
    size = len(families)
    idx = {rf: i for i, rf in enumerate(families)}
    mean_unsteered_excl = []
    mean_steered_excl = []
    baseline = []
    for rf in families:
        j = idx[rf]
        # exclude diagonal source==target
        un_vals = [unsteered_matrix[i, j] for i in range(size) if i != j]
        st_vals = [steered_matrix[i, j] for i in range(size) if i != j]
        mean_unsteered_excl.append(float(np.mean(un_vals)) if len(un_vals) > 0 else 0.0)
        mean_steered_excl.append(float(np.mean(st_vals)) if len(st_vals) > 0 else 0.0)
        baseline.append(float(baseline_avg_excl_source.get(rf, 0.0)))

    x = np.arange(size)
    width = 0.28

    plt.figure(figsize=(12, 6))
    plt.bar(x - width, baseline, width=width, label="Baseline avg (excl. source)")
    plt.bar(x, mean_unsteered_excl, width=width, label="Unsteered ")
    plt.bar(x + width, mean_steered_excl, width=width, label="Steered")
    plt.xticks(x, families, rotation=45)
    plt.ylabel("Regeneration rate")
    plt.title("Baseline vs observed regeneration rates by target family")
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.close()


def main(results_root: str = "results", mode: str = "rhyme_family_steering"):
    models = list_models(results_root, mode)
    if len(models) == 0:
        print(f"No models found under {os.path.join(results_root, mode)}")
        return

    # For aggregated plots across models
    agg_model_names: List[str] = []
    agg_baseline_means: List[float] = []
    agg_unsteered_means: List[float] = []
    agg_steered_means: List[float] = []

    for model_name in models:
        print(f"Processing model: {model_name}")
        combined = load_combined_baselines(results_root, mode, model_name)
        unsteered_matrix, steered_matrix, families = build_regeneration_matrices_for_model(results_root, mode, model_name)
        if unsteered_matrix is None:
            print(f"Skipping {model_name}: no data")
            continue

        out_dir = os.path.join(results_root, mode, model_name, "plots_regeneration")
        os.makedirs(out_dir, exist_ok=True)

        # Heatmaps
        plot_heatmap(unsteered_matrix, families, f"Unsteered regeneration ({model_name})", os.path.join(out_dir, "unsteered_regeneration_heatmap.png"))
        plot_heatmap(steered_matrix, families, f"Steered regeneration (avg over layers) ({model_name})", os.path.join(out_dir, "steered_regeneration_heatmap.png"))

        # Baseline vs observed
        baseline_avg = combined.get("average_target_rate_excluding_source", {}) if combined else {}
        plot_baseline_vs_observed(
            families,
            baseline_avg,
            unsteered_matrix,
            steered_matrix,
            os.path.join(out_dir, "baseline_vs_observed.png"),
        )

        print(f"Saved plots to: {out_dir}")

        # Compute per-model aggregated means per the intended definitions
        # Unsteered: avg over sources of (success at own target vs baseline over other targets for that source)
        un_success_mean, un_baseline_mean = compute_overall_unsteered_success_vs_baseline(combined, families)

        # Steered: avg over targets of (mean steered success into target vs baseline computed from same sources)
        st_success_mean, st_baseline_mean = compute_overall_steered_success_vs_baseline(combined, steered_matrix, families)

        agg_model_names.append(model_name)
        agg_baseline_means.append(un_baseline_mean)
        agg_unsteered_means.append(un_success_mean)
        agg_steered_means.append(st_success_mean)

    # Create aggregated bar plots across models
    if len(agg_model_names) > 0:
        agg_dir = os.path.join(results_root, mode, "plots_aggregated")
        os.makedirs(agg_dir, exist_ok=True)

        x = np.arange(len(agg_model_names))
        width = 0.35

        # Unsteered vs Baseline
        plt.figure(figsize=(max(8, 1 + 1.2 * len(agg_model_names)), 5))
        plt.bar(x - width / 2, agg_baseline_means, width=width, label="Baseline avg (excl. source)")
        plt.bar(x + width / 2, agg_unsteered_means, width=width, label="Unsteered avg")
        plt.xticks(x, agg_model_names, rotation=30, ha="right")
        plt.ylabel("Regeneration rate")
        plt.title("Aggregated: Baseline vs Unsteered Regeneration Across Models")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(agg_dir, "aggregated_unsteered_vs_baseline.png"), dpi=200, bbox_inches="tight")
        plt.close()

        # Steered vs Baseline
        plt.figure(figsize=(max(8, 1 + 1.2 * len(agg_model_names)), 5))
        plt.bar(x - width / 2, agg_baseline_means, width=width, label="Baseline avg (unsteered)")
        plt.bar(x + width / 2, agg_steered_means, width=width, label="Steered avg (targets)")
        plt.xticks(x, agg_model_names, rotation=30, ha="right")
        plt.ylabel("Regeneration rate")
        plt.title("Aggregated: Baseline vs Steered Regeneration Across Models")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(agg_dir, "aggregated_steered_vs_baseline.png"), dpi=200, bbox_inches="tight")
        plt.close()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", default="results")
    parser.add_argument("--mode", default="rhyme_family_steering")
    args = parser.parse_args()
    main(results_root=args.output_dir, mode=args.mode)


