import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np


cache_folder = #cache folder
plot_folder = #plot folder

df = pd.read_csv(cache_folder+"/CA_test_onlyTime_noNat.csv")
df = pd.read_csv(cache_folder+"/CA_test_spaceAndTime_noNat.csv")


x_value = "perturbation_sensitivity"
y_value = "accuracy/CNN_diff_majority"
y_value = "accuracy"

k = 16  # number of bins
time_factors = [3,4,5]
spatial_factors = [3,5,7] 

palette = sns.color_palette("Dark2", n_colors=len(spatial_factors)*2)
palette_dict = dict(zip(spatial_factors, palette))
l = list(range(100,100+len(spatial_factors)))
l2 = l + spatial_factors
palette_dict2 = dict(zip(l2, palette))
print(palette_dict2)
# bin edges
x_min = df[x_value].min()
x_max = df[x_value].max()
bins = np.linspace(x_min, x_max, k+1)
bin_centers = (bins[:-1] + bins[1:]) / 2
offset = 0.01 * (x_max - x_min)  # 2% of x-range, tweak if needed
n_colors = len(time_factors)
darker_palette = [(r * 0.6, g * 0.6, b * 0.6) for r, g, b in palette]
fig, axs = plt.subplots(nrows=2, ncols=6, figsize=(12, 4), sharey='row')


for time in [3,5]:
    for indxS,space in enumerate(spatial_factors):
        ax = axs[0,indxS+3*(time//5)]
        #df for time_factor = i
        df_time = df[df["time_factor"] == time]
        df_space = df_time[df_time["spatial_factor"]== space]
        pal = palette_dict2 if time == 5 else palette_dict
        sns.scatterplot(
            data=df_space,
            x="perturbation_sensitivity",
            y="accuracy",
            hue="spatial_factor",
            palette=pal,
            alpha=0.8,
            ax=ax,
            legend=False
        )
        ax.set_title(f'T={time}, S={space}', fontsize=12)
        #disable x axis and x label and ticks
        ax.set_xticks([])
        ax.set_xlabel('')
        if indxS == 0 and time == 3:
            ax.set_ylabel('Accuracy', fontsize=10)
        

    for indxS,space in enumerate(spatial_factors):
        ax = axs[1,indxS+3*(time//5)]
        #df for time_factor = i
        df_time = df[df["time_factor"] == time]
        df_space = df_time[df_time["spatial_factor"]== space]
        pal = palette_dict2 if time == 5 else palette_dict
        darker_palette = {}
        print(pal)
        for key,val in pal.items():
            darker_palette[key] = (val[0]*0.6,val[1]*0.6,val[2]*0.6)

        sns.scatterplot(
            data=df_space,
            x="perturbation_sensitivity",
            y="accuracy/CNN_diff_majority",
            hue="spatial_factor",
            palette=darker_palette,
            alpha=0.8,
            ax=ax,
            legend=False
        )
        #ax.set_title(f'T={time}, S={space}', fontsize=10)
        ax.set_xlabel('PS', fontsize=10)
        #disable x axis and x label and ticks
        if indxS == 0 and time == 3:
            ax.set_ylabel('Accuracy-Majority', fontsize=10)

from matplotlib.lines import Line2D

plt.tight_layout()
plt.subplots_adjust(hspace=0.1, wspace=0.2)

line_x = 0.52  # halfway between 3rd and 4th column, adjust if needed
line = Line2D([line_x, line_x], [0.08, 0.96], transform=fig.transFigure,
              color='black', linewidth=1, linestyle='--')  # customize
fig.add_artist(line)

plt.savefig(plot_folder+"/scatter_space_factor.png", dpi=300, bbox_inches='tight')
plt.show()