# -*- coding: utf-8 -*-
"""
@author: anonymous
"""
import os
basic_parent = os.path.dirname(os.path.dirname(os.getcwd()))

cur_path_synth = os.getcwd()

from MILCCI_basic_functions import *
os.chdir(cur_path_synth)
print('called basic functions!!')


from MILCCI_main_functions import *

print('called basic functions!!')

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor

from sklearn.gaussian_process.kernels import RBF, WhiteKernel, ConstantKernel
def gp_shared_by_axis_labels(labels_tuples, axis_assignments, p , n_trials,  length_scale=0.1, noise_level=1e-5, seed=9, sigma=  0.15, T = 500, signal_variance = 2):
    """
    axis_assignments : list-like of length p, value in [0..num_classes-1] assigning each ensemble to a class
    returns: X (T*p, n_trials) and traces (p, n_trials, T)
    """
    
    labels_arr = np.vstack(labels_tuples)                       # shape (n_trials, num_axes)
    num_axes = labels_arr.shape[1]
    time_pts = np.linspace(0, 1, T)[:, None]                   # shape (T,1)

    # make class_num GP sampler (no fit needed for prior sampling)

    np.random.seed(seed)
    random_seeds = np.random.choice(np.arange(1, 10000), size=p*(1+num_axes)*len(list(set(labels_tuples)))*len(axis_assignments), replace=False)
    # pre-sample one GP trace per (axis, unique_label_value)
    samples = {}
    
    counter_full = 0
    for ens_num in range(p):
        for class_num in range(-1, num_axes):
            if axis_assignments[ens_num] == class_num:
                # -1 is for random traces
                if class_num != -1:
                    labels_current_class = labels_arr[:, class_num]
                else:
                    np.random.seed(ens_num - class_num)
                    labels_current_class = np.arange(labels_arr.shape[0])
                    np.random.shuffle(labels_current_class)
                    
                unique_labels_in_class, unique_label_counts = np.unique(labels_current_class, return_counts = True) # pay attention order changes here. it is ok
                
                
                for unique_label_counter, (unique_label, unique_label_count) in enumerate(zip(unique_labels_in_class, unique_label_counts)):
                    random_state_now = random_seeds[counter_full]
                    # sample from GP prior: sample_y returns shape (T, n_samples)

                    rng = np.random.default_rng(random_state_now)

                    kernel = ConstantKernel(constant_value=np.random.rand()*signal_variance/1.5 + 0.2) * RBF(length_scale = np.random.rand()*0.15+0.05 ) + WhiteKernel(1e-8)
                    gp = GaussianProcessRegressor(kernel=kernel, random_state= random_state_now+1)
                    

                    num_unique_trials_with_that_label = np.sum(labels_current_class == unique_label)
                    
              
                    K = kernel(time_pts)

                    s = gp.sample_y(time_pts, n_samples = 1, random_state = random_state_now+2).flatten() #+ offset # this is T by number of samples 
                    

                    # sample similar traces
                    s = rng.multivariate_normal(mean=s, cov=sigma**2 * K, size=num_unique_trials_with_that_label).T  # T x num_unique_trials_with_that_label
                    
    
                    samples[(class_num, unique_label, ens_num)] = s
                    counter_full += 1
                    
                    
                
    # build traces: shape (p, n_trials, T)
    counter_label_val = {} # just counting how many I have
    traces = np.zeros(( T, p, n_trials))
    for ens_num in range(p):
        class_num = axis_assignments[ens_num]
        for n_trial in range(n_trials):
            if class_num != -1: # i.e. now a general ensemble that change every trial
                label_val = labels_arr[n_trial, class_num]
            else:
                np.random.seed(ens_num - class_num)
                labels_val = np.arange(labels_arr.shape[0])
                np.random.shuffle(labels_current_class)
                label_val = labels_val[n_trial]
                
            print(samples[(class_num, label_val, ens_num)].shape)
            
            counter_now = counter_label_val.get((class_num, label_val, ens_num),0) # how many samples I had for that class
            s_now = samples[(class_num, label_val, ens_num)] # s_now are the relevant samples which are n_trials X num relevant variables
            assert s_now.shape[1] > counter_now, 's_now shape is %s, counter now is %d; class is %d'%(str(s_now.shape), counter_now, class_num)
            
            traces[:, ens_num, n_trial] = s_now[: , counter_now ]
            counter_label_val[(class_num, label_val, ens_num)] = counter_label_val.get((class_num, label_val, ens_num),0) + 1

    # build design matrix Txp x n_trials (concatenate ensembles along time)
    # For each trial t, column = concat over ensembles e of traces[e,t,:] (length T*p)
    X = traces.transpose((2,0,1))  # (T*p, n_trials)
    return X, traces, samples
    

