import numpy as np
from pathlib import Path
import argparse
import matplotlib.pyplot as plt
import json
import matplotlib.ticker as ticker

# --- Configuration Globals ---
TICK_FONTSIZE = 24
LEGEND_FONTSIZE = 24

# Models to compare
MODEL_LEFT = "ResNet56_CIFAR100"  # CIFAR model (left)
MODEL_RIGHT = "ViT_Base_P16_224_ImageNet1k"  # ImageNet model (right)

# List of methods to plot
METHODS_TO_PLOT = [
    "Uncalibrated",
    "TemperatureScaling",
    "VectorScaling",
    # "MatrixScaling",
    # "MatrixScalingODIR",
    "DirichletL2",
    "EnsembleTemperatureScaling",
    "IROvA",
    "IRM",
]

# Legend mapping
LEGEND_MAPPING = {
    "Uncalibrated": "Uncalibrated",
    "TemperatureScaling": "T.S.",
    "VectorScaling": "V.S",
    # "MatrixScalingODIR": "Matrix Scaling",
    "DirichletL2": "Dirichlet",
    "EnsembleTemperatureScaling": "Ens.T.S.",
    "IROvA": "IR(OvA)",
    "IRM": "IR",
}

# Color-blind friendly color palette
COLORBLIND_PALETTE = [
    "#0072B2",  # blue
    "#D55E00",  # vermillion/red
    "#009E73",  # green
    "#CC79A7",  # pink
    "#F0E442",  # yellow
    "#56B4E9",  # light blue
    "#E69F00",  # orange
    "#000000",  # black
]

# Line styles for additional differentiation
LINE_STYLES = ["-", "--", "-.", ":"]


def _plot_subplot_ecdf(
    ax,
    errors_dict,
    subplot_title_prefix,
    methods_to_plot_config,
    legend_mapping_config,
    show_yticks=False,
    show_xticks=True,
    x_lim=None,
):
    """Plot a single subplot with ECDF data."""
    ax.set_title(
        f"{subplot_title_prefix}",
        fontsize=LEGEND_FONTSIZE,
    )

    # Set fixed y-ticks positions for all plots to ensure horizontal grid lines
    ax.set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

    # Add cross-hatch grid pattern
    ax.grid(True, linestyle="-", alpha=0.2, which="major")
    ax.grid(True, linestyle="-", alpha=0.1, which="minor")
    ax.set_ylim(0, 1.05)  # Keep Y limit fixed
    ax.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)

    # Show y-tick labels only if specified
    if not show_yticks:
        ax.set_yticklabels([])

    # Handle x-ticks
    if not show_xticks:
        ax.set_xticklabels([])

    # Set log scale for x-axis
    ax.set_xscale("log")

    # Add minor grid lines for better readability
    ax.minorticks_on()

    plotted_anything = False
    max_x_val = 0.0
    min_x_val = float("inf")
    handles = []
    labels = []

    # Sort methods to ensure consistent color assignment
    sorted_methods = sorted(errors_dict.keys())

    for i, method_name in enumerate(sorted_methods):
        if method_name in methods_to_plot_config:
            errors = np.array(errors_dict[method_name])
            if len(errors) == 0:
                print(
                    f"No errors to plot for '{subplot_title_prefix}' method: {method_name}."
                )
                continue

            # For log scale, replace zeros with a small value
            errors = np.maximum(errors, 1e-10)

            legend_label = legend_mapping_config.get(method_name, method_name)
            x_ecdf = np.sort(errors)
            y_ecdf = np.arange(1, len(errors) + 1) / len(errors)

            # Use color-blind friendly palette and cycle through line styles
            color_idx = i % len(COLORBLIND_PALETTE)
            line_style_idx = i % len(LINE_STYLES)

            (line,) = ax.plot(
                x_ecdf,
                y_ecdf,
                marker="o",
                linestyle=LINE_STYLES[line_style_idx],
                label=legend_label,
                markersize=3,
                linewidth=3.0,
                solid_capstyle="round",
                alpha=0.8,
                color=COLORBLIND_PALETTE[color_idx],
            )
            handles.append(line)
            labels.append(legend_label)
            plotted_anything = True
            if len(x_ecdf) > 0:
                current_max_x = x_ecdf[-1]  # Max error for this method
                current_min_x = x_ecdf[0]  # Min error for this method
                if current_max_x > max_x_val:
                    max_x_val = current_max_x
                if current_min_x < min_x_val:
                    min_x_val = current_min_x
        else:
            print(
                f"Skipping method '{method_name}' for subplot '{subplot_title_prefix}' as it's not in METHODS_TO_PLOT config."
            )

    if plotted_anything:
        if x_lim is not None:
            # Use provided x_lim for consistent limits across subplots
            ax.set_xlim(left=min_x_val * 0.5, right=x_lim)
        else:
            ax.set_xlim(
                left=min_x_val * 0.5, right=max_x_val * 1.05 if max_x_val > 0 else 1.0
            )  # Dynamic X limit

        # Set custom tick locations to only show 10^-1 and 10^-2
        ax.xaxis.set_major_locator(ticker.LogLocator(base=10, numticks=2))
        ax.xaxis.set_major_formatter(ticker.LogFormatterSciNotation(base=10))

    else:
        ax.text(
            0.5,
            0.5,
            "No data for selected methods",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax.transAxes,
            fontsize=TICK_FONTSIZE,
        )
        ax.set_xlim(left=1e-10, right=1.0)  # Default X limit if nothing plotted

    return handles, labels, max_x_val, min_x_val


