import os
import json
import csv
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
from config.config import RESULTS_TRAIN
# Matplotlib settings
plt.rcParams.update({
    'axes.titlesize': 26,
    'axes.labelsize': 22,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'legend.fontsize': 10,  # reduced from 20
    'figure.titlesize': 26,
})

# Define paths
csv_path = os.path.join(RESULTS_TRAIN, "integrals", "integrals_by_model_and_embedding.csv")
output_dir = os.path.join(RESULTS_TRAIN, "images")
os.makedirs(output_dir, exist_ok=True)

# Auto-detect delimiter and load CSV
with open(csv_path, "r", encoding='utf-8-sig') as f:
    sample = f.read(2048)
    f.seek(0)
    dialect = csv.Sniffer().sniff(sample)
    print(f"Detected delimiter: {repr(dialect.delimiter)}")
    df = pd.read_csv(f, delimiter=dialect.delimiter)

# Clean column names
df.columns = df.columns.str.strip()
print("Columns:", df.columns.tolist())
print(df)
# Mapping of working names to LaTeX table labels
name_map = {
    "mobilenet_v2": "MobileNet-V2",
    "dinov1_s": "DINO-V1 ViT-S/16",
    "dinov1_b": "DINO-V1 ViT-B/16",
    "dinov1_b8": "DINO-V1 ViT-B/8",
    "dinov2_s": "DINO-V2 ViT-S/14",
    "dinov2_b": "DINO-V2 ViT-B/14",
}

# Replace emb_model with human-readable names
df["emb_model"] = df["emb_model"].map(name_map).fillna(df["emb_model"])

# Create color palette
unique_emb_models = sorted(df["emb_model"].unique())
n_colors = len(unique_emb_models)
cmap = cm.get_cmap("tab20c", n_colors)
color_dict = {
    emb_model: cmap(i / (n_colors - 1) if n_colors > 1 else 0.5)
    for i, emb_model in enumerate(unique_emb_models)
}

# Capture legend handles/labels only once
legend_handles = []
legend_labels = []

# Plotting
for model_idx, (model_name, model_df) in enumerate(df.groupby("model")):
    plt.figure(figsize=(10, 6))

    for emb_model, emb_df in model_df.groupby("emb_model"):
        emb_df_sorted = emb_df.sort_values("N")
        line, = plt.plot(
            emb_df_sorted["N"],
            emb_df_sorted["normalized_adjusted"],
            label=emb_model,
            color=color_dict[emb_model],
            linewidth=2.0
        )

        if model_idx == 0:
            legend_handles.append(line)
            legend_labels.append(emb_model)

    plt.xlabel("N")
    plt.ylabel("Normalized Adjusted Integral")
    plt.grid(True, linestyle="--", linewidth=0.5)
    plt.tick_params(axis='both', which='both', direction='in')
    plt.xlim(0, 750)
    plt.tight_layout()

    filename = f"{model_name}_normalized_adjusted_integral.png".replace("/", "_")
    save_path = os.path.join(output_dir, filename)
    plt.savefig(save_path, dpi=300)
    plt.close()

# Save legend separately with smaller font
fig_legend = plt.figure(figsize=(12, 1))
fig_legend.legend(
    handles=legend_handles,
    labels=legend_labels,
    loc='center',
    ncol=3,
    fontsize=10,
    frameon=False
)
legend_path = os.path.join(output_dir, "shared_legend.png")
fig_legend.tight_layout()
fig_legend.savefig(legend_path, dpi=300, bbox_inches='tight')

print(f"Plots saved to: {output_dir}")
print(f"Legend saved as: {legend_path}")