import numpy as np
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
plt.close('all')

#%% create parameters (to change)
names_classes = ['context', 'stimulus']
num_neurons = 80
seed = 64
n_ensembles_each = 2
num_trials = n_trials = 250
min_each_class = [-2, 0]
max_each_class = [2, 1]
cont_axis = [0]
thres_0_percentile = 0.6

T =  500
normalize_A_style = 'avg'
traces_positive = True

# #%% parameters related to saving
full_date = str(datetime2.now())
ss = int(str(datetime2.now()).split('.')[-1])
full_date = full_date.replace('-','_').replace(':', '_').replace('.','_')
today = full_date.split()[0]

save_path = os.getcwd() + os.sep + 'generated_synthetic_data' + os.sep + '%s'%today 
print('save path is %s'%save_path)
os.makedirs(save_path, exist_ok = True)
save_name_metadata = 'metadata_%s'%today
to_plot = True
A_phi_same_scale = True




#%% calculations based on parameters
if thres_0_percentile < 1:
    thres_0_percentile = thres_0_percentile*100

ensembles_names = np.repeat(names_classes, n_ensembles_each)
ensembles_names2index = {class_name : np.where(ensembles_names == class_name)[0] for class_name in names_classes}
index2ensemble_names = {ind : ens for ind, ens in enumerate(ensembles_names)}


num_axes = num_classes = len(names_classes)
np.random.seed(seed)
labels_each_class = {class_name : np.random.randint(min_each_class[j], 1+ max_each_class[j] , size = num_trials) 
                     for j, class_name in enumerate(names_classes)}
labels_tuples = [tuple((labels_each_class[class_name][trial] for class_name in names_classes)) for trial in range(num_trials)]
labels_unique_order = make_labels_unique_order(labels_tuples, make_array = False)
# link label trial number to index in labels unique order
trial_number2index_in_label_unique_order = {'trial_%d'%trial_num : find_indices_in_list(labels_unique_order, label_tup) for trial_num, label_tup in enumerate(labels_tuples)} # i.e. this is the layer we need to take from A


num_unique_labels_per_class = {class_name : len(np.unique(np.vstack(labels_tuples)[:, class_num])) for class_num, class_name in enumerate(names_classes)}
labels_tuples_unique = labels_unique_order #sorted(list(set(labels_tuples)))
values_unique_labels_each_var = {class_name : np.sort(np.unique(np.vstack(labels_tuples_unique)[:,class_num])) for class_num, class_name in enumerate(names_classes)}

n_ensembles = num_axes * n_ensembles_each
num_unique_labels = len(labels_tuples_unique)

#%% regularization parameters (also can change, but depends on the above)

noise_sigma_between_As = 0.1 # this is how much we alow the ensembles to change... if higher - more change
nu_basic = 1/noise_sigma_between_As*0.01
nu = [nu_basic*n_ensembles]
params_basis_pattern = {'wind_size': 5, # wind size defines how many trials to take before and after (overall wind_size*2+1 trials)
                        'weight_func':'linear', # how the weight is distributed across trials
                        'weight_min': 0, #'value_nu_fixed', #'nu', 
                        'weight_max': nu_basic,
                        'one_or_two_sides': -1 # if -1 it means that we connect to the former one. if 2 - two sides. if 1 -> connect to future
                        }