def load_model_errors(model_name, logs_root_dir):
    """Load error data for a specific model."""
    model_log_data_dir = Path(logs_root_dir) / model_name / "cdf_plot_logs"
    if not model_log_data_dir.is_dir():
        print(
            f"Log directory not found for model '{model_name}' at {model_log_data_dir}."
        )
        return None, None, None

    print(f"Reading ECDF data from: {model_log_data_dir}")
    linear_errors = {}
    rank_errors = {}
    dataset_name = model_name

    for log_file in model_log_data_dir.glob("*.npy"):
        method_name_full = log_file.stem
        errors_data = np.load(log_file)

        method_name_simple = "Unknown"
        if "_linear_errors" in method_name_full:
            method_name_simple = method_name_full.replace("_linear_errors", "")
            if method_name_simple in METHODS_TO_PLOT:
                linear_errors[method_name_simple] = errors_data
                print(
                    f"  Loaded linear errors for '{method_name_simple}' (shape: {errors_data.shape})"
                )
        elif "_rank_errors" in method_name_full:
            method_name_simple = method_name_full.replace("_rank_errors", "")
            if method_name_simple in METHODS_TO_PLOT:
                rank_errors[method_name_simple] = errors_data
                print(
                    f"  Loaded rank errors for '{method_name_simple}' (shape: {errors_data.shape})"
                )

    # Try to get dataset name from metadata
    metadata_path = model_log_data_dir / "_metadata.json"
    if metadata_path.exists():
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
        dataset_name = metadata.get("dataset_name", model_name)

    return linear_errors, rank_errors, dataset_name


