import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# -----------------------------
# Load CSV
# -----------------------------
df = pd.read_csv("results/results_RQ_1_windows_lotka.csv")

metrics = ["AC@1", "AC@3", "AC@5", "AC@10", "Avg@10"]

# Correct architecture names if main_model matches
df.loc[df["main_model"]=="FEDformer", "architecture"] = "FEDformer"
df.loc[df["main_model"]=="iTransformer", "architecture"] = "iTransformer"

# Remove window_size=3 and deep_mlp_old
df = df[df["window_size"] != 3]
df = df[df["window_size"] != 15]
df = df[df["architecture"] != "deep_mlp_old"]

# Convert total_params to string in millions
df["Params"] = df["total_params"].apply(lambda x: f"{x/1e6:.1f}M")

# -----------------------------
# Aggregate metrics
# -----------------------------
agg_df = df.groupby(["architecture", "dataset_name", "Params","window_size"])[metrics].agg(['mean','std']).reset_index()
agg_df.columns = ["architecture", "dataset_name", "Params","window_size"] + [f"{m}_{stat}" for m in metrics for stat in ['mean','std']]

# Format mean ± std for LaTeX
for m in metrics:
    agg_df[m] = agg_df.apply(
        lambda row: f"{row[f'{m}_mean']:.3f}" + r"\tiny{{±{:.3f}}}".format(row[f'{m}_std']),
        axis=1
    )
    agg_df[f"{m}_numeric"] = agg_df[f"{m}_mean"]

# Bold/underline best and second best per dataset
for dataset in agg_df["dataset_name"].unique():
    subset = agg_df[agg_df["dataset_name"] == dataset]
    for m in metrics:
        max_idx = subset[f"{m}_numeric"].idxmax()
        second_idx = subset[f"{m}_numeric"].nlargest(2).index[-1] if len(subset) > 1 else None
        agg_df.loc[max_idx, m] = r"\textbf{" + agg_df.loc[max_idx, m] + "}"
        if second_idx is not None and second_idx != max_idx:
            agg_df.loc[second_idx, m] = r"\underline{" + agg_df.loc[second_idx, m] + "}"

# Sort for display
display_df = agg_df.sort_values(by=["dataset_name", "AC@1_numeric"], ascending=[True, True])

# -----------------------------
# LaTeX table
# -----------------------------
header = r"""\begin{tabular}{llllllll}
\toprule
scheme & Params & window_size & AC@1 & AC@3 & AC@5 & AC@10 & Avg@10 \\
\midrule
"""
# only view CrGSTA, AERCA, FEDformer, iTransformer, RCD, Epsilon in latex table
display_df = display_df[display_df["architecture"].isin(["TemporalGNN_Attention_crossattn","deep_mlp","FEDformer","iTransformer","rcd","epsilon_diagnosis"])]
#rename for better display
display_df["architecture"] = display_df["architecture"].replace({
    "rcd": "RCD",
    "epsilon_diagnosis": "Epsilon",
    "deep_mlp": "AERCA",
    "FEDformer": "FEDformer",
    "iTransformer": "iTransformer",
    "TemporalGNN_Attention_crossattn": "CrGSTA",
})
latex_lines = [header]
prev_dataset = None
for _, row in display_df.iterrows():
    dataset_val = row["dataset_name"]
    if prev_dataset is None or dataset_val != prev_dataset:
        latex_lines.append(r"\midrule")
        latex_lines.append(r"\rowcolor{gray!20} \multicolumn{8}{c}{" + dataset_val.upper() + r"} \\")
        prev_dataset = dataset_val
    scheme = row['architecture']
    params = row['Params']
    window_size = row['window_size']
    values = [row[m] for m in metrics]
    latex_lines.append(f"{scheme} & {params} & {window_size} & " + " & ".join(values) + r" \\")
latex_lines.append(r"\bottomrule")
latex_lines.append(r"\end{tabular}")
latex_table = "\n".join(latex_lines)

print(latex_table)

# -----------------------------
# Heatmap Generation
# -----------------------------
agg_df["architecture"] = agg_df["architecture"].replace({
    "rcd": "RCD",
    "epsilon_diagnosis": "Epsilon",
    "deep_mlp": "AERCA",
    "FEDformer": "FEDformer",
    "iTransformer": "iTransformer",
    "TemporalGNN_Attention_crossattn": "CrGSTA",
})