#%% create first ensemble layer
np.random.seed(seed + 89)

#%% pay attention - this is the full A, not the one I expect from multi MILCCI output, but trial wise
A_full_GT = np.zeros((num_neurons, n_ensembles, n_trials)) #num_unique_labels))
A_full_GT[:, :, 0] = np.random.rand(*A_full_GT[:, :, 0].shape)*0.5+0.5
thres_0 = np.percentile(A_full_GT[:, :, 0], thres_0_percentile)
A_full_GT[:, :, 0][A_full_GT[:, :, 0] < thres_0] = 0 # this is A by trials


# now make the multi MILCCI A (less deep. )
A_tensor_multi_MILCCI = np.zeros((num_neurons, n_ensembles, num_unique_labels))


A_dict_tensors = {class_name: A_full_GT[:, ensembles_names2index[class_name] , : num_unique_labels_per_class[class_name] ] for class_name in names_classes} # create a dict of subset of ensembles per A. 

if to_plot:
    fig, ax = plt.subplots()
    sns.heatmap(A_full_GT[:, :, 0], center = 0)
    save_fig('A_one_layer',fig,save_path)
    
#%% create labels 
numbers2tuples = {j : tup for j, tup in enumerate(labels_tuples_unique)}
tuples2numbers = {tup : j for j, tup in numbers2tuples.items()}
labels = np.array([tuples2numbers[tup] for  tup in labels_tuples])


#%% Now change A with specific directions
## capture graph for variables. i.e. a graph of trial-trial similarity
basis_pattern, params_basis_pattern, label_distance_to_basis_pattern_values = create_basis_patterns(labels, 
                          numbers2tuples, 
                          cont_labels = [], cont_axis_list = [0],                          
                          params_basis_pattern = params_basis_pattern, 
                          value_nu_fixed = 1, disable_assert = True)


trial_distance_graph = {class_name :  calculate_graph_similarity(label_class) for class_name, label_class in values_unique_labels_each_var.items()} # this will be used to define how much I will change A. 
trial_similarity_graph = {class_name :0.1 + 2*np.vectorize(label_distance_to_basis_pattern_values.get)(distance_graph) for class_name, distance_graph in trial_distance_graph.items()}

#%% now start updating next layers based on the labels graph. 
repeats = 10
for repeat in range(repeats):
    print('repeat %d'%repeat)
    for class_num, class_name in enumerate(names_classes):
        # 
        cur_A = A_dict_tensors[class_name]
        num_layers = cur_A.shape[2]
        
        
        for layer in range(num_layers):
            cur_graph = trial_similarity_graph[class_name][layer,:] # this is a vector. this defines the levels of similarity we want
            cur_graph[layer] = np.max(cur_graph)*2
            
            assert (cur_graph != 0).all(), 'there should not be 0 in cur_Graph'
            assert len(cur_graph) == num_layers
            assert cur_graph.sum() > 0
            if layer == 0 and repeat == 0:  # i.e. this one is already updated 
                assert (cur_A[:,:,layer] != 0).any(), 'something is wrong, how you get 0?'
                next_A = cur_A[:,:,layer].copy()
            elif repeat == 0: # that means we still have 0s, but other layers
                # take all As I already updated from the other layers, and interpolate them
                assert (cur_A[:,:,layer] == 0).all(), "cur_A must be of 0s!"
                former_layers = np.arange(layer-1+1)
                other_As = cur_A[:, :, former_layers]
                other_As_weights = cur_graph[former_layers]
                
                assert other_As_weights.sum() > 0
                
                
                next_A = np.sum(other_As*(other_As_weights/other_As_weights.sum()).reshape((1,1,-1)), 2)
                np.random.seed(repeat + 5*class_num + layer**2)
                next_A = np.abs(next_A + np.random.randn( *next_A.shape )*noise_sigma_between_As)
                #cur_A[:,:,layer] = next_A
                assert not np.isnan(next_A).any(), 'nan in A1!'
            else:
                # now reweigh enesembles
                
                next_A = np.sum(cur_A*(cur_graph/cur_graph.sum()).reshape((1,1,-1)), 2)
                np.random.seed(repeat + 5*class_num + layer**2)
                next_A = next_A + np.random.randn( *next_A.shape )*noise_sigma_between_As
                np.isnan(next_A)
                next_A = np.abs(next_A)           
                
                assert not np.isnan(next_A).any(), 'nan in A1!'
                
            thres_0 = np.percentile(next_A, thres_0_percentile)
            next_A[next_A < thres_0] = 0 
            cur_A[:,:,layer] = next_A
                
        A_dict_tensors[class_name] = cur_A
       
            
