import pandas as pd
from io import StringIO

### Load into DataFrame
df = pd.read_csv("results/results_RQ2_spatial.csv")

metrics = ["AC@1", "AC@3", "AC@5", "AC@10", "Avg@10"]

# Convert total_params to millions for LaTeX display
df["Params"] = df["total_params"].apply(lambda x: f"{x/1e6:.1f}M")
# only if main model is FEDformer or iTransformer, set architecture to that

df.loc[df["main_model"] == "FEDformer","architecture"] = "FEDformer"
df.loc[df["main_model"] == "iTransformer","architecture"] = "iTransformer"

# Aggregate metrics
agg_df = df.groupby(["architecture", "dataset_name", "Params","num_vars"])[metrics].agg(['mean','std']).reset_index()
agg_df.columns = ["architecture", "dataset_name", "Params","num_vars"] + [f"{m}_{stat}" for m in metrics for stat in ['mean','std']]

# Format mean ± std for LaTeX
for m in metrics:
    agg_df[m] = agg_df.apply(
        lambda row: f"{row[f'{m}_mean']:.3f}" + r"\tiny{" + f"±{row[f'{m}_std']:.3f}" + "}",
        axis=1
    )
    agg_df[f"{m}_numeric"] = agg_df[f"{m}_mean"]

# Bold/underline per dataset
for dataset in agg_df["dataset_name"].unique():
    subset = agg_df[agg_df["dataset_name"] == dataset]
    for m in metrics:
        max_idx = subset[f"{m}_numeric"].idxmax()
        second_idx = subset[f"{m}_numeric"].nlargest(2).index[-1] if len(subset) > 1 else None
        agg_df.loc[max_idx, m] = r"\textbf{" + agg_df.loc[max_idx, m] + "}"
        if second_idx is not None and second_idx != max_idx:
            agg_df.loc[second_idx, m] = r"\underline{" + agg_df.loc[second_idx, m] + "}"

# Sort by dataset and AC@1 numeric for display
display_df = agg_df.sort_values(by=["dataset_name", "AC@1_numeric"], ascending=[True, True])
# only view CrGSTA, AERCA, FEDformer, iTransformer, RCD, Epsilon in latex table
display_df = display_df[display_df["architecture"].isin(["TemporalGNN_Attention_crossattn","deep_mlp","FEDformer","iTransformer","rcd","epsilon_diagnosis"])]
#rename for better display
display_df["architecture"] = display_df["architecture"].replace({
    "rcd": "RCD",
    "epsilon_diagnosis": "Epsilon",
    "deep_mlp": "AERCA",
    "FEDformer": "FEDformer",
    "iTransformer": "iTransformer",
    "TemporalGNN_Attention_crossattn": "CrGSTA",
})
# -----------------------------
# Build LaTeX table manually (with num_vars)
# -----------------------------
header = r"""\begin{tabular}{llllllll}  % note 8 columns now
\toprule
scheme & Params & num_vars & AC@1 & AC@3 & AC@5 & AC@10 & Avg@10 \\
\midrule
"""

latex_lines = [header]
prev_dataset = None

for _, row in display_df.iterrows():
    dataset_val = row["dataset_name"]
    if prev_dataset is None or dataset_val != prev_dataset:
        latex_lines.append(r"\midrule")
        latex_lines.append(r"\rowcolor{gray!20} \multicolumn{8}{c}{" + dataset_val.upper() + r"} \\")
        prev_dataset = dataset_val
    
    scheme = f"{row['architecture']}"
    params = row["Params"]
    num_vars = row["num_vars"]
    values = [row[m] for m in metrics]
    latex_lines.append(f"{scheme} & {params} & {num_vars} & " + " & ".join(values) + r" \\")

latex_lines.append(r"\bottomrule")
latex_lines.append(r"\end{tabular}")

latex_table = "\n".join(latex_lines)
print(latex_table)



#-----heatmap generation-----
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

metrics = ["AC@1","Avg@10"]

# change scheme names for better display
agg_df["architecture"] = agg_df["architecture"].replace({ 
                                                        "rcd": "RCD",
                                                        "epsilon_diagnosis": "Epsilon",
                                                        "deep_mlp": "AERCA",
                                                         "FEDformer":  "FEDformer",
                                                         "iTransformer": "iTransformer",
                                                        
                                                        "TemporalGNN_Attention_crossattn": "CrGSTA",
                                                        })