def plot_models_comparison(logs_root_dir, output_plot_dir):
    """Create a side-by-side comparison plot of the two models."""
    # Load data for both models
    cifar_linear, cifar_rank, cifar_dataset = load_model_errors(
        MODEL_LEFT, logs_root_dir
    )
    imagenet_linear, imagenet_rank, imagenet_dataset = load_model_errors(
        MODEL_RIGHT, logs_root_dir
    )

    if cifar_rank is None or cifar_linear is None:
        print(f"No data found for {MODEL_LEFT}. Exiting.")
        return

    if imagenet_rank is None or imagenet_linear is None:
        print(f"No data found for {MODEL_RIGHT}. Exiting.")
        return

    # Create figure with 1x4 subplots with slightly taller height
    fig, axes = plt.subplots(1, 4, figsize=(24, 7))  # Increased height from 5 to 5.5

    # Tight spacing
    plt.subplots_adjust(wspace=0.05, hspace=0.05, top=0.85, bottom=0.25)

    # Titles for the subplots - simplify by removing dataset names to save space
    cifar_rank_title = "Rank-Based (CIFAR)"
    cifar_linear_title = "Linear (CIFAR)"
    imagenet_rank_title = "Rank-Based (ImageNet)"
    imagenet_linear_title = "Linear (ImageNet)"

    # First pass to determine max x values
    # Plot the data first to calculate max x-values for consistent scaling
    max_x_vals = []
    min_x_vals = []

    # First two subplots: CIFAR model (left two columns)
    # Only show y-ticks on leftmost subplot
    tmp_handles, tmp_labels, max_x_cifar_rank, min_x_cifar_rank = _plot_subplot_ecdf(
        axes[0],
        cifar_rank,
        cifar_rank_title,
        METHODS_TO_PLOT,
        LEGEND_MAPPING,
        show_yticks=True,
    )
    max_x_vals.append(max_x_cifar_rank)
    min_x_vals.append(min_x_cifar_rank)

    tmp_handles, tmp_labels, max_x_cifar_linear, min_x_cifar_linear = (
        _plot_subplot_ecdf(
            axes[1],
            cifar_linear,
            cifar_linear_title,
            METHODS_TO_PLOT,
            LEGEND_MAPPING,
            show_yticks=False,
        )
    )
    max_x_vals.append(max_x_cifar_linear)
    min_x_vals.append(min_x_cifar_linear)

    # Last two subplots: ImageNet model (right two columns)
    tmp_handles, tmp_labels, max_x_imagenet_rank, min_x_imagenet_rank = (
        _plot_subplot_ecdf(
            axes[2],
            imagenet_rank,
            imagenet_rank_title,
            METHODS_TO_PLOT,
            LEGEND_MAPPING,
            show_yticks=False,
        )
    )
    max_x_vals.append(max_x_imagenet_rank)
    min_x_vals.append(min_x_imagenet_rank)

    tmp_handles, tmp_labels, max_x_imagenet_linear, min_x_imagenet_linear = (
        _plot_subplot_ecdf(
            axes[3],
            imagenet_linear,
            imagenet_linear_title,
            METHODS_TO_PLOT,
            LEGEND_MAPPING,
            show_yticks=False,
        )
    )
    max_x_vals.append(max_x_imagenet_linear)
    min_x_vals.append(min_x_imagenet_linear)

    # Clear the figure for the real plotting with consistent x limits
    for ax in axes:
        ax.clear()

    # Find global max and min for x-axis
    global_max_x = max(max_x_vals) * 1.05
    global_min_x = min(min_x_vals) * 0.5

    # Now plot for real with consistent x limits
    cifar_rank_handles, cifar_rank_labels, _, _ = _plot_subplot_ecdf(
        axes[0],
        cifar_rank,
        cifar_rank_title,
        METHODS_TO_PLOT,
        LEGEND_MAPPING,
        show_yticks=True,
        show_xticks=True,
        x_lim=global_max_x,
    )

    cifar_linear_handles, cifar_linear_labels, _, _ = _plot_subplot_ecdf(
        axes[1],
        cifar_linear,
        cifar_linear_title,
        METHODS_TO_PLOT,
        LEGEND_MAPPING,
        show_yticks=False,
        show_xticks=True,
        x_lim=global_max_x,
    )

    # Last two subplots: ImageNet model (right two columns)
    imagenet_rank_handles, imagenet_rank_labels, _, _ = _plot_subplot_ecdf(
        axes[2],
        imagenet_rank,
        imagenet_rank_title,
        METHODS_TO_PLOT,
        LEGEND_MAPPING,
        show_yticks=False,
        show_xticks=True,
        x_lim=global_max_x,
    )

    imagenet_linear_handles, imagenet_linear_labels, _, _ = _plot_subplot_ecdf(
        axes[3],
        imagenet_linear,
        imagenet_linear_title,
        METHODS_TO_PLOT,
        LEGEND_MAPPING,
        show_yticks=False,
        show_xticks=True,
        x_lim=global_max_x,
    )

    # Set custom x-ticks for all subplots (only 10^-3 and 10^-2)
    for ax in axes:
        ax.set_xticks([1e-3, 1e-2])
        ax.set_xticklabels([r"$10^{-3}$", r"$10^{-2}$"])

    # Combine all handles and labels
    all_method_handles = {}

    for handles, labels in [
        (cifar_rank_handles, cifar_rank_labels),
        (cifar_linear_handles, cifar_linear_labels),
        (imagenet_rank_handles, imagenet_rank_labels),
        (imagenet_linear_handles, imagenet_linear_labels),
    ]:
        for h, l in zip(handles, labels):
            if l not in all_method_handles:
                all_method_handles[l] = h

    # Sort methods for consistent legend order
    sorted_methods = sorted(all_method_handles.keys())

    # Create custom legend handles with thicker lines for better visibility
    custom_handles = []
    custom_labels = []

    for i, method in enumerate(sorted_methods):
        color_idx = i % len(COLORBLIND_PALETTE)
        line_style_idx = i % len(LINE_STYLES)
        custom_line = plt.Line2D(
            [0],
            [0],
            color=COLORBLIND_PALETTE[color_idx],
            linewidth=4.0,
            linestyle=LINE_STYLES[line_style_idx],
            marker="o",
            markersize=6,
        )
        custom_handles.append(custom_line)
        custom_labels.append(method)

    # Calculate number of items per line for two-line legend
    total_items = len(custom_handles)

    # Add unified legend on a single line
    fig.legend(
        custom_handles,
        custom_labels,
        loc="lower center",
        ncol=total_items,  # All items on one line
        bbox_to_anchor=(0.5, 0.005),
        fontsize=LEGEND_FONTSIZE,
        frameon=False,
        handlelength=2.5,
        borderpad=0.3,
    )

    # Adjust bottom margin to accommodate the legend
    plt.subplots_adjust(
        wspace=0.05, hspace=0.05, top=0.85, bottom=0.25
    )  # Reduced bottom margin since legend is more compact

    # Save the figure with tight layout as PDF
    output_path = Path(output_plot_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    plot_filename = output_path / f"model_comparison_{MODEL_LEFT}_vs_{MODEL_RIGHT}.pdf"
    plt.savefig(plot_filename, bbox_inches="tight", format="pdf")
    plt.close(fig)
    print(f"Saved comparison plot to: {plot_filename}")


def main():
    parser = argparse.ArgumentParser(
        description="Generate side-by-side comparison plot of CIFAR and ImageNet models."
    )
    parser.add_argument(
        "--logs-root-dir",
        type=str,
        default="./experiment_cdf_data",
        help="Path to the root directory containing model-specific ECDF data subdirectories.",
    )
    parser.add_argument(
        "--output-plot-dir",
        type=str,
        default="./ecdf_plots_comparison",
        help="Directory to save the comparison plot.",
    )

    args = parser.parse_args()

    # Set matplotlib parameters for consistent font sizes
    plt.rcParams.update(
        {
            "font.size": TICK_FONTSIZE,
            "axes.titlesize": LEGEND_FONTSIZE,
            "axes.labelsize": TICK_FONTSIZE,
            "xtick.labelsize": TICK_FONTSIZE,
            "ytick.labelsize": TICK_FONTSIZE,
            "legend.fontsize": LEGEND_FONTSIZE,
        }
    )

    plot_models_comparison(args.logs_root_dir, args.output_plot_dir)
    print("Comparison plot generation completed.")


if __name__ == "__main__":
    main()