agg_df["Params_M"] = agg_df["Params"].str.replace("M","").astype(float)

categories = {
    "Statistical": ["RCD", "Epsilon"],
    "Non-causal": ["FEDformer", "iTransformer"],
    "Causal": ["AERCA", "CrGSTA"]
}

for dataset in agg_df["dataset_name"].unique():
    dataset_df = agg_df[agg_df["dataset_name"] == dataset]
    for metric in ["Avg@10"]:
        heatmap_data = dataset_df.pivot(index='architecture', columns='window_size', values=f'{metric}_numeric')
        heatmap_params = dataset_df.pivot(index='architecture', columns='window_size', values='Params_M')

        # Order rows by category
        ordered_rows = sum(categories.values(), [])
        heatmap_data = heatmap_data.reindex(ordered_rows)
        heatmap_params = heatmap_params.reindex(ordered_rows)

        # Mask large-parameter cells
        mask_special = heatmap_params > 100
        annot_data = heatmap_data.copy().astype(str)
        annot_data[mask_special] = ""

        

        plt.figure(figsize=(10,6))
        ax = sns.heatmap(
            heatmap_data,
            annot=True,
            fmt=".3f",
            cmap="YlGnBu",
            annot_kws={"size": 17},
            cbar_kws={'label': metric},
            linewidths=0.5,
            linecolor='gray'
        )
        
        # Hatch overlay
        for i in range(heatmap_data.shape[0]):
            for j in range(heatmap_data.shape[1]):
                if mask_special.iloc[i,j]:
                    ax.add_patch(plt.Rectangle(
                        (j, i), 1, 1,
                        fill=True,
                        color='lightgray',
                        alpha=0.6,
                        hatch='///',
                        edgecolor='none'
                    ))

        # Outer category labels
        ypos = 0
        for cat, models in categories.items():
            start = ypos
            end = ypos + len(models) - 1
            y_center = (start + end)/2
            ax.text(-0.3,         1 - y_center/len(ordered_rows) - 0.1, cat, rotation=90, va='center', ha='center', fontsize=16, transform=ax.transAxes)
            ypos += len(models)

        
        # -----------------------------
        # Draw horizontal lines outside the heatmap to separate categories
        # -----------------------------
        category_boundaries = []
        counter = 0
        for cat, models in categories.items():
            counter += len(models)
            category_boundaries.append(counter)

        # Convert heatmap row indices to figure coordinates
        for boundary in category_boundaries[:-1]:
            ax.plot(
                [-7, heatmap_data.shape[1]],  # full width of heatmap
                [boundary, boundary],                # row index boundary
                color='black',
                linewidth=2,
                clip_on=False                        # ensures line draws outside cells
            )


        # Axis labels and ticks
        ax.set_xlabel("Window Size", fontsize=16)
        ax.set_ylabel("")
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=14, rotation=0, ha="right")
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=14, rotation=0)

        cbar = ax.collections[0].colorbar
        cbar.ax.yaxis.label.set_size(16)
        cbar.ax.tick_params(labelsize=14)

        plt.tight_layout()
        plt.savefig(f"Scripts/Data/CrGSTA/heatmap_{dataset}_{metric}.pdf")
        plt.show()

# -----------------------------
# Params vs Window Size Plot
# -----------------------------
# only CrGSTA, AERCA, FEDformer, iTransformer, RCD, Epsilon
agg_df = agg_df[agg_df["architecture"].isin(["CrGSTA","AERCA","FEDformer","iTransformer","RCD","Epsilon"])]
plt.figure(figsize=(8,6))
for scheme, group in agg_df.groupby("architecture"):
    plt.plot(group["window_size"], group["Params_M"], marker="o", label=scheme)

plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
# font sizes
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# legend font size
plt.legend(fontsize=12)
# label size 
plt.xlabel("Window Size", fontsize=16)
plt.ylabel("Number of Parameters (Millions)", fontsize=16)
plt.tight_layout()
plt.savefig(f"Scripts/Data/CrGSTA/params_vs_windows_Lotka.pdf")
plt.show()
