## Imports

# %matplotlib widget # uncomment for interactive plots
from nlb_tools.nwb_interface import NWBDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import random
import os
import tensorflow as tf

# If you want to process the data yourself, download the data from https://dandiarchive.org/dandiset/000127
# and put it in corresponding directory shown below

def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    # tf.experimental.numpy.random.seed(seed)
    # tf.compat.v2.numpy.random.seed(seed) # this is used for an older version of tensorflow
    tf.random.set_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)


set_seed(42)
# %%
## Load data
dataset = NWBDataset("Area2_Bump/000127\sub-Han", "*train", split_heldout=False)
# Smooth spikes with 40 ms std Gaussian kernel
dataset.smooth_spk(40, name='smth_40')
# %%
# Choose lag value
lag = 40
align_field = 'move_onset_time'
align_range = (-100, 500)
neur_num = 1001
# All 16 conditions, in the format (ctr_hold_bump, cond_dir)
unique_conditions = [(False, 0.0), (False, 45.0), (False, 90.0), (False, 135.0),
                     (False, 180.0), (False, 225.0), (False, 270.0), (False, 315.0)]
# unique_conditions = [(False, 0.0), (False, 90.0), (False, 180.0), (False, 270.0)]
data = []
label = []

# Loop through conditions
for idx, cond in enumerate(unique_conditions):
    # Filter out invalid trials (labeled 'none') and trials in other conditions
    cond_mask = (np.all(dataset.trial_info[['ctr_hold_bump', 'cond_dir']] == cond, axis=1)) & \
                (dataset.trial_info.split != 'none')
    # Extract relevant portion of selected trials
    cond_data = \
    dataset.make_trial_data(align_field='move_onset_time', align_range=(-100, 500), ignored_trials=~cond_mask)[
        'spikes_smth_40'].to_numpy().reshape(-1, 600, 65)
    cond_label = idx * np.ones(cond_data.shape[0])
    print(cond_label)
    data.append(cond_data)
    label.append(cond_label)
data = np.concatenate(data, axis=0)
label = np.concatenate(label, axis=0)
# %%
from sklearn.utils import shuffle
import numpy as np

# Find the minimum class sample size
_, counts = np.unique(label, return_counts=True)
min_samples = np.min(counts)

# Balance the dataset
balanced_data = []
balanced_label = []
leftover_data = []
leftover_label = []
for unique_label in np.unique(label):
    class_idx = np.where(label == unique_label)[0]
    np.random.shuffle(class_idx)  # Shuffle indices
    class_idx_balanced = class_idx[:min_samples]  # Keep only `min_samples` indices
    leftover_idx = class_idx[min_samples:]
    leftover_data.append(data[leftover_idx])
    leftover_label.append(label[leftover_idx])
    balanced_data.append(data[class_idx_balanced])
    balanced_label.append(label[class_idx_balanced])

# Convert lists to numpy arrays
balanced_data = np.concatenate(balanced_data, axis=0)
balanced_label = np.concatenate(balanced_label, axis=0)
leftover_data = np.concatenate(leftover_data, axis=0)
leftover_label = np.concatenate(leftover_label, axis=0)

# Shuffle the balanced dataset
balanced_data, balanced_label = shuffle(balanced_data, balanced_label, random_state=42)
leftover_data, leftover_label = shuffle(leftover_data, leftover_label, random_state=42)
# %%
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical

X_train, X_test, y_train, y_test = train_test_split(balanced_data, balanced_label, test_size=0.2, random_state=42)

X_test = np.concatenate((X_test, leftover_data), axis=0)
y_test = np.concatenate((y_test, leftover_label), axis=0)

np.save('Area2_Bump/trainX.npy', X_train)
np.save('Area2_Bump/testX.npy', X_test)
np.save('Area2_Bump/trainy.npy', y_train)
np.save('Area2_Bump/testy.npy', y_test)