if to_plot:            
    fig, ax = plt.subplots(num_classes, A_dict_tensors[names_classes[0]].shape[2], figsize = (40,5*num_classes))
    for class_num, class_name in enumerate(names_classes):      
        [sns.heatmap(A_dict_tensors[class_name][:,:,j], ax = ax[class_num, j]) for j in range(A_dict_tensors[class_name].shape[2])]
    fig.tight_layout()
    save_fig('As_per_class_changes_over_trials', fig, save_path)
    
if to_plot:
    fig, ax = plt.subplots(num_classes, n_ensembles_each, figsize = (20, 5*num_classes)) 
    for class_num, class_name in enumerate(names_classes):       
        [sns.heatmap(A_dict_tensors[class_name][:,j,:], ax = ax[class_num, j], center = 0) for j in range(A_dict_tensors[class_name].shape[1])]
    fig.tight_layout()
    save_fig('As_per_class_changes_over_trials', fig, save_path)





# Call the function
np.random.seed(544)
axis_assignments = np.repeat(np.arange(num_classes), n_ensembles_each)
axis_assignments[1], axis_assignments[2] = axis_assignments[2],axis_assignments[1]
axis_assignments[-1] = -1


X, traces, samples = gp_shared_by_axis_labels(labels_tuples, axis_assignments, p = n_ensembles, n_trials = n_trials, length_scale=0.1, noise_level=1e-6, seed=2)

print("X shape:", X.shape)        # (T*p, n_trials)
print("traces shape:", traces.shape)  # (p, n_trials, T)


#%%
unique_labels_from_tuples = {class_name  :  np.unique(np.vstack(labels_tuples)[:, class_num]) for class_num, class_name in enumerate(names_classes)}
optional_colors = ['r','g','b','orange']
cmaps = ['hsv', 'gist']
color_dict =  {class_name  :
               {el : col for el, col in zip(unique_labels_from_tuples[class_name], create_colors(len(unique_labels_from_tuples[class_name])+1 , cmap = 'hsv', style = 'cmap')[:len(unique_labels_from_tuples[class_name] )]) }
               for class_num, class_name in enumerate(names_classes)}
               
num_classes =  2



#%% normalize A
A_dict_tensors = {class_name : normalize_A_columns(A_now, normalize_A_style = normalize_A_style, epsilon = 10**(-9) )[0] for class_name, A_now  in A_dict_tensors.items()}



#%%  create full A

for trial_num, label_tup in enumerate(labels_tuples):
    
    # build \tilde{A}
    tilde_A_now = []
    
    for class_num, class_name in enumerate(names_classes):
        class_val = label_tup[class_num] # what is the value of that class? e.g. what is the odor?
        cur_class_unique_values = values_unique_labels_each_var[class_name] # this is the converion to As
        index_current_class_val = np.where(cur_class_unique_values == class_val)[0]
        assert len(index_current_class_val) == 1
        cur_A_part =  A_dict_tensors[class_name][:,:,index_current_class_val[0]]
        tilde_A_now.append(cur_A_part)
    tilde_A_now = np.hstack(tilde_A_now)
    assert tilde_A_now.ndim == 2
    assert tilde_A_now.shape[1] == n_ensembles
    A_full_GT[:,:,trial_num] = tilde_A_now
    
    # check what depth is label_tup in labels_unique_order
    label_depth_index_for_A = trial_number2index_in_label_unique_order['trial_%d'%trial_num]
    assert len(label_depth_index_for_A) == 1, 'something is wrong...'
    label_depth_index_for_A = label_depth_index_for_A[0]
    if np.sum(A_tensor_multi_MILCCI[:, :, label_depth_index_for_A] ) != 0:
        print('found already')
        assert np.all(A_tensor_multi_MILCCI[:, :, label_depth_index_for_A] == tilde_A_now), 'something is off?'
    else:
        A_tensor_multi_MILCCI[:, :, label_depth_index_for_A] = tilde_A_now
  
    
