import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import matplotlib.patches as patches

# -----------------------------
# Load CSV
# -----------------------------
df = pd.read_csv("results/results_RQ_3_ablations_lotka32.csv")

# Filter for window_size = 7
df = df[df["window_size"] == 7]
# remove rows wehere time_freq_representation="mag_phase" and combine_method != "freq_only" (as we only view branch of freq mag only)
df = df[~((df["time_freq_representation"] == "mag_phase") & (df["combine_method"] == "freq_only"))]

# Metric to display
metric = "Avg@10"
param_col = "total_params"

# -----------------------------
# Aggregate (mean over seeds)
# -----------------------------
agg_df = (
    df.groupby(["architecture", "time_freq_representation", "combine_method"])[[metric, param_col]]
    .mean()
    .reset_index()
)
# combine method attention -> attn 
agg_df.loc[agg_df["combine_method"] == "attention", "combine_method"] = "attn"


# -----------------------------
# Map x-axis
# -----------------------------
def x_axis_map(row):
    if row['architecture'] == "TemporalGNN_Attention":
        return "T"
    elif row['time_freq_representation'] == "normal":
        return "T-F (Mag)"
    elif row['time_freq_representation'] == "learnable_filter":
        return "Learnable"
    elif row['time_freq_representation'] == "mag_phase":
        return "T-F (Mag-Phase)"
    elif row['time_freq_representation'] == "mag_phase_learnable_filter":
        return "Mag-Phase-Learnable"
    elif row['time_freq_representation'] == "Freq":
        return "F (Mag)"
    else:
        return row['time_freq_representation']
agg_df.loc[agg_df["combine_method"] == "freq_only", "time_freq_representation"] = "Freq"
agg_df["x_axis"] = agg_df.apply(x_axis_map, axis=1)
agg_df["y_axis"] = agg_df["combine_method"]
# if architecture contains "TemporalGNN_Attention_crossattn" --> y_order = "attention"
agg_df.loc[agg_df["architecture"].str.contains("TemporalGNN_Attention_crossattn"), "y_axis"] = "attn"
# if x_axis time -> make y_order = "none"
agg_df.loc[agg_df["x_axis"] == "T", "y_axis"] = ""
agg_df.loc[agg_df["x_axis"] == "F (Mag)", "y_axis"] = ""
# -----------------------------
# Pivot for heatmap
# -----------------------------
heatmap_data = agg_df.pivot(index="y_axis", columns="x_axis", values=metric)

# Reorder columns and rows
x_order = ["T","F (Mag)", "T-F (Mag)", "T-F (Mag-Phase)"]
y_order = ["","sum","gated","concat","attn"]
heatmap_data = heatmap_data.reindex(index=y_order, columns=x_order)

# -----------------------------
# Plot
# -----------------------------
plt.figure(figsize=(10, 4))

ax = sns.heatmap(
    heatmap_data,
    annot=True,
    fmt=".3f",
    cmap="YlGnBu",
    annot_kws={"size": 16},
    cbar_kws={'label': metric},
    linewidths=1.2,         # thickness of grid lines
    linecolor="black",       # color of grid lines
)

# -----------------------------
# Outer border
# -----------------------------
ax.add_patch(
    patches.Rectangle(
        (0, 0),                       # bottom-left corner
        heatmap_data.shape[1],        # width (#cols)
        heatmap_data.shape[0],        # height (#rows)
        fill=False,
        edgecolor="black",
        lw=2.5                         # border thickness
    )
)
# move x ticks to top
# Labels
#ax.set_xlabel("Architecture / Frequency Representation", fontsize=14)
#ax.set_ylabel("Combine Method", fontsize=14)
# Remove x and y labels
ax.set_xlabel("", fontsize=14)
ax.set_ylabel("", fontsize=14)
# x and y ticks size
ax.tick_params(axis='x', labelsize=20)
ax.tick_params(axis='y', labelsize=13)
plt.tight_layout()

plt.savefig("Scripts/Data/CrGSTA/RQ3_ablations/heatmap_ablations_lotka.pdf")
#plt.show()


# -----------------------------
# Plot 2: Bar chart (Params per combine method)
# -----------------------------
plt.figure(figsize=(10, 6))
palette = sns.color_palette("Set2")

