import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

cache_folder = #cache folder
plot_folder = #plot folder

df = pd.read_csv(cache_folder+"/CA_test_onlyTime_noNat.csv")

x_value = "perturbation_sensitivity"
y_value = "accuracy/CNN_diff_majority"
y_value = "accuracy"

k = 16  # number of bins
time_factors = sorted(df["time_factor"].unique())
palette = sns.color_palette("Dark2", n_colors=len(time_factors))
palette_dict = dict(zip(time_factors, palette))

# bin edges
x_min = df[x_value].min()
x_max = df[x_value].max()
bins = np.linspace(x_min, x_max, k+1)
bin_centers = (bins[:-1] + bins[1:]) / 2
offset = 0.01 * (x_max - x_min)  # 2% of x-range, tweak if needed
n_colors = len(time_factors)
darker_palette = [(r * 0.6, g * 0.6, b * 0.6) for r, g, b in palette]
fig, axs = plt.subplots(nrows=2, ncols=6, figsize=(3*(4), 4), sharey='row')

#palette = sns.color_palette("Dark2", n_colors=df["time_factor"].nunique())

palette_dict = dict(zip(time_factors, palette))
for i in range(2,8):
    indx = i-2
    ax = axs[0,indx%6]
    #df for time_factor = i
    df_time = df[df["time_factor"] == i]

    sns.scatterplot(
        data=df_time,
        x="perturbation_sensitivity",
        y="accuracy",
        hue="time_factor",
        palette=palette_dict,
        alpha=0.8,
        ax=ax,
        legend=False
    )
    axs[0, indx].set_title(f'T={i}', fontsize=12)
    #disable x axis and x label and ticks
    axs[0, indx].set_xticks([])
    axs[0, indx].set_xlabel('')
    if indx == 0:
        ax.set_ylabel('Accuracy', fontsize=10)
    



palette_dict = dict(zip(time_factors, darker_palette))
for i in range(2,8):
    indx = i-2
    ax = axs[1,indx%6]
    #df for time_factor = i
    df_time = df[df["time_factor"] == i]

    sns.scatterplot(
        data=df_time,
        x="perturbation_sensitivity",
        y="accuracy/CNN_diff_majority",
        hue="time_factor",
        palette=palette_dict,
        alpha=0.8,
        ax=ax,
        legend=False
    )
    ax.set_xlim(-0.02,0.55)
    if indx == 0:
        ax.set_ylabel('Accuracy-Majority', fontsize=10)

    ax.set_xlabel('PS', fontsize=10)

plt.tight_layout()
plt.subplots_adjust(hspace=0.1, wspace=0.2)
plt.savefig(plot_folder+"/scatter_time_factor.png", dpi=300, bbox_inches='tight')
plt.show()