categories = {
    "Non-causal": ["FEDformer", "iTransformer"],
    "Causal": ["AERCA", "CrGSTA"]
}

# Iterate over datasets
for dataset in agg_df["dataset_name"].unique():
    dataset_df = agg_df[agg_df["dataset_name"] == dataset]

    # Iterate over metrics
    for metric in metrics:
        plt.figure(figsize=(10, 6))
        
        # Pivot: rows = scheme, columns = window size
        heatmap_data = dataset_df.pivot(index='architecture', columns='num_vars', values=f'{metric}_numeric')
        # order schemes
        scheme_order = ["FEDformer", "iTransformer","AERCA","CrGSTA"]
        heatmap_data = heatmap_data.reindex(scheme_order)
        # Create heatmap
        ax = sns.heatmap(
            heatmap_data,
            annot=True,
            fmt=".3f",
            cmap="YlGnBu",
            annot_kws={"size": 17},       # numbers inside cells
            cbar_kws={'label': metric},    # colorbar label
            linewidths=0.5,
            linecolor='black'
        )

        # Order rows by category
        ordered_rows = sum(categories.values(), [])
        heatmap_data = heatmap_data.reindex(ordered_rows)

        # Outer category labels
        ypos = 0
        for cat, models in categories.items():
            start = ypos
            end = ypos + len(models) - 1
            y_center = (start + end)/2
            ax.text(-0.3,         1 - y_center/len(ordered_rows) - 0.1, cat, rotation=90, va='center', ha='center', fontsize=16, transform=ax.transAxes)
            ypos += len(models)

        # Axis labels
        ax.set_xlabel("X Axis Label", fontsize=18)
        #ax.set_ylabel("Y Axis Label", fontsize=18)

        #hide y_label
        ax.set_ylabel("", fontsize=18)  
        # Tick labels
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=14, rotation=45, ha="right")
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=14, rotation=0)

        # Colorbar label size
        cbar = ax.collections[0].colorbar
        cbar.ax.yaxis.label.set_size(16)   # colorbar label font size
        cbar.ax.tick_params(labelsize=14)  # colorbar ticks font size
        # -----------------------------
        # Draw horizontal lines outside the heatmap to separate categories
        # -----------------------------
        category_boundaries = []
        counter = 0
        for cat, models in categories.items():
            counter += len(models)
            category_boundaries.append(counter)

        # Convert heatmap row indices to figure coordinates
        for boundary in category_boundaries[:-1]:
            ax.plot(
                [-7, heatmap_data.shape[1]],  # full width of heatmap
                [boundary, boundary],                # row index boundary
                color='black',
                linewidth=2,
                clip_on=False                        # ensures line draws outside cells
            )

        #plt.ylabel("Scheme")
        plt.xlabel("Num of Variables")
        plt.tight_layout()
        plt.savefig(f"Scripts/Data/CrGSTA/heatmap_sptaial_{dataset}_{metric}.pdf")
        plt.show()




# Convert Params string back to numeric (in millions)
agg_df["Params_M"] = agg_df["Params"].str.replace("M", "").astype(float)

# only CrGSTA, AERCA, FEDformer, iTransformer, RCD, Epsilon
agg_df = agg_df[agg_df["architecture"].isin(["CrGSTA","AERCA","FEDformer","iTransformer","RCD","Epsilon"])]
plt.figure(figsize=(8,6))
for scheme, group in agg_df.groupby("architecture"):
    plt.plot(group["num_vars"], group["Params_M"], marker="o", label=scheme)

plt.legend()
plt.grid(True, linestyle="--", alpha=0.7)
# font sizes
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# legend font size
plt.legend(fontsize=12)
# label size 
plt.xlabel("Number of Variables", fontsize=16)
plt.ylabel("Number of Parameters (Millions)", fontsize=16)
plt.tight_layout()
plt.savefig(f"Scripts/Data/CrGSTA/params_vs_windows_spatial_{dataset}.pdf")
plt.show()
