import numpy as np
from pathlib import Path
import argparse

# import matplotlib # No longer needed for backend switching
import matplotlib.pyplot as plt
from collections import defaultdict
import json  # For reading metadata
from datetime import datetime

# Removed: from utility_cal import plot_combined_uc_ecdfs

# --- Configuration Globals ---
DEFAULT_MAIN_FIGURE_TITLE = "ECDF of Utility Calibration Errors"
TICK_FONTSIZE = 24
LEGEND_FONTSIZE = 24

# List of method names (as they appear in filenames/directory names) to be plotted.
# Only methods in this list AND found in the logs will be plotted.
METHODS_TO_PLOT = [
    "Uncalibrated",
    "TemperatureScaling",
    "VectorScaling",
    "MatrixScaling",
    # "DirichletODIR",
    "DirichletL2",
    "EnsembleTemperatureScaling",
    "IROvA",
    # "IROvATS",
    "IRM",
    # "IRMTS",
    # "MatrixScalingODIR",

    "PostHocUC_Union_iters125_step_fixed_0.01_sub500",
    # Add or remove methods as needed
]

# Mapping from method names (from METHODS_TO_PLOT) to desired legend names.
# If a method is in METHODS_TO_PLOT but not here, its original name will be used.
LEGEND_MAPPING = {
    "Uncalibrated": "Uncalibrated",
    "TemperatureScaling": "Temp. Scaling",
    "VectorScaling": "Vector Scaling",
    "MatrixScaling": "Matrix Scaling",
    "DirichletODIR": "Dirichlet (ODIR)",
    "DirichletL2": "Dirichlet",
    "EnsembleTemperatureScaling": "Ens. Temp. Scaling",
    "IROvA": "IROvA",
    "IROvATS": "IROvA-TS",
    "IRM": "IRM",
    "IRMTS": "IRM-TS",
    "MatrixScalingODIR": "Matrix Scaling (ODIR)",
    "PostHocUC_Union_iters125_step_fixed_0.01_sub500": "Patch",
}

# Color-blind friendly color palette with more distinct, higher contrast colors
COLORBLIND_PALETTE = [
    "#0072B2",  # blue
    "#D55E00",  # vermillion/red
    "#009E73",  # green
    "#CC79A7",  # pink
    "#F0E442",  # yellow
    "#56B4E9",  # light blue
    "#E69F00",  # orange
    "#000000",  # black
]

# Line styles for additional differentiation
LINE_STYLES = ["-", "--", "-.", ":"]


# --- Local Plotting Function for a single subplot ---
def _plot_subplot_ecdf(
    ax,
    errors_dict,
    subplot_title_prefix,
    model_name_for_title,
    dataset_name_for_title,
    methods_to_plot_config,
    legend_mapping_config,
    remove_yticks=False,
):
    ax.set_title(
        f"{subplot_title_prefix}",
        fontsize=LEGEND_FONTSIZE,
    )
    ax.grid(True, linestyle="--", alpha=0.7)
    ax.set_ylim(0, 1.05)  # Keep Y limit fixed
    ax.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)

    if remove_yticks:
        ax.set_yticks([])

    plotted_anything = False
    max_x_val = 0.0
    handles = []
    labels = []

    sorted_methods = sorted(errors_dict.keys())

    line_styles = ["-", "--", "-.", ":"]

    for i, method_name in enumerate(sorted_methods):
        if method_name in methods_to_plot_config:
            errors = np.array(errors_dict[method_name])
            if len(errors) == 0:
                print(
                    f"No errors to plot for '{subplot_title_prefix}' method: {method_name}."
                )
                continue

            legend_label = legend_mapping_config.get(method_name, method_name)
            x_ecdf = np.sort(errors)
            y_ecdf = np.arange(1, len(errors) + 1) / len(errors)

            color_idx = i % len(COLORBLIND_PALETTE)
            line_style_idx = i % len(line_styles)

            (line,) = ax.plot(
                x_ecdf,
                y_ecdf,
                marker="o",
                linestyle=line_styles[line_style_idx],
                label=legend_label,
                markersize=3,
                linewidth=3.0,
                solid_capstyle="round",
                alpha=0.8,
                color=COLORBLIND_PALETTE[color_idx],
            )
            handles.append(line)
            labels.append(legend_label)
            plotted_anything = True
            if len(x_ecdf) > 0:
                current_max_x = x_ecdf[-1]  # Max error for this method
                if current_max_x > max_x_val:
                    max_x_val = current_max_x
        else:
            print(
                f"Skipping method '{method_name}' for subplot '{subplot_title_prefix}' as it's not in METHODS_TO_PLOT config."
            )

    if plotted_anything:
        ax.set_xlim(
            left=0, right=max_x_val * 1.05 if max_x_val > 0 else 1.0
        )  # Dynamic X limit
    else:
        ax.text(
            0.5,
            0.5,
            "No data for selected methods",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax.transAxes,
            fontsize=TICK_FONTSIZE,
        )
        ax.set_xlim(left=0, right=1.0)  # Default X limit if nothing plotted
    return handles, labels


