import pandas as pd
import os
import csv
import argparse
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
import matplotlib.cm as cm

from config.config import RESULTS_TEST_INTERNAL, RESULTS_TEST_EXTERNAL

# === Set matplotlib font sizes globally ===
plt.rcParams.update({
    'axes.titlesize': 16,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 12,
    'figure.titlesize': 16,
})

def main(dataset_type: str):
    # === Choose roots based on dataset type ===
    if dataset_type == "EXTERNAL":
        results_root = RESULTS_TEST_EXTERNAL
        combo_dirname = "combination_external_test"
    else:  # INTERNAL (default)
        results_root = RESULTS_TEST_INTERNAL
        combo_dirname = "combination_internal_test"

    target_dir = os.path.join(results_root, combo_dirname)
    os.makedirs(target_dir, exist_ok=True)

    # === Define paths ===
    csv_path = os.path.join(results_root, combo_dirname, "final_modelwise_coverage_accuracy_results.csv")
    output_dir = os.path.join(results_root, "plot_combinations_rectangles_colored")
    os.makedirs(output_dir, exist_ok=True)

    # === Load CSV with delimiter auto-detection ===
    with open(csv_path, "r", encoding="utf-8-sig") as f:
        sample = f.read(2048)
        f.seek(0)
        dialect = csv.Sniffer().sniff(sample)
        print(f"Detected delimiter: {repr(dialect.delimiter)}")
        df = pd.read_csv(f, delimiter=dialect.delimiter)

    # === Clean and sort dataframe ===
    df.columns = df.columns.str.strip()
    df = df.sort_values(by="mean_subset_coverage")
    print(df)
    print("Columns:", df.columns.tolist())

    # === Assign distinct colors using tab20c colormap ===
    unique_models = df["model"].unique()
    n_colors = len(unique_models)
    cmap = cm.get_cmap("tab20c", n_colors)
    model_colors = {model: cmap(i) for i, model in enumerate(unique_models)}

    # === Plotting ===
    for model_name, model_df in df.groupby("model"):
        fig, ax = plt.subplots(figsize=(10, 6))
        color = model_colors[model_name]

        for _, row in model_df.iterrows():
            x = row["mean_subset_coverage"]
            y = row["mean_accuracy"]
            x_err = row["std_subset_coverage"]
            y_err = row["std_accuracy"]

            # Filled rectangle
            filled_rect = Rectangle(
                (x - x_err, y - y_err),
                2 * x_err,
                2 * y_err,
                facecolor=color,
                edgecolor=None,
                alpha=0.3,
                zorder=1
            )
            ax.add_patch(filled_rect)

            # Outline rectangle
            edge_rect = Rectangle(
                (x - x_err, y - y_err),
                2 * x_err,
                2 * y_err,
                facecolor='none',
                edgecolor=color,
                linestyle='--',
                linewidth=1,
                zorder=2
            )
            ax.add_patch(edge_rect)

            # Central dot
            ax.plot(x, y, 'o', color=color, markersize=4, zorder=3)

        ax.set_xlabel("Coverage")
        ax.set_ylabel("Mean Accuracy")
        ax.set_title(f"Accuracy vs Coverage: {model_name}")
        ax.set_xlim(left=model_df["mean_subset_coverage"].min() - 0.02, right=1.0)
        ax.set_ylim(bottom=0.55)
        ax.grid(True, linestyle=":", linewidth=0.8)
        ax.tick_params(axis='both', which='major', labelsize=12)

        # === Custom Legend ===
        legend_elements = [
            Line2D([0], [0], marker='o', color='w', label='Mean coverage/accuracy',
                   markerfacecolor=color, markersize=6),
            Rectangle((0, 0), 1, 1, facecolor=color, alpha=0.3,
                      edgecolor=color, linestyle='--', linewidth=1,
                      label='±1 std interval')
        ]
        ax.legend(handles=legend_elements, fontsize=12)

        fig.tight_layout()

        # === Save with dataset suffix ===
        safe_name = f"{model_name}_accuracy_vs_coverage_{dataset_type}".replace("/", "_")
        save_path = os.path.join(output_dir, f"{safe_name}.png")
        fig.savefig(save_path, dpi=300)
        plt.close(fig)
        print(f"Saved: {save_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Plot model-wise accuracy vs coverage rectangles for INTERNAL/EXTERNAL datasets."
    )
    parser.add_argument(
        "dataset_type",
        nargs="?",
        default="INTERNAL",
        choices=["INTERNAL", "EXTERNAL"],
        help="Dataset type: INTERNAL or EXTERNAL (default: INTERNAL)"
    )
    args = parser.parse_args()
    main(args.dataset_type)
