# # %%
# import pandas as pd
# import numpy as np
# import matplotlib.pyplot as plt
# from scipy.stats import pearsonr
# from enzyme import PRJ_ROOT

# # Load data
# # Assuming you have a file 'Behavior_unified_curated.mat' converted to a CSV file
# T = pd.read_csv(PRJ_ROOT / f"Data/mice/T.csv")

# # Initialize variables
# mouse_num = T['mouse_number'].max()
# PGOs = T['PGo'].unique()
# PGO_N = len(PGOs)
# trial_N = 10

# all_trial_waits = np.zeros((mouse_num, PGO_N))
# wait_from_switch = np.zeros((mouse_num, trial_N, PGO_N))
# wait_from_switch_prev_pgo = np.zeros((mouse_num, trial_N, PGO_N))

# # Loop over PGOs and mice
# for p in range(PGO_N):
#     for m in range(1, mouse_num + 1):
#         mouse_inds = T['mouse_number'] == m
#         PGO_inds = T['PGo'] == PGOs[p]
#         inds = mouse_inds & PGO_inds
        
#         if inds.sum() > 0:
#             all_trial_waits[m-1, p] = T.loc[inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
            
#             for trial in range(1, trial_N + 1):
#                 curr_trial_inds = inds & (T['trial_number_from_switch'] == trial)
#                 wait_from_switch[m-1, trial-1, p] = T.loc[curr_trial_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
                
#                 prev_pgo_inds = mouse_inds & (T['PGo'].shift(trial) == PGOs[p]) & (T['trial_number_from_switch'] == trial)
#                 wait_from_switch_prev_pgo[m-1, trial-1, p] = T.loc[prev_pgo_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)

# # More code for averages and plotting can be added similarly

# # For correlation calculations, you could use scipy's pearsonr function, something like:
# # r, _ = pearsonr(curr_PGOs, mouse_total_wait)


# # %%
# # Initialize arrays for mouse averages
# mouse_avg_waits = np.zeros(PGO_N)
# mouse_avg_rews = np.zeros(PGO_N)
# mouse_avg_trial_dur = np.zeros(PGO_N)
# mouse_avg_rew_rate = np.zeros(PGO_N)
# mouse_avg_wait_from_switch = np.zeros((trial_N, PGO_N))
# mouse_avg_wait_from_switch_prev_pgo = np.zeros((trial_N, PGO_N))

# # Compute mouse averages
# for p in range(PGO_N):
#     min_wait = 0.2
#     inds_non_impulsive = (T['PGo'] == PGOs[p]) & (T['wait_from_last_NoGo_duration'] > min_wait)
    
#     mouse_avg_waits[p] = T.loc[inds_non_impulsive, 'wait_from_last_NoGo_duration'].mean(skipna=True)
#     mouse_avg_rews[p] = T.loc[inds_non_impulsive, 'performance'].mean(skipna=True)
#     mouse_avg_trial_dur[p] = (T.loc[inds_non_impulsive, 'choice_time'] + 3).mean(skipna=True)
#     mouse_avg_rew_rate[p] = mouse_avg_rews[p] / mouse_avg_trial_dur[p]
    
#     for trial in range(1, trial_N + 1):
#         curr_trial_inds = (T['PGo'] == PGOs[p]) & (T['trial_number_from_switch'] == trial)
#         mouse_avg_wait_from_switch[trial-1, p] = T.loc[curr_trial_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
        
#         prev_pgo_inds = (T['PGo'].shift(trial) == PGOs[p]) & (T['trial_number_from_switch'] == trial)
#         mouse_avg_wait_from_switch_prev_pgo[trial-1, p] = T.loc[prev_pgo_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)

# # Correlation calculations
# def safe_pearsonr(x, y):
#     """Calculate Pearson correlation, safely ignoring NaN values."""
#     mask = ~np.isnan(x) & ~np.isnan(y)
#     if np.any(mask):
#         return pearsonr(x[mask], y[mask])
#     else:
#         return np.nan, np.nan

# curr_corr = np.zeros(trial_N)
# prev_corr = np.zeros(trial_N)
# curr_corr_wait = np.zeros(trial_N)
# prev_corr_wait = np.zeros(trial_N)

# curr_corr_p = np.zeros(trial_N)
# prev_corr_p = np.zeros(trial_N)

# rdm_corr_curr = np.zeros(trial_N)
# rdm_corr_prev = np.zeros(trial_N)




# # %%