order = ["T","F (Mag)", "T-F (Mag)", "T-F (Mag-Phase)"]
# params make it in millions
agg_df[param_col] = agg_df[param_col] / 1e6
ax = sns.barplot(
    data=agg_df,
    x="x_axis",
    y=param_col,
    hue="y_axis",
    palette=palette,
    order=order,
)

# Titles & labels
ax.set_xlabel("Architecture / Representation", fontsize=15)
ax.set_ylabel("Num Params (Millions)", fontsize=15)

# Format y-axis with commas
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: f"{int(x):,}"))
# font of x, y ticks and labels
ax.tick_params(axis='x', labelsize=15)
ax.tick_params(axis='y', labelsize=15)
ax.xaxis.label.set_size(15)
ax.yaxis.label.set_size(15)


# Legend formatting
plt.legend(title="Combine Method", title_fontsize=11, fontsize=15, loc="upper left")



# Clean style
sns.despine(right=True, top=True)

ax.grid(axis="y", linestyle="--", alpha=0.6)
plt.savefig("Scripts/Data/CrGSTA/RQ3_ablations/bar_params_ablations_lotka.pdf", bbox_inches='tight')
plt.tight_layout()
plt.show()




# -----------------------------
# Build LaTeX table (row = combination, cols = all metrics)
# -----------------------------
metrics = ["AC@1", "AC@3", "AC@5", "AC@10", "Avg@10"]

# Average across seeds
agg_df = (
    df.groupby(["architecture", "time_freq_representation", "combine_method"])[metrics + [param_col]]
    .mean()
    .reset_index()
)

# Clean names (reuse your logic)
agg_df.loc[agg_df["combine_method"] == "attention", "combine_method"] = "attn"
agg_df.loc[agg_df["combine_method"] == "freq_only", "time_freq_representation"] = "Freq (Mag)"
agg_df["x_axis"] = agg_df.apply(x_axis_map, axis=1)
agg_df["y_axis"] = agg_df["combine_method"]
agg_df.loc[agg_df["x_axis"] == "T", "y_axis"] = "None"
agg_df.loc[agg_df["x_axis"] == "F", "y_axis"] = "None"
agg_df.loc[agg_df["architecture"].str.contains("TemporalGNN_Attention_crossattn"), "y_axis"] = "attn"


# remove x-axis = Learnable, Mag-Phase-Learnable, attnention2
agg_df = agg_df[~agg_df["x_axis"].isin(["Learnable", "Mag-Phase-Learnable"])]
agg_df = agg_df[~agg_df["y_axis"].isin(["attention2"])]
# sort by Avg@10 descending
agg_df = agg_df.sort_values(by="Avg@10", ascending=True).reset_index(drop=True)

# Params column (M)
agg_df["Params"] = agg_df[param_col].apply(lambda x: f"{x/1e6:.1f}M")

# Bold + underline per metric (global best/second best)
for m in metrics:
    max_idx = agg_df[m].idxmax()
    second_idx = agg_df[m].nlargest(2).index[-1] if len(agg_df) > 1 else None

    agg_df.loc[max_idx, m] = r"\textbf{" + f"{agg_df.loc[max_idx, m]:.3f}" + "}"
    if second_idx is not None and second_idx != max_idx:
        agg_df.loc[second_idx, m] = r"\underline{" + f"{agg_df.loc[second_idx, m]:.3f}" + "}"
    # Format others
    for i in agg_df.index:
        if isinstance(agg_df.loc[i, m], float):
            agg_df.loc[i, m] = f"{agg_df.loc[i, m]:.3f}"

# -----------------------------
# Build LaTeX
# -----------------------------
header = r"""\begin{tabular}{lllcccccc}
\toprule
Architecture & Combination Method & Params & AC@1 & AC@3 & AC@5 & AC@10 & Avg@10 \\
\midrule
"""

latex_lines = [header]

for _, row in agg_df.iterrows():
    values = [row[m] for m in metrics]
    latex_lines.append(f"{row['x_axis']} & {row['y_axis']} & {row['Params']} & " + " & ".join(values) + r" \\")
    
latex_lines.append(r"\bottomrule")
latex_lines.append(r"\end{tabular}")

latex_table = "\n".join(latex_lines)

print(latex_table)