if traces_positive:
    if (traces.flatten() < 0).any():
        traces = traces - np.min(traces)
    
if A_phi_same_scale:
    A_full_GT_max = np.percentile(A_full_GT,98)
    traces = traces / np.percentile(traces,98)*A_full_GT_max
    
    
    
if to_plot:    
    fig, axs = plt.subplots(num_classes, n_ensembles, figsize = (30, 10))
    keys_samples = samples.keys() # form is class_num, unique_label, ens_num
    for class_num , class_name in enumerate(names_classes):
        create_legend(color_dict[class_name], save_path = save_path, save_addi = 'legend_ensembles_%s'%class_name, params_leg = {'title': class_name})
        for ens in range(n_ensembles):
            for trial in range(n_trials):
                label_now = labels_tuples[trial][class_num] 
                color_now = color_dict[class_name][label_now]
                axs[class_num , ens].plot(traces[:, ens, trial], color = color_now )
                add_labels(axs[class_num , ens], xlabel = 'Time', ylabel = 'Traces (colored by %s)'%class_name, title = '$A_{%d}$; %s'%(ens, class_name))
    [remove_edges(ax) for ax in axs.flatten()]
    fig.tight_layout()
    save_fig('traces_colored_by_Class', fig, save_path, formats = ['png'])
    

if to_plot:
    # plot correlation between ensemble traces
    corr_between_ensembles_list = []
    for trial in range(traces.shape[2]):
        corr_between_ensembles = np.corrcoef(traces[:,:, trial].T)
        corr_between_ensembles -= np.eye(traces.shape[1])
        corr_between_ensembles_list.extend(list(corr_between_ensembles.flatten()))
    
    fig, ax = plt.subplots()
    sns.histplot(corr_between_ensembles_list, ax = ax, alpha = 0.5) 
    add_labels(ax, xlabel = 'Correlation between ensembles', ylabel ='# of correlations', title = 'Histogram of correlation of ensemble traces, should be lower than 1!')
    save_fig('correlations_between_ensemble_traces', fig, save_path)
    
    
    
#%% Calculate Sparsity and Similarity Level
sparsity_thres = 0.001
real_sparsity_per_layer =  [(np.abs(A_full_GT[:, :, layer]) < sparsity_thres).sum() for layer in range(A_full_GT.shape[2])]
real_sparsity_overall =  (np.abs(A_full_GT) < sparsity_thres).mean()
real_l2_distance_between_layers = np.vstack([[ np.mean((A_full_GT[:, :, layer] - A_full_GT[:, :, layer2])**2) for layer in range(A_full_GT.shape[2])]
                                     for layer2 in range(A_full_GT.shape[2])])  # this is a graph of how similar layers are
real_l2_distance_layer_and_condition =  {key: np.vstack([[ np.mean(  (val[:,:,layer] - val[:,:,layer2])**2) for layer in range(val.shape[2]) ]
                                               for layer2 in range(val.shape[2]) ])                                   
                                               for key, val in A_dict_tensors.items()}


#%% create data
Y = np.zeros((num_neurons, T, n_trials))
for trial_num, label_tup in enumerate(labels_tuples):
    Y[:,:,trial_num] = A_full_GT[:,:, trial_num] @ traces[:, :, trial_num].T

if to_plot:
    num_trials_plot = 20
    num_rows = 4
    num_cols = int(np.ceil(num_trials_plot/num_rows))
    
    fig, axs = plt.subplots(num_rows, num_cols, figsize = (30,30), sharex = True, sharey = True)
    axs = axs.flatten()
    [sns.heatmap(Y[:,:,trial_num], ax = axs[trial_num], center = 0 ) for trial_num in range(num_trials_plot)]
    save_fig('data', fig, save_path, formats = ['png'])