# for trial in range(1, trial_N + 1):
#     trial_inds = (T['trial_number_from_switch'] == trial) & (~T['wait_from_last_NoGo_duration'].isna())
#     curr_PGOs = T.loc[trial_inds, 'PGo'].to_numpy()
#     curr_wait_opts = wait_opts_PGo[np.argmin(np.abs(curr_PGOs[:, None] - PGOs_sorted[None, :]), axis=1)]
#     rdm_wait_opts = wait_opts_PGo[np.random.randint(0, len(PGOs_sorted), len(curr_PGOs))]

#     prev_PGOs = T['PGo'].shift(trial)[trial_inds].to_numpy()
#     prev_wait_opts = wait_opts_PGo[np.argmin(np.abs(prev_PGOs[:, None] - PGOs_sorted[None, :]), axis=1)]

#     mouse_total_wait = T.loc[trial_inds, 'wait_from_last_NoGo_duration'].to_numpy()

#     r, _ = safe_pearsonr(curr_PGOs, mouse_total_wait)
#     curr_corr[trial-1] = r


#     r, _ = safe_pearsonr(prev_PGOs, mouse_total_wait)
#     prev_corr[trial-1] = r

#     r, p = safe_pearsonr(curr_wait_opts, mouse_total_wait)
#     curr_corr_wait[trial-1] = r
#     curr_corr_p[trial-1] = p

#     r, p = safe_pearsonr(prev_wait_opts, mouse_total_wait)
#     prev_corr_wait[trial-1] = r
#     prev_corr_p[trial-1] = p

#     r_rdm, _ = safe_pearsonr(curr_wait_opts, rdm_wait_opts)
#     rdm_corr_curr[trial-1] = r_rdm

#     r_rdm, _ = safe_pearsonr(prev_wait_opts, rdm_wait_opts)
#     rdm_corr_prev[trial-1] = r_rdm




# if __name__ == "__main__":
#     plt.plot(PGOs_sorted, wait_opts_PGo_, 'o')
#     plt.plot(PGOs_dense, f(PGOs_dense), '-')

#     # Plotting
#     fix, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
#     ax = ax1
#     twax = ax.twinx()
#     ax.plot(curr_corr, c="C0")
#     ax.plot(prev_corr, c="C1")

#     # twax.plot(curr_corr_p, c="C0", ls="--")
#     # twax.plot(prev_corr_p, c="C1", ls="--")

#     ax = ax2
#     twax = ax.twinx()
#     twax.set_yscale("log")
#     ax.plot(curr_corr_wait, c="C0")
#     ax.plot(prev_corr_wait, c="C1")

#     ax.plot(rdm_corr_curr, c="C0", ls=":")
#     ax.plot(rdm_corr_prev, c="C1", ls=":")

#     twax.plot(curr_corr_p, c="C0", ls="--")
#     twax.plot(prev_corr_p, c="C1", ls="--")
#     print(curr_corr_p)
#     print(prev_corr_p)
    


#     for ax in ax1, ax2:
#         ax.set_xticks(np.arange(0, trial_N, 2))
#         ax.set_xlabel("Trial from switch")
#         ax.set_ylabel("Correlation to PGO")
#         ax.legend(["Current PGO", "Previous PGO"])
#     plt.show()

#     # More plotting and saving data can be done similarly


#     # %%
#     # Additional plotting
#     def errorbar_plot(data, label, xlabel, ylabel, title):
#         MY = np.mean(data, axis=0)
#         SY = np.std(data, axis=0)
#         plt.errorbar(PGOs, MY, yerr=SY, fmt='o')
#         plt.title(title)
#         plt.xlabel(xlabel)
#         plt.ylabel(ylabel)
#         plt.xlim([-0.1, 1])
#         plt.show()

#     # Plotting theta vs waiting time for different conditions
#     errorbar_plot(all_trial_waits, 'o', 'Theta', 'waiting time', 'theta vs waiting time (all trials)')
#     errorbar_plot(np.squeeze(wait_from_switch[:, 0, :]), 'o', 'Theta', 'waiting time', 'theta vs waiting time (first trial)')
#     errorbar_plot(np.squeeze(np.mean(wait_from_switch[:, 0:2, :], axis=1)), 'o', 'Theta', 'waiting time', 'theta vs waiting time (first 2 trials)')


#     fig, ax = plt.subplots()


# %%
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from enzyme import PRJ_ROOT

# Assuming you have a CSV file 'Behavior_unified_curated.csv'
T = pd.read_csv(PRJ_ROOT / f"Data/mice/T.csv")
cues_per_s = 5

# Initialize variables
mouse_num = T['mouse_number'].max()
PGOs = T['PGo'].unique()
PGO_N = len(PGOs)
trial_N = 10

