"""
Plot combined dataset metrics like Kendall tau across different datasets.
"""

import argparse
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc

# Enable LaTeX rendering
rc("text", usetex=True)
rc("font", family="serif")


def load_results(result_dir):
    """Load results from a result directory."""
    result_dir = Path(result_dir)
    aggregated_results_path = result_dir / "aggregated_results.json"
    metadata_path = result_dir / "metadata.json"

    with open(aggregated_results_path, "r") as f:
        results = json.load(f)

    with open(metadata_path, "r") as f:
        metadata = json.load(f)
        dataset_name = metadata.get("dataset")

    return results, dataset_name


def extract_metric_values(results, metric_name):
    """Extract values for a specified dataset metric from results."""
    values = []

    metrics = results["dataset_metrics"].get(metric_name, [])

    for metric in metrics:
        # if isinstance(metric, dict) and "value" in metric:
        value_data = metric["value"]
        # if isinstance(value_data, dict) and "x" in value_data:
        x_values = value_data["x"]
        values.extend(x_values)
        # else:
        #     values.extend(x_values.tolist())

    return values


def plot_histograms(data_dict, metric_name, output_dir, output_suffix, x_range=None):
    """Plot histograms of metric values for each dataset."""
    datasets = list(data_dict.keys())
    n_datasets = len(datasets)

    fig, axes = plt.subplots(1, n_datasets, figsize=(5 * n_datasets, 5), sharey=True)

    if n_datasets == 1:
        axes = [axes]

    dataset_display_names = {
        "Deterministic": "Deterministic",
        "Deterministic No Variability": "Deterministic",
        "Drift Diffusion": "Drift-Diffusion",
        "Drift Diffusion No Variability": "Drift-Diffusion",
        "Stochastic": "Stochastic",
        "Stochastic No Variability": "Stochastic",
    }

    metric_display_names = {
        "test_abs_logit_diff_rt_kendall_tau": "Kendall $\\tau$ Correlation",
    }

    if x_range is None:
        # if "kendall_tau" in metric_name:
        #     x_range = (0, 1)
        # else:
        min_val = min(min(values) for values in data_dict.values() if values)
        max_val = max(max(values) for values in data_dict.values() if values)
        x_range = (min_val, max_val)

    # for dataset, values in data_dict.items():
    #     counts, _ = np.histogram(values, bins=100, range=x_range)

    for i, dataset in enumerate(datasets):
        ax = axes[i]
        values = data_dict[dataset]

        ax.hist(values, bins=25, range=x_range, alpha=0.7, color="#1D9A6C")

        # Add mean indicator (line and text)
        mean_val = np.mean(values)
        ax.axvline(mean_val, color="black", linestyle="dashed", linewidth=2)
        ax.text(
            0.05,
            0.90,
            f"Mean: {mean_val:.3f}",
            transform=ax.transAxes,
            ha="left",
            va="top",
            bbox=dict(facecolor="white", alpha=0.7, edgecolor="none"),
        )

        display_name = dataset_display_names.get(dataset, dataset.capitalize())
        ax.set_title(r"\textbf{" + display_name + r"}", fontsize=14)

        # Only add y-label to the first subplot
        if i == 0:
            ax.set_ylabel(r"\textbf{Frequency}", fontsize=14)

        # No X label, will use shared label
        ax.set_xlabel("")

        # Shared ylim, use max
        ax.set_ylim(0, len(values))
        ax.set_xlim(x_range)

        ax.grid(axis="y", linestyle="--", alpha=0.7)

    # Shared X label
    display_name = metric_display_names.get(metric_name, metric_name)
    fig.text(
        0.5,
        0.02,
        r"\textbf{" + display_name + r"}",
        ha="center",
        fontsize=14,
    )
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    suffix = f"_{output_suffix}" if output_suffix else ""
    output_path = output_dir / f"{metric_name}_combined_histograms{suffix}.pdf"
    plt.savefig(output_path, bbox_inches="tight")

    print(f"Saved {metric_name} histograms to {output_path}")
    plt.close()


def run_analysis(
    result_dirs, metrics, output_dir, output_suffix, x_range_min=None, x_range_max=None
):
    """Run the combined dataset metrics analysis.

    Args:
        result_dirs: List of directories containing experiment results
        metrics: List of metrics to analyze
        output_dir: Directory to save output files
        output_suffix: Suffix for output filenames
        x_range_min: Minimum value for x-axis range (None for auto)
        x_range_max: Maximum value for x-axis range (None for auto)
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    x_range = None
    if x_range_min is not None and x_range_max is not None:
        x_range = (x_range_min, x_range_max)

    for metric_name in metrics:
        data_dict = {}
        for result_dir in result_dirs:
            results, dataset_name = load_results(result_dir)
            values = extract_metric_values(results, metric_name)
            data_dict[dataset_name] = values
            print(f"Extracted {len(values)} {metric_name} values from {dataset_name}")

        plot_histograms(data_dict, metric_name, output_dir, output_suffix, x_range)


def main():
    parser = argparse.ArgumentParser(
        description="Plot dataset metric histograms across datasets"
    )
    parser.add_argument(
        "--result_dirs",
        nargs="+",
        required=True,
        help="Paths to result directories containing aggregated_results.json",
    )
    parser.add_argument(
        "--metrics",
        nargs="+",
        default=["test_abs_logit_diff_rt_kendall_tau"],
        help="Dataset metrics to visualize (default: test_abs_logit_diff_rt_kendall_tau)",
    )
    parser.add_argument(
        "--output",
        default="./",
        help="Output directory for histograms",
    )
    parser.add_argument(
        "--output-suffix",
        default="",
        help="Suffix for output filenames",
    )
    parser.add_argument(
        "--x-range-min",
        type=float,
        default=None,
        help="Minimum value for x-axis range (default: auto)",
    )
    parser.add_argument(
        "--x-range-max",
        type=float,
        default=None,
        help="Maximum value for x-axis range (default: auto)",
    )

    args = parser.parse_args()

    run_analysis(
        result_dirs=args.result_dirs,
        metrics=args.metrics,
        output_dir=args.output,
        output_suffix=args.output_suffix,
        x_range_min=args.x_range_min,
        x_range_max=args.x_range_max,
    )


if __name__ == "__main__":
    main()