terms = ['neuron_%d'%el for el in range(num_neurons)]

#%% save data
metadata_dict = {'path_directory': os.getcwd(), 
    'cur_path_synth': cur_path_synth, 'save_path': save_path,
    'names_classes': names_classes, 'num_neurons': num_neurons, 'seed': seed,
    'n_ensembles_each': n_ensembles_each, 'num_trials': num_trials, 'n_trials': n_trials,
    'min_each_class': min_each_class, 'max_each_class': max_each_class, 'cont_axis': cont_axis,
    'thres_0_percentile': thres_0_percentile, 'T': T, 'traces_positive': traces_positive,
    'full_date': full_date, 'ss': ss, 'today': today, 'save_name_metadata': save_name_metadata,
    'to_plot': to_plot, 'ensembles_names': ensembles_names,
    'ensembles_names2index': ensembles_names2index, 'index2ensemble_names': index2ensemble_names,
    'labels_each_class': {k: v.tolist() for k,v in labels_each_class.items()}, 'labels_tuples': labels_tuples,
    'num_unique_labels_per_class': num_unique_labels_per_class, 'labels_tuples_unique': labels_tuples_unique,
    'values_unique_labels_each_var': {k: v.tolist() for k,v in values_unique_labels_each_var.items()},
    'length_scale': 0.1, 'noise_level': 1e-5, 'sigma': 0.15,
    'nu_basic': nu_basic, 'nu': nu, 'params_basis_pattern': params_basis_pattern, 'noise_sigma_between_As': noise_sigma_between_As,
    'axis_assignments': axis_assignments,
    'n_ensembles': n_ensembles, 'num_axes': num_axes, 'num_classes': num_classes, 'num_unique_labels': num_unique_labels,
    'A_full_GT_shape': A_full_GT.shape,
    'basis_pattern': basis_pattern ,
    'label_distance_to_basis_pattern_values': label_distance_to_basis_pattern_values,
    'trial_distance_graph': trial_distance_graph, 'trial_similarity_graph': trial_similarity_graph,
    'X_shape': X.shape , 'traces_shape': traces.shape, 'Y_shape': Y.shape,
    'unique_labels_from_tuples': unique_labels_from_tuples,
    'optional_colors': optional_colors, 'color_dict': color_dict, 'A_dict_tensors':A_dict_tensors, 'traces':traces, 'full_A_GT': A_full_GT,
    'trial_number2index_in_label_unique_order':trial_number2index_in_label_unique_order, 
    'real_sparsity_per_layer': real_sparsity_per_layer, 
    'real_sparsity_overall': real_sparsity_overall,
    'real_l2_distance_between_layers': real_l2_distance_between_layers , # this is a graph of how similar layers are
    'real_l2_distance_layer_and_condition': real_l2_distance_layer_and_condition,
    'sparsity_thres':sparsity_thres, 'A_phi_same_scale':A_phi_same_scale,
    'normalize_A_style': normalize_A_style
    }

data_dict = {'data':Y, 'full_phi':traces, 'full_A': A_full_GT, 'labels':labels, 'numbers2tuples':numbers2tuples, 'tuples2numbers':tuples2numbers, 
             'labels_tuples':labels_tuples, 'terms':terms, 'n_ensembles_each':n_ensembles_each, 
             'labels_unique_order':labels_unique_order   ,          # these are the labels for A!
             'A_tensor_multi_MILCCI': A_tensor_multi_MILCCI # this is N_neurons X ensembles X num_unique_labels
             }

data_full_dict_to_save = {'data': data_dict, 'metadata':metadata_dict}


np.save(save_path + os.sep + 'data_and_metadata_%s.npy'%today , data_full_dict_to_save)
print('saved synthetic data in %s'%(save_path + os.sep + 'data_and_metadata_%s.npy'%today))