PGOs_sorted = np.sort(PGOs)
wait_opts_PGo_ = np.array([ 3.,  3.,  4.,  4.,  5.,  5.,  6.,  8.,  8., 11.])

p = np.polyfit(PGOs_sorted, wait_opts_PGo_, 3)
f = np.poly1d(p)
PGOs_dense = np.linspace(PGOs_sorted[0], PGOs_sorted[-1], 100)
wait_opts_PGo = f(PGOs_sorted)


all_trial_waits = np.zeros((mouse_num, PGO_N))
wait_from_switch = np.zeros((mouse_num, trial_N, PGO_N))
wait_from_switch_prev_pgo = np.zeros((mouse_num, trial_N, PGO_N))
mouse_curr_corrs = np.zeros((mouse_num, trial_N))
mouse_prev_corrs = np.zeros((mouse_num, trial_N))

# Function to safely calculate Pearson correlation
def safe_pearsonr(x, y):
    # pad the arrays with nans to make them the same length. FIX
    if len(x) > len(y):
        y = np.pad(y, (0, len(x) - len(y)), 'constant', constant_values=np.nan)
    elif len(y) > len(x):
        x = np.pad(x, (0, len(y) - len(x)), 'constant', constant_values=np.nan)

    mask = ~np.isnan(x) & ~np.isnan(y)
    if np.any(mask):
        return pearsonr(x[mask], y[mask])
    else:
        return np.nan, np.nan

# Loop over mice and PGOs
for m in range(1, mouse_num + 1):
    mouse_inds = T['mouse_number'] == m
    for p in range(PGO_N):
        PGO_inds = T['PGo'] == PGOs[p]
        inds = mouse_inds & PGO_inds
        all_trial_waits[m-1, p] = T.loc[inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
        
        for trial in range(1, trial_N + 1):
            curr_trial_inds = inds & (T['trial_number_from_switch'] == trial)
            wait_from_switch[m-1, trial-1, p] = T.loc[curr_trial_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
            
            prev_pgo_inds = mouse_inds & (T['PGo'].shift(trial) == PGOs[p]) & (T['trial_number_from_switch'] == trial)
            wait_from_switch_prev_pgo[m-1, trial-1, p] = T.loc[prev_pgo_inds, 'wait_from_last_NoGo_duration'].mean(skipna=True)
    
    for trial in range(1, trial_N + 1):
        trial_inds = (T['mouse_number'] == m) & (T['trial_number_from_switch'] == trial)
        nan_inds = T['wait_from_last_NoGo_duration'].isna()
        trial_inds[nan_inds] = False
        prev_trial_inds = trial_inds.shift(-1).fillna(False)

        curr_PGO = T.loc[trial_inds, 'PGo'].to_numpy()
        prev_PGO = T['PGo'].shift(trial)
        
        curr_PGO_ = curr_PGO
        prev_PGO_ = prev_PGO[trial_inds].to_numpy()

        curr_wait_opts = wait_opts_PGo[np.argmin(np.abs(curr_PGO_[:, None] - PGOs_sorted[None, :]), axis=1)]
        prev_wait_opts = wait_opts_PGo[np.argmin(np.abs(prev_PGO_[:, None] - PGOs_sorted[None, :]), axis=1)]

        # dep_var = T.loc[trial_inds, 'PGo'].to_numpy()
        
        r, _ = safe_pearsonr(T.loc[prev_trial_inds, 'wait_from_last_NoGo_duration'].to_numpy(), 
                              curr_wait_opts)
        mouse_curr_corrs[m-1, trial-1] = r

        # dep_var = T.loc[prev_trial_inds, 'PGo'].to_numpy()

        r, _ = safe_pearsonr(T.loc[prev_trial_inds, 'wait_from_last_NoGo_duration'].to_numpy(),
                              prev_wait_opts)
        mouse_prev_corrs[m-1, trial-1] = r

xax = np.arange(trial_N)


if __name__ == "__main__":
    # Plotting
    plt.figure()
    MY = np.mean(mouse_curr_corrs, axis=0)
    SY = np.std(mouse_curr_corrs, axis=0)
    plt.errorbar(xax, MY, yerr=SY, fmt='-o', label='Current PGO')

    MY = np.mean( mouse_prev_corrs, axis=0)
    SY = np.std( mouse_prev_corrs, axis=0)
    plt.errorbar(xax, MY, yerr=SY, fmt='-o', label='Previous PGO')

    plt.xlabel('Trial from switch')
    plt.ylabel('correlation to theta')
    plt.legend()
    plt.show()

# %%