def generate_plots_for_single_model(
    model_name, logs_root_dir, output_base_plot_dir, main_title_config
):
    """
    Generates ECDF plots for a single model using pre-computed log files.
    """
    model_log_data_dir = Path(logs_root_dir) / model_name / "cdf_plot_logs"
    if not model_log_data_dir.is_dir():
        print(
            f"Log directory not found for model '{model_name}' at {model_log_data_dir}. Skipping."
        )
        return

    print(f"\nProcessing model: {model_name}")
    print(f"Reading ECDF data from: {model_log_data_dir}")

    final_linear_errors_for_plot = {}
    final_rank_errors_for_plot = {}
    found_data_files = False

    for log_file in model_log_data_dir.glob("*.npy"):
        method_name_full = log_file.stem
        errors_data = np.load(log_file)
        found_data_files = True
        method_name_simple = "Unknown"
        if "_linear_errors" in method_name_full:
            method_name_simple = method_name_full.replace("_linear_errors", "")
            if method_name_simple in METHODS_TO_PLOT:
                final_linear_errors_for_plot[method_name_simple] = errors_data
                print(
                    f"  Loaded linear errors for '{method_name_simple}' (shape: {errors_data.shape})"
                )
        elif "_rank_errors" in method_name_full:
            method_name_simple = method_name_full.replace("_rank_errors", "")
            if method_name_simple in METHODS_TO_PLOT:
                final_rank_errors_for_plot[method_name_simple] = errors_data
                print(
                    f"  Loaded rank errors for '{method_name_simple}' (shape: {errors_data.shape})"
                )

    if not found_data_files:
        print(
            f"No .npy ECDF data files found in {model_log_data_dir}. Skipping plot generation for {model_name}."
        )
        return
    if not final_linear_errors_for_plot and not final_rank_errors_for_plot:
        print(
            f"No data loaded for methods specified in METHODS_TO_PLOT for model {model_name}. Skipping plot generation."
        )
        return

    dataset_name_from_meta = model_name
    num_splits_from_meta = "N/A"
    metadata_path = model_log_data_dir / "_metadata.json"
    if metadata_path.exists():
        with open(metadata_path, "r") as f:
            metadata = json.load(f)
        dataset_name_from_meta = metadata.get("dataset_name", model_name)
        num_splits_from_meta = metadata.get("num_splits_processed", "N/A")
    else:
        dataset_name_from_meta = model_name

    model_output_plot_dir = Path(output_base_plot_dir) / model_name
    model_output_plot_dir.mkdir(parents=True, exist_ok=True)

    # Removed: original_backend = matplotlib.get_backend()

    fig, axes = plt.subplots(
        1, 2, figsize=(20, 8)
    )  # 1 row, 2 columns. Adjusted figsize for legend
    ax_rank, ax_linear = axes[0], axes[1]

    all_handles = []
    all_labels = []

    if final_rank_errors_for_plot:
        handles_rank, labels_rank = _plot_subplot_ecdf(
            ax_rank,
            final_rank_errors_for_plot,
            "Rank-Based Utility Errors",
            model_name,
            dataset_name_from_meta,
            METHODS_TO_PLOT,
            LEGEND_MAPPING,
            remove_yticks=False,
        )
        all_handles.extend(handles_rank)
        all_labels.extend(labels_rank)
    else:
        ax_rank.set_title(
            "Rank-Based Utility Errors",
            fontsize=LEGEND_FONTSIZE,
        )
        ax_rank.text(
            0.5,
            0.5,
            "No data for selected methods",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax_rank.transAxes,
            fontsize=TICK_FONTSIZE,
        )
        ax_rank.set_xlim(left=0, right=1.0)  # Default if no data
        ax_rank.set_ylim(0, 1.05)
        ax_rank.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)

    if final_linear_errors_for_plot:
        handles_linear, labels_linear = _plot_subplot_ecdf(
            ax_linear,
            final_linear_errors_for_plot,
            "Linear Utility Errors",
            model_name,
            dataset_name_from_meta,
            METHODS_TO_PLOT,
            LEGEND_MAPPING,
            remove_yticks=True,
        )
        # Avoid duplicate legend entries if methods are in both plots
        for h, l in zip(handles_linear, labels_linear):
            if l not in all_labels:
                all_handles.append(h)
                all_labels.append(l)
    else:
        ax_linear.set_title(
            "Linear Utility Errors",
            fontsize=LEGEND_FONTSIZE,
        )
        ax_linear.text(
            0.5,
            0.5,
            "No data for selected methods",
            horizontalalignment="center",
            verticalalignment="center",
            transform=ax_linear.transAxes,
            fontsize=TICK_FONTSIZE,
        )
        ax_linear.set_xlim(left=0, right=1.0)  # Default if no data
        ax_linear.set_ylim(0, 1.05)
        ax_linear.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)
        ax_linear.set_yticks([])  # Remove y-ticks for the right plot

    # Add unified legend at the bottom
    if all_handles:
        # Filter out duplicate labels/handles for the final legend
        unique_handles_labels = {}
        for handle, label in zip(all_handles, all_labels):
            if label not in unique_handles_labels:
                unique_handles_labels[label] = handle

        # Create custom legend handles with thicker lines
        custom_handles = []
        custom_labels = []

        # Sort methods to ensure consistent display order in legend
        sorted_items = sorted(unique_handles_labels.items())

        for i, (label, handle) in enumerate(sorted_items):
            color_idx = i % len(COLORBLIND_PALETTE)
            line_style_idx = i % len(LINE_STYLES)
            custom_line = plt.Line2D(
                [0],
                [0],
                color=COLORBLIND_PALETTE[color_idx],
                linewidth=4.0,  # Extra thick for legend
                linestyle=LINE_STYLES[line_style_idx],
                marker="o",
                markersize=6,  # Larger markers in legend
            )
            custom_handles.append(custom_line)
            custom_labels.append(label)

        # Use the custom handles for the legend
        fig.legend(
            custom_handles,
            custom_labels,
            loc="lower center",
            ncol=min(len(custom_handles), 4),
            bbox_to_anchor=(0.5, -0.08),
            fontsize=LEGEND_FONTSIZE,
            frameon=True,
            fancybox=True,
            handlelength=3.0,  # Not too long
            borderpad=0.4,
        )

    # Adjust bottom margin for legend
    fig.tight_layout(
        rect=[0, 0.05, 1, 1]  # Back to original margin
    )  # Adjust layout to make space for suptitle and bottom legend

    plot_filename = (
        model_output_plot_dir
        / f"{dataset_name_from_meta}_{model_name}_combined_uc_ecdf.pdf"
    )
    plt.savefig(plot_filename, bbox_inches="tight")
    plt.close(fig)
    print(f"  Saved combined ECDF plot to: {plot_filename}")
    # Removed: plt.switch_backend(original_backend)


def main():
    parser = argparse.ArgumentParser(
        description="Generate combined ECDF plots from pre-computed utility calibration error logs."
    )
    parser.add_argument(
        "--logs-root-dir",
        type=str,
        default="./experiment_cdf_data",
        help="Path to the root directory containing model-specific ECDF data subdirectories. (Default: ./experiment_cdf_data).",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="",
        help="Specific model name (directory name under logs-root-dir) to generate plots for. Set to empty string to process all models.",
    )
    parser.add_argument(
        "--output-plot-dir",
        type=str,
        default=None,
        help="Base directory to save generated plots. (Default: ./ecdf_plots_from_logs).",
    )
    parser.add_argument(
        "--main-title",
        type=str,
        default=DEFAULT_MAIN_FIGURE_TITLE,
        help=f'Main title for the combined plot figure (currently unused in plot output). (Default: "{DEFAULT_MAIN_FIGURE_TITLE}")',
    )

    args = parser.parse_args()

    # Set matplotlib parameters for consistent font sizes
    plt.rcParams.update(
        {
            "font.size": TICK_FONTSIZE,
            "axes.titlesize": LEGEND_FONTSIZE,
            "axes.labelsize": TICK_FONTSIZE,
            "xtick.labelsize": TICK_FONTSIZE,
            "ytick.labelsize": TICK_FONTSIZE,
            "legend.fontsize": LEGEND_FONTSIZE,
        }
    )

    output_plot_dir_resolved = args.output_plot_dir
    if output_plot_dir_resolved is None:
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Timestamp removed
        output_plot_dir_resolved = Path(
            "./ecdf_plots_from_logs"
        )  # Default without timestamp
    else:
        output_plot_dir_resolved = Path(output_plot_dir_resolved)

    output_plot_dir_resolved.mkdir(parents=True, exist_ok=True)
    print(f"Saving plots to subdirectories within: {output_plot_dir_resolved}")

    logs_root_path = Path(args.logs_root_dir)
    if not logs_root_path.is_dir():
        print(f"Error: Logs root directory not found: {logs_root_path}")
        return

    if args.model_name and args.model_name.strip():
        generate_plots_for_single_model(
            args.model_name, logs_root_path, output_plot_dir_resolved, args.main_title
        )
    else:
        print(
            f"No specific model name provided or model name is empty. Scanning for all models in {logs_root_path}..."
        )
        processed_any = False
        for model_dir_item in logs_root_path.iterdir():
            if model_dir_item.is_dir() and (model_dir_item / "cdf_plot_logs").is_dir():
                generate_plots_for_single_model(
                    model_dir_item.name,
                    logs_root_path,
                    output_plot_dir_resolved,
                    args.main_title,
                )
                processed_any = True
            # else: # Optionally print skipped non-relevant directories
            # print(f"Skipping {model_dir_item.name}, does not contain 'cdf_plot_logs' subdirectory or is not a directory.")
        if not processed_any:
            print("No model directories with 'cdf_plot_logs' found to process.")

    print("\nECDF plot generation from logs process completed.")


if __name__ == "__main__":
    main()
