# -*- coding: utf-8 -*-
"""
options for the kernel in grannet: "kernel_grannet_type"
    - "one_kernel"
    - "averaged" - average with others (then we need another param - 'params_weights')
    - "combination" - combination of shared kernel and individual kernels for different states
    - "ind" - independent kernels
    need to change:
        - whenever mkDataGraph appear for grannet
        - lambda calculation
     type_kernel :
    - 'shared_flex' =     same kernel for different trials but different nets (there is variability in A across trials)
     - 'shared' =  same kernel same nets. Same kernel for all trial of the state and also exactly the same A
    

"""


#%% Imports
global ask_selected
from scipy import ndimage
try:
    from skimage.morphology import square
    from skimage import  io
except ImportError:
    import os
    if os.path.expanduser("~") + os.sep in os.getcwd():
        print('Ignored ImportError for from skimage.morphology import square (home directory)')
    else:
        raise
try:
    from skimage.morphology import erosion, opening, dilation
except ImportError:
    import os
    if os.path.expanduser("~") + os.sep in os.getcwd():
        print('Ignored ImportError for from skimage.morphology import erosion, opening, dilation (home directory)')
    else:
        raise
from PIL import Image
from scipy.signal import convolve2d
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import jaccard_score
import os
try:
    import mat73
except ImportError:
    if os.path.expanduser("~") + os.sep in os.getcwd():
        print('Ignored ImportError for import mat73 (home directory)')
    else:
        raise
import scipy.io as sio
from scipy.optimize import nnls

from sklearn.decomposition import PCA
from scipy.optimize import nnls

try:
    from qpsolvers import solve_qp #https://scaron.info/doc/qpsolvers/quadratic-programming.html#qpsolvers.solve_qp https://pypi.org/project/qpsolvers/
except ImportError:
    if os.path.expanduser("~") + os.sep in os.getcwd():
        print('Ignored ImportError for from qpsolvers import solve_qp (home directory)')
    else:
        raise

import matplotlib
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
import itertools
import pandas as pd
import seaborn as sns
import random
from datetime import date
import os.path
import warnings
from scipy.optimize import nnls
import numbers
from sklearn import linear_model
import sys
from sklearn.decomposition import DictionaryLearning, sparse_encode


######################################################################
change_path = False
    
from MILCCI_basic_functions import *
#import basic_functions  
#basic_functions = reload(basic_functions)
#from basic_functions import *

print('called basic functions!')
######################################################################
    

ask_selected = False
in_local = True
try:
    import pylops
except:
    print('did not load pylops')
#from PIL import Image


import networkx as nx
from datetime import datetime as datetime2


    
    
ss = int(str(datetime2.now()).split('.')[-1])
print(ss)
seed = ss 

print(seed)

np.random.seed(seed)


from sklearn.model_selection import train_test_split, cross_val_score, cross_val_predict, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np
import random


from sklearn.utils import resample

from scipy.stats import chi2_contingency
from sklearn.metrics import confusion_matrix


from scipy.optimize import linear_sum_assignment
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np
from sklearn.metrics import balanced_accuracy_score

def transpose_dict(dict_in_dict):
    new_dict = {}
    
    for key_parent, value_parent in dict_in_dict.items():
        assert isinstance(value_parent, dict), "inner component must be a dictnionary"
        for key_child, value_child in value_parent.items():
            if  key_child not in new_dict:
                new_dict[key_child] = {}
            new_dict[key_child][key_parent] = value_child
    ##################################################################################  
    return new_dict



from scipy.spatial.distance import cdist
def best_column_alignment_score_for_multi_MILCCI(A_real, phi_real, 
                                                    A_hat, phi_hat, metric='cosine', 
                                                    single_level = True,  
                                                    ensembles_names  = [], max_cost = 10**9,                                                    
                                                    labels_tuples = []
                                                    ): 
    # this aligns the columns in B to align to A
    # this is for multi MILCCI and calculated via the mid reoncsturctions. 

    assert len(A_hat) > 0 and len(phi_hat) > 0, "calculate_reco_multi_MILCCI_given_labels_tuples(A, phi, labels_tuples, unique_labels_tuples = [])"        # make sure you have phi_hat and A_hat in this case
    assert A_hat.shape[1] == phi_hat.shape[1], 'ensemble structure must match but %d vs %d'%(A_hat.shape[1] , phi_hat.shape[1])
    # check if A is full. 
    A_hat_is_full = A_hat.shape[2] == phi_hat.shape[2]
    if not A_hat_is_full:
        A_hat = make_A_hat_full(A_hat, labels_tuples)
    #Y_hat = calculate_reco_multi_MILCCI_given_labels_tuples(A_hat, phi_hat, labels_tuples = labels_tuples)
    assert A_hat.shape[1] == phi_hat.shape[1]

    assert len(A_real) > 0 and len(phi_real) > 0, "calculate_reco_multi_MILCCI_given_labels_tuples(A, phi, labels_tuples, unique_labels_tuples = [])" 
    assert A_real.shape[1] == phi_real.shape[1], 'ensemble structure must match but %d vs %d'%(A_real.shape[1] , phi_real.shape[1])
    # check if A is full
    A_real_is_full = A_real.shape[2] == phi_real.shape[2]
    if not A_real_is_full:
        A_real = make_A_hat_full(A_real, labels_tuples)
    assert A_real.shape[2] == phi_real.shape[2]
    assert A_real.shape[1] == phi_real.shape[1]
    
    assert A_hat.shape[0] == A_real.shape[0], 'n_neurons mismatch %d vs %d'%( A_hat.shape[0] , A_real.shape[0])
    
    
    ############################################ start calculating
    n_trials = phi_real.shape[2]
    n_ensembles= phi_real.shape[1]
    # calculate Y real and Y_hat for each component
    if single_level: 
        cost_matrix = np.zeros((n_ensembles, n_ensembles, n_trials))
        
    else:
        cost_matrix = np.zeros((n_ensembles, n_ensembles))
        
    memo_Y_hat = {}
    for  ens in range(n_ensembles):           
        for trial in range(n_trials):                 # this is the real
            Y_real_ens = A_real[:, ens, trial].reshape((-1,1)) @ phi_real[:, ens, trial].reshape((1,-1))
            #Y_ens.append(Y_real_ens )
            for ens_change in range(n_ensembles):    
                if trial == 0 and  ensembles_names[ens] != ensembles_names[ens_change]: 
                    #### first make sure that for the first time we set these to max_cost
                    cost_matrix[ens, ens_change] = max_cost
                elif  ensembles_names[ens] != ensembles_names[ens_change]: 
                    pass
                else:       
                    # calculate Y_hat
                    if (trial, ens_change) not in memo_Y_hat:
                        memo_Y_hat[(trial, ens_change)] = A_hat[:, ens_change, trial].reshape((-1,1)) @ phi_hat[:, ens_change, trial].reshape((1,-1))
                    
                    if metric.lower() == 'l2':
                        distance = np.mean((memo_Y_hat[(trial, ens_change)] - Y_real_ens)**2)
                    elif metric.lower() == 'cosine':
                        distance = cdist(memo_Y_hat[(trial, ens_change)].flatten() , Y_real_ens.flatten())
                    else:
                        raise ValueError('undefined distance metric')
                        
                    if single_level:
                        cost_matrix[ens, ens_change, trial] +=  distance
                    else:
                        cost_matrix[ens, ens_change] +=  distance
                        
    if single_level: 
        col_ind_list = []
        A_hat_full_changed = []
        phi_hat_full_changed = []
        new_score = 0
        for trial in range(n_trials):  
            row_ind, col_ind = linear_sum_assignment(cost_matrix[:,:, trial])
            A_hat_full_changed.append(A_hat[:,col_ind, trial])
            phi_hat_full_changed.append(phi_hat[:, col_ind, trial])
            col_ind_list.append(col_ind)
            
            new_score += cost_matrix[:,:, trial][row_ind, col_ind].sum()
            
        A_hat_full_changed = np.dstack(A_hat_full_changed)
        phi_hat_full_changed = np.dstack(phi_hat_full_changed)
        
        return  A_hat_full_changed, phi_hat_full_changed , col_ind_list, cost_matrix, new_score
            
    else:        
        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        A_hat_full_changed = A_hat[:,col_ind, :]
        phi_hat_full_changed = phi_hat[:, col_ind, :]
        new_score = cost_matrix[row_ind, col_ind].sum()
        return  A_hat_full_changed, phi_hat_full_changed , col_ind, cost_matrix, new_score
        
            
    
    
    
    

def best_column_alignment_score(A, B, metric='cosine', n_shuffles = 100, what_to_reorder = 'rows' ):
    # metric should be corr, cosine, l2
    # example: 
    # metric can be cosine or values
    # pay attention! this function changes rows, now columns
    assert metric in ['cosine', 'l2', 'corr'], "metric undefined!"
    if  what_to_reorder in [ 'cols' , 'columns']:
        A = A.T
        B = B.T
        
        
    # Compute cost matrix
    if metric == 'cosine':
        cost = cdist(A, B, metric=metric)
    elif metric == 'l2':
        cost = np.vstack([((B - A[row].reshape((1,-1)))**2).mean(1).reshape((1,-1)) for row in A])
    elif metric == 'corr':
        cost = 1-np.vstack([[spec_corr(row_B,row_A, to_abs = False) for row_B in B]
                          for row_A in A])
        
    else:
        raise ValueError('not recgnonized: %s'%metric)


    # Hungarian algorithm to find best matching (minimize cost)
    row_ind, col_ind = linear_sum_assignment(cost)

    
    shuffled_costs_for_comparison = []
    for n_shuffle in range(n_shuffles): 
        np.random.seed(n_shuffle)
        shuffled_col_ind = np.random.permutation(col_ind.copy())
        #.append(shuffled_col_ind)
        shuffled_costs_for_comparison.append(cost[row_ind, shuffled_col_ind].mean())

        
    # alignment score: average similarity (or inverse of cost)
    matched_cost = cost[row_ind, col_ind].mean()

    mean_shuffled = np.mean(shuffled_costs_for_comparison)
    std_shuffled = np.std(shuffled_costs_for_comparison)


    z = (matched_cost - mean_shuffled) / std_shuffled  # lower is better → negative z is better
    pval_permutations = norm.cdf(z)        # one-sided p-vaue
    
    sorted_mat = B.iloc[np.array(col_ind)] if isinstance(B, pd.DataFrame) else B[np.array(col_ind)]
    
    if  what_to_reorder in [ 'cols' , 'columns']:
        sorted_mat = sorted_mat.T
        #A = A.T
        #B = B.T
    return col_ind, matched_cost, shuffled_costs_for_comparison, pval_permutations, sorted_mat 




def update_locals_with_dict(d):
    for key, value in d.items():
        if isinstance(value, list):  # Handle lists or arrays as needed
            exec(f"{key} = {repr(value)}")  # Dynamically assign the values
        elif isinstance(value, np.ndarray):  # For numpy arrays
            exec(f"{key} = np.array({repr(value)})")  # Ensure numpy is used
        else:
            exec(f"{key} = {repr(value)}")  # For other types


def project_data_to_As(data_known, full_A_hat_tensor_known, labels = [], labels_unique_order = [], lambda_l2 = 0.01):
    ############# 
    # pay attention!
    # data is N X T
    # labels - organization of the data
    # labels_unique_order - organization of A
    if  full_A_hat_tensor_known.ndim == 2 or full_A_hat_tensor_known.shape[2] == 1: 
        ################################################################################ OPTION 1: 2D FOR BOTH
        if data_known.ndim == 2 or data_known.shape[2] == 1:
            if full_A_hat_tensor_known.ndim > 2:
                full_A_hat_individual_tensor = full_A_hat_tensor_known[:,:,0]
            if data_known.ndim > 2:
                data_known = data_known[:,:,0]
            assert data_known.shape[0] == full_A_hat_tensor_known.shape[0]
            #############
            # solve: data = A @ phi (left = right @ phi)
            T = data_known.shape[1]
            p = full_A_hat_tensor_known.shape[1]
            left = np.vstack([data_known, np.zeros((p, T))])
            right = np.vstack([ full_A_hat_individual_tensor, lambda_l2 * np.eye(p) ])
            return np.linalg.pinv(right) @ left 
        
        elif data_known.ndim > 2 and data_known.shape[2] > 1:
            ######################################################################### OPTION 2: 3D FOR DATA, 2D FOR A
            return_phis = []
            for layer in range(data_known.shape[2]):
                cur_data = data_known[:,:,layer]
                cur_A = full_A_hat_tensor_known
                return_phis.append(project_data_to_As(cur_data, cur_A, labels =labels, labels_unique_order = labels_unique_order, lambda_l2 = lambda_l2))
            return np.dstack(return_phis)
            
            
        else:
            print('how?!')
        
    else: 
        assert len(labels) > 0
        assert len(labels_unique_order) > 0
        assert data_known.shape[0] == full_A_hat_tensor_known.shape[0]
        T = data_known.shape[1]
        n_ensembles = full_A_hat_tensor_known.shape[1]
        return_phis = np.zeros((n_ensembles, T,  data_known.shape[2]))
        for unique_label in labels_unique_order:
            ###### where A?
            ind_A = np.where(labels_unique_order == unique_label)[0]
            cur_A = full_A_hat_tensor_known[:,:,ind_A]

            ##### where data?
            ind_data = np.where(labels == unique_label)[0]
            cur_data = data_known[:,:,ind_data]

            assert cur_data.shape[2] >= cur_A.shape[2]
            assert cur_A.shape[2] == 1

            #return_phis.append(project_data_to_As(cur_data, cur_A, labels, labels_unique_order, lambda_l2 ))
            cur_phis = project_data_to_As(cur_data, cur_A, labels, labels_unique_order, lambda_l2 )

            if cur_phis.ndim == 2:
                cur_phis = np.expand_dims(cur_phis,2)
            return_phis[:,:, ind_data] = cur_phis

        return return_phis
    
    
def keep_one_last_key_option(some_dict, to_include=[], to_match=[]):
    """
    Filters a nested dict (dict of dicts of dicts...) and keeps only the branches 
    where the **second-to-last** key matches or includes a target string.

    - If `to_match` is not empty, return only items where the second-to-last key exactly matches the string.
    - Else, if `to_include` is not empty, return only items where the second-to-last key includes the string.
    """
    # this function gets dicts of dicts of dicts. e.g. some_dict = {key1: {key2: {key4: {....}}, key3: {key4: {...} }}}
    # it returns a dict where the last key in the hierarchy of the dict (i.e. each dict has the same length) either matches or includes the need
    # if to_match is not empty — prioritize to_match
    # if empty — check if the key includes
    # if not empty, both to_match and to_include are expected to be strings

    if isinstance(to_match, str) and to_match:
        to_match = [to_match]
    elif not to_match:
        to_match = []
        
    if isinstance(to_include, str) and to_include:
        to_include = [to_include]
    elif not to_include:
        to_include = []

    def filter_leaf(d):
        if not isinstance(d, dict):
            return d  # base case

        result = {}
        for k, v in d.items():
            if isinstance(v, dict) and all(not isinstance(subv, dict) for subv in v.values()):
                # this is the second-to-last level
                if to_match and k in to_match:
                    result[k] = v
                elif to_include and any(sub in k for sub in to_include):
                    result[k] = v
                elif not to_match and not to_include:
                    result[k] = v
            else:
                filtered = filter_leaf(v)
                if filtered:
                    result[k] = filtered
        return result

    return filter_leaf(some_dict)



    
    
    
    
    

def identify_predictive_power(traces, labels, t_start = None, t_end = None, num_pieces_for_predict = 5, prior_name = '', 
                              prior_info = [], solver = 'logreg', pieces_borders = [],  with_or_ablation = 'with', include_all = False):
    """
    traces - is a matrix of time X ensembles X trials
    labels = list of labels (length = number of trials)
    """
    if 'without' in prior_name:
        prior_name = ''
        
    prior_info = np.array(prior_info)
    assert 'with' not in prior_name or len(prior_info) > 0
    n_trials = len(labels)
    assert traces.shape[2] == len(labels), "number of trials from traces need to match labels duration. But from traces we have %d trials vs %d in labels"%(traces.shape[2], len(labels))
    

    if t_start is None:
        t_start = 0
    if t_end is None:
        t_end = traces.shape[0]
        
    dur = t_end - t_start
    ####### take traces in time
    traces_now = traces[t_start: t_end, :,:]
    
    ###### take pieces
    if len(pieces_borders) == 0:
        pieces_borders = np.linspace(0, dur, num_pieces_for_predict + 1).astype(int)
    else:
        assert len(pieces_borders) == num_pieces_for_predict + 1, "len(pieces_borders) != num_pieces_for_predict + 1"
    n_ensembles = traces_now.shape[1]
    
    ##### avg within each
    data_traces =  np.vstack([traces_now[piece_minus : piece_plus,:,:].mean(0) for piece_minus, piece_plus in zip(pieces_borders[:-1],pieces_borders[1:])]) # this is a mat of (ensembes*windows) X (trials) which is features X trials
    
    assert data_traces.shape[0] == n_ensembles * num_pieces_for_predict
    assert data_traces.shape[1] == n_trials
    
    ########################## add prior
    if 'with' in prior_name:
       
        # add info about the box
        #################
        assert len(prior_info) == n_trials
        
        data_traces = np.vstack([data_traces, prior_info.reshape((1,-1))])
        
    
    
    
    #################
    
    
    ##### 
    # predict the label for each trial. data is of shape feature X samples. 
    # prediction_scores  is a dict of y_real, y_pred, accuracy, and conf_mat values, 'xlabel_confmat', 'ylabel_confmat'
    
    ##################
    ensembles_list = np.tile(np.arange(n_ensembles), num_pieces_for_predict)
    t_meaning = np.repeat(np.arange(num_pieces_for_predict), n_ensembles)
    features_names = np.vstack([(ens, t) for ens,t in zip(ensembles_list, t_meaning)])
    if 'with' not in prior_name:
        assert data_traces.shape[0] == len(ensembles_list)
    else:
        features_names = np.vstack([features_names, np.array(['prior', 'prior']).reshape((1,-1)) ])
        assert data_traces.shape[0] == len(ensembles_list) + 1

    #########################
    prediction_scores = {}
    
    if include_all:
        start_ens = -1
    else:
        start_ens = 0
    for ens in range(start_ens, n_ensembles):
        if ens != -1:
            if  with_or_ablation == 'with':
                indices_ens = np.where(ensembles_list == ens)[0]
            elif with_or_ablation == 'ablation':
                indices_ens = np.where(ensembles_list != ens)[0]
            else:
                raise ValueError('undefined %s'%with_or_ablation)
        else:
            indices_ens = ensembles_list.copy()
        
        if 'with' in prior_name:
            indices_ens = np.array(list(indices_ens) + [-1])
            
        data_traces_now = data_traces[indices_ens]
        features_names_now = features_names[indices_ens]       
        
        #if 'with' not in prior_name:
        if ens != -1:
            name_key = '$A_%d$'%ens
        else :
            name_key = 'ALL'
        
        prediction_scores[name_key] = predict_traces_predictive_power(data_traces_now, labels, features_names = features_names_now, solver = solver.lower())
        
        ############## now integrate the prior
            
    
    
    return prediction_scores 



def predict_traces_predictive_power(data_traces, labels, cv=5, max_iter=1000, features_names=[], solver='logreg'):
    """
    data_traces: features x samples (trials)
    labels: list or array of true labels for each trial
    cv: number of folds for cross-validation
    max_iter: maximum number of iterations for the logistic regression solver
    """

    if solver == 'logreg':
        clf = LogisticRegression(max_iter=max_iter, class_weight='balanced')
    elif solver == 'rf':
        clf = RandomForestClassifier(class_weight='balanced', n_estimators=100, random_state=42)
    else:
        raise ValueError('undefined solver')

    skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=42)
    y_pred = np.zeros_like(labels)
    importances = []

    for train_idx, test_idx in skf.split(data_traces.T, labels):
        X_train, X_test = data_traces[:, train_idx].T, data_traces[:, test_idx].T
        y_train = np.array(labels)[train_idx]
        clf.fit(X_train, y_train)
        y_pred[test_idx] = clf.predict(X_test)

        if solver == 'logreg':
            importances.append(clf.coef_[0])
        elif solver == 'rf':
            importances.append(clf.feature_importances_)

    acc = accuracy_score(labels, y_pred)
    balanced_acc = balanced_accuracy_score(labels, y_pred)
    conf_mat = confusion_matrix(labels, y_pred)
    avg_importance = np.mean(importances, axis=0)

    return {
        'y_real': labels,
        'y_pred': y_pred,
        'balanced_acc':balanced_acc,
        'accuracy': acc,
        'conf_mat': conf_mat,
        'xlabel_confmat': 'true',
        'ylabel_confmat': 'predicted',
        'features_names': features_names,
        'feature_importance': avg_importance
    }



def calculate_reco_multi_MILCCI_given_labels_tuples(A, phi, labels_tuples, unique_labels_tuples = [], 
                                                       Y_true = [], return_scores =  True):
    # this claculates Y given A, Phi
    # A here is expected to be a tensor with n_Neurons X T X unique_labels
    # it is also possible to provide the full A. 
    assert (not return_scores) or len(Y_true) > 0, 'if return_scores, you must provide Y_true'
    if return_scores:
        assert Y_true.shape[2] == phi.shape[2]
        assert Y_true.shape[1] == phi.shape[0]
        assert Y_true.shape[0] == A.shape[0]
        
    if A.shape[2] == phi.shape[2]:
        A_is_full = True
    else:
        if len(unique_labels_tuples) > 0:
            assert np.array([el == el2 for el, el2 in zip(unique_labels_tuples, make_labels_unique_order( labels_tuples))]).all(), 'something is wrong in %s vs %s'%(str(unique_labels_tuples), str(make_labels_unique_order( labels_tuples)))
        else:
            unique_labels_tuples = make_labels_unique_order( labels_tuples, make_array = False)
        
        if A.shape[2] ==  len(unique_labels_tuples):
            A_is_full = False
        else:
            raise ValueError('full A dimension does not make sense: %s'%str(full_A.shape))
        
    
    
    Y_reco = []
    for trial_num, label_tuple in enumerate(labels_tuples):
        if A_is_full:
            cur_A = A[:,:,trial_num]
        else:
            index_label_tuple = find_indices_in_list(unique_labels_tuples, label_tuple)
            assert len(index_label_tuple) == 1
            cur_A = A[:,:,index_label_tuple]
        
        cur_phi = phi[:,:, trial_num]
        Y_reco_now = cur_A @ cur_phi.T
        Y_reco.append(Y_reco_now)
        
    Y_reco = np.dstack(Y_reco)
    if return_scores:
        l2_score = ((Y_true - Y_reco)**2).mean()
        corr_score =  spec_corr(Y_true.flatten() , Y_reco.flatten() )
        return Y_reco, l2_score, corr_score
    return Y_reco 

def make_A_hat_full(A, labels_tuples, unique_labels_tuples = []):
    if A.shape[2] == len(labels_tuples):
        print('A is already full.')
        return A
    if len(unique_labels_tuples) > 0:
        assert np.array([el == el2 for el, el2 in zip(unique_labels_tuples, make_labels_unique_order( labels_tuples))]).all(), 'something is wrong in %s vs %s'%(str(unique_labels_tuples), str(make_labels_unique_order( labels_tuples)))
    else:
        unique_labels_tuples = make_labels_unique_order( labels_tuples, make_array = False)
        
    assert A.shape[2] ==  len(unique_labels_tuples), 'full A dimension does not make sense: %s'%str(full_A.shape)
    
    full_A = []
    for trial, label_tup in enumerate(labels_tuples):
        inds = find_indices_in_list(unique_labels_tuples, label_tup)
        assert len(inds) == 1
        A_now = A[:,:,inds[0]]
        full_A.append(A_now)
    return np.dstack(full_A)
    
    
    
    
    

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%    
    
def cal_continous_trials(labels, full_phi_hat_3d, full_A_hat, data, cont_labels = [], cont_axis_list = [1], nu_full_each_axes_dict = {}, 
                         dict_params_cont = {}, is_cont = True, numbers2tuples = {}, axis_label_once = True):
    # axis_label_once - if True, it means that we want to update each label only once within a repeat. e.g. if True, and we have label = 1 for (session 1, trial 1) and label = 2 for (session 2, trial 2), we will update the ensembles for the session only once.
    # this is to solve for MILCCI with continous trial structure
    # labels: numbers of labels
    # full_phi_hat_3d: full phi
    # full_A_hat: full A
    # cont_labels: 1d or 2d array of the labels in tuples. i.e. vstack of the tuples
    # cont_axis_list : which axis is the one we need to sort by?
    # dict_params_cont_trial:
    #                         dict_params_cont = {}, is_cont = True, params_basis_pattern = {}
    # dict_params_cont_trial is like      
    #     {'durs': None,'verbose': None,    'lambda_similarity': None,
    #     'nu': None,'n_ensembles': None,'solve_Lasso_style': None,'l1': None,'params_lasso_solver': None,
    #     'seed': None,'with_graph': None,'decor_A': None,'factor_A': None,'func_normalize_A': None,
    #      'identity': None,       'num_unique_conditions': None,'for_each_trial_or_condition': None,'with_nu': None,
    #       'num_repeats': None,'cur_nu_mat_whole': None,'n_neurons': None,'another_update_for_A': None,'cur_A': None,
    #       'l2_phi': None,'T': None , 'numbers2tuples':None   } 
    # The split_A flag controls whether a separate A (ensemble-to-neuron weight matrix) is inferred per condition (e.g. session or trial) or a shared A across all conditions is used.
    
    ############### asserts
    #update_locals_with_dict(dict_params_cont)

    for k, v in dict_params_cont.items():
        globals()[k] = v  
    verbose = dict_params_cont['verbose']
    assert len(numbers2tuples) > 0
    tuples2numbers = {v:k for k,v in numbers2tuples.items()}
    if checkEmptyList(cont_labels):
        cont_labels = np.vstack([numbers2tuples[lab] for lab in labels])
    axes = np.arange(cont_labels.shape[1])
    if not split_A or len(cont_axis_list) > 0:
        assert (is_cont and len(cont_labels) == len(labels)) or (not is_cont), "is not continous?"
        #assert (cont_labels[:, cont_axis_list[0]] == np.sort(cont_labels[:,cont_axis_list[0]])).all(), "continous variable must be sorted!" 
        
    # in the general scenario, the following array will just be = labels. However if we have a different axis (i.e t is not just trial number of some e.g. speed, maybe it will differ if cont_labels is pre-defined)
    cont_labels_in_label_num = np.array([tuples2numbers[tuple(cont_labels[row])] for row in range(len( cont_labels ))]) 
    #####################################    
    
    
    #if is_cont:
    #    labels_unique_order =  cont_labels_in_label_num

    labels_unique_order = make_labels_unique_order(labels)
        
    ########################################################
    if 'num_unique_labels' not in locals():
        num_unique_labels = len(labels_unique_order)
    if split_A: # this means that we split the A by the axis of change
        axes_to_change = axes
        #n_neurons = A_fixed.shape[0]
        full_A_hat_individual = np.repeat(np.expand_dims(full_A_hat, 2), num_unique_labels, axis = 2) #np.zeros((n_neurons, n_ensembles, n_trials)) # TODO change n_trials to more compact
        
        
    else:
        full_A_hat_individual = []
        axes_to_change = ['all']
                        
    ##########################################################

    if axis_label_once:
        axis_label_visited = []
    
        
        
    for c, label in enumerate(labels_unique_order):
        if verbose:
            print(label)

        labels_loc = np.where(labels == label)[0]
        assert len(labels_loc) > 0, "how labels_loc is short? labels is %s and label is %s"%(str(labels), str(label))
        # make the phi as small as possible given the durations. i.e. not need to solve for all since maybe duration is smaller for that label. 
        
        if not checkEmptyList(durs):
            max_dur_for_label = np.max(durs[labels_loc])
            phi_for_label = full_phi_hat_3d[:,:max_dur_for_lab, labels_loc]
        else:
            phi_for_label = full_phi_hat_3d[:,:, labels_loc]
        
        if not split_A:
            phi_for_label_2d = np.vstack([phi_for_label[:,:,trial] for trial in range(phi_for_label.shape[2])]) # this is total T X p (T x num_trials for the condition) X p
            edges = np.linspace(0, phi_for_label_2d.shape[0], len(labels_loc) + 1)
        else:
            phi_for_label_2d = full_phi_hat_3d.copy()
            if not checkEmptyList(durs):
                phi_for_label_2d =  phi_for_label_2d[:,:max_dur_for_lab]
            
            
        #######################################################################################
        if not split_A:
            data_for_label = data[:,:, labels_loc]    
            data_for_label_2d = np.hstack([data[:,:,trial]
                                           for trial in range(data_for_label.shape[2])]) # this is N X total T
        else:
            data_for_label_2d = data.copy()
        
        """
        infer ind A - only if no P
        data_for_label_2d : neurons X (time X trials)
        phi_for_label_2d : (trials X time) X ensembles
        """
        if not with_graph:
            if params_lasso_solver['solver'] == 'inv' or params_lasso_solver['solver'] == 'nnls':
                # assume \hat{A} = \tilde{A} \cdot \nu
                # now look for tilde{A}
                # addi means for individual A
                # we want to solve: addi = arg min \| \|



                for axis_to_update_num, axis_to_update in enumerate(axes_to_change): # which axis to change

                    if split_A:
                        label_axis = numbers2tuples[label][axis_to_update] # what is the label of the current axis. e.g. if axis_to_update is 0, -> session 1? (axis_to_update is a number, label is a number). this gives us for instance 'session 1'
                        if (axis_label_once and label_axis not in  axis_label_visited) or (not axis_label_once):
                            axis_label_visited.append(axis_label_once)
                            to_update = True
                        else:
                            to_update = False # do you want to update the baseline A to accelerate inference? (drawbeck - may create a bias towards initial trials)
                        if to_update:
                            #  y = A phi = A[:,:i] @ phi[:i,:] + A[:,i:] @ phi[i:,:] (where one term is know and the other is unknown) 
                            """
                            IMPORTANT
                            
                            fixed - meaning the ones that we UPDATE NOW. i.e. they are fixed within the condition we are updating. If the indices we are updating now are for session 1, then we want these ensembles to be fixed for all trials of session 1.
                            changing - the indices we are actually NOT updating now. They change in the sense of changing for the current condition (due to other labels)
                            
                            explanation of indices:
                                1) A_fixed_indices - these are the ensembles indices that we update now. e.g. if I want to update the ensembles for session 1, I need to recognize which ensemble indices are SHARED across all trials of that session 1.
                                so the indices here just means which ensemble indices are now fixed and need to be found.
                                2) A_changing_indices - these are the indices that are changing withint trials of that session 1 due to other axes. We are not interested in them but must consider them for inferring the other ensembles
                            
                            """
                            A_fixed_indices = axes2ensembles[axis_to_update] #np.setdiff1d(np.arange(n_ensembles), axes2ensembles[axis_to_update])
                            n_ensembles_fixed = len(A_fixed_indices) # how many ensembles are we  updating now?
                            A_changing_indices = np.setdiff1d(np.arange(n_ensembles), A_fixed_indices)   # which ensembles we DO NOT want to update. these are the ones that are SHARED among all label_axis.
                            n_ensembles_changing = len(A_changing_indices) # how many ensembles are we NOT updating now? (but using)
                                                    
                            
                            
                            all_trials_of_axis_label_to_update_indices = np.array([trial_index 
                                                                                   for trial_index, lab_tuple in enumerate(cont_labels) 
                                                                                   if tuple(lab_tuple)[axis_to_update] == label_axis]) # all the trials of the axis label we need to update
                            
                            all_trials_of_labels_that_differ_from_axis_label = np.setdiff1d(np.arange(n_trials),  all_trials_of_axis_label_to_update_indices)
                            
                            all_condition_index_of_axis_label_to_update_indices = np.array([label_index  
                                                                                   for label_index, lab in enumerate(labels_unique_order) 
                                                                                   if tuple(numbers2tuples[lab])[axis_to_update] == label_axis])
                            all_condition_index_that_differ_from_axis_label =  np.setdiff1d(np.arange(num_unique_labels), all_condition_index_of_axis_label_to_update_indices)
                            if len(all_condition_index_that_differ_from_axis_label) == 0:
                                # this means that there is only 1 unique value for the current axes
                                assert len(np.unique(cont_labels[:, axis_to_update])) == 1, "what is happening? axis to update is %d"%axis_to_update
                                axis_has_unique_val = True
                            else:
                                axis_has_unique_val = False
                            
                            if num_unique_conditions == n_trials:
                                # check that these are the same
                                assert tuple(all_condition_index_of_axis_label_to_update_indices) == tuple(all_trials_of_axis_label_to_update_indices), "something does not make sense"
                                
                            else:
                                # check that they differ
                                assert set(all_condition_index_of_axis_label_to_update_indices) != set(all_trials_of_axis_label_to_update_indices), "something does not make sense"
                                

                            
                            right = []
                            left = []
                            
                            for counter_trial_to_update, (label_index_to_update, trial_to_update) in enumerate(zip(all_condition_index_of_axis_label_to_update_indices, all_trials_of_axis_label_to_update_indices)):
                                #print('phi_for_label_2d.shape %s'%str(phi_for_label_2d.shape))
                                right_now = phi_for_label_2d[:,A_fixed_indices, trial_to_update]
                                # this is the phi part of the unknown term. i.e. the ensemnbles we want to update
                                right.append(right_now)
                                # TODO not c. but concat all trials of to-update axes? 
                                assert full_A_hat_individual[:,A_changing_indices,label_index_to_update].shape[0] == data_for_label_2d.shape[0]
                                extra_on_top_now = full_A_hat_individual[:,A_changing_indices,label_index_to_update] @ phi_for_label_2d[:,  A_changing_indices, trial_to_update].T  # this is the A_known @ phi_known part we know and want to deduct. we need the deduction only for the top part. we are NOT going to update these ensembles here
                                # extra_on_top is of dim n_neurons X T
                                assert tuple(extra_on_top_now.shape) == tuple(data_for_label_2d.shape[:2]), 'shape mismatch: %s %s'%(str(tuple(extra_on_top_now.shape)) ,str(tuple(data_for_label_2d.shape[:2])))
                                left_now = data_for_label_2d[:,:,trial_to_update].T - extra_on_top_now.T
                                # this is the original data part without reduction. the minus extra on top is to acount for the known term
                                left.append(left_now)
                            if not axis_has_unique_val:
                                right = np.vstack(right + [lambda_similarity * np.eye(n_ensembles_fixed)*nu[A_fixed_indices].reshape((1,-1))])
                            else:
                                right = np.vstack(right) # + [lambda_similarity * np.eye(n_ensembles_fixed)*nu[A_fixed_indices].reshape((1,-1))])
                            left = np.vstack(left)
                            if verbose: print('left shape %s'%str(left.shape))
                            
                            if not axis_has_unique_val:
                                # the following snippet builds the ensembles from other conditions of the current axis (e.g. all sessions except for session 1) only for the ensembles indices we now update
                                full_A_hat_individual_in_axis_differ = []
                                for ind in all_condition_index_that_differ_from_axis_label:
                                    full_A_hat_individual_in_axis_differ.append(full_A_hat_individual[:,A_fixed_indices,ind])
                                full_A_hat_individual_in_axis_differ = np.dstack(full_A_hat_individual_in_axis_differ)
                                
                                
                                if verbose: print('full_A_hat_individual[:,A_fixed_indices,:][:,:,all_condition_index_that_differ_from_axis_label].mean(2) %s'%str(full_A_hat_individual_in_axis_differ.mean(2).shape))
                                left_addition = lambda_similarity * (full_A_hat_individual_in_axis_differ.mean(2)* nu[A_fixed_indices].reshape((1,-1))).T 
                                if verbose: print('left_addition.shap %s'%str(left_addition.shape))
                                
                                left = np.vstack([left , left_addition])

                    else:                        
                        n_ensembles_fixed = n_ensembles
                        right = np.vstack([phi_for_label_2d, 
                                    lambda_similarity * np.eye(n_ensembles)*nu.reshape((1,-1))])
                        left = np.vstack([data_for_label_2d.T, 
                                    lambda_similarity * (full_A_hat * nu.reshape((1,-1))).T ])
                    assert decor_A >= 0, 'decor must be non-negative'
                    
                    if decor_A  > 0:
                        # here also adjust to trial concat case
                        right = np.vstack([right, decor_A*(np.ones((n_ensembles_fixed , n_ensembles_fixed)) - np.eye(n_ensembles_fixed))])
                        left = np.vstack([left, np.zeros((n_ensembles_fixed , n_neurons))])
                    
                    left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                    left = left[left_not_0]
                    right = right[left_not_0]
                    
                    

                    addi = solve_Lasso_style(                        
                    right,
                    left,            
                    l1 = l1, x0 = [], params = params_lasso_solver,
                                                        lasso_params = {},random_state = seed).T                        
                    
                    if split_A:
                        
                        assert len(addi.shape) == 2 
                        assert addi.shape[0] == full_A_hat_individual.shape[0], "shapes in direction 0 do not match, addi.shape[0] %s != full_A_hat_individual.shape[0] %s"%(str(addi.shape[0], str(full_A_hat_individual.shape[0])))
                        assert addi.shape[1] == n_ensembles_fixed, "shapes in direction 0 do not match, addi.shape[1] %s != n_ensembles_fixed %s"%(str(addi.shape[1], n_ensembles_fixed ))
                                                                                                                                                                    
                        

                        for ind in all_condition_index_of_axis_label_to_update_indices:
                            full_A_hat_individual[:, A_fixed_indices,ind] = addi
                        
                        
               

                
                
                
            else:
                # again find tilde A. this is none fixed A case
                raise ValueError('change some things....')
                addi = np.hstack([solve_Lasso_style(
                np.vstack([phi_for_label_2d*nu_inv.reshape((1,-1)) , 
                           lambda_similarity * np.eye(n_ensembles)]), 
    
                np.vstack([data_for_label_2d[n,:].reshape((-1,1)), 
                            lambda_similarity * full_A_hat[n].reshape((-1,1)) * nu.reshape((-1,1)) ]),        
                l1 = l1, x0 = [], params = params_lasso_solver,
                                                   lasso_params = {},random_state = seed).reshape((-1,1)) for n in range(n_neurons)])                 
                
                addi = addi * nu_inv.reshape((-1,1))                  
                full_A_hat_individual.append(addi)
                
        else:
            
            raise ValueError('TODO !')
            # TODO
    if split_A:
        full_A_hat_individual = [full_A_hat_individual[:,:,layer].T for layer in range(full_A_hat_individual.shape[2])]
    
    full_A_hat_individual_tensor = np.dstack(full_A_hat_individual).transpose([1,0,2]) # neurons by ensembles by trials
    if verbose: print('mmax val of ensembles is %d'%np.nanmax(full_A_hat_individual_tensor))
    if func_normalize_A != identity and (not enable_regular_MILCCI_global):
        summations_1d = func_normalize_A(np.abs(full_A_hat_individual_tensor), axis = 0) 
        
        assert np.array(summations_1d  > 0.0000000001).all(), "pay attention, some ensembles are completely 0. Check data normalizationand values. maybe allow less ensembles?"        
        
        summations = (np.expand_dims(summations_1d, 0) + 10**-18 ) / factor_A
        if verbose: 
            print('summations')
            print(summations)
        full_A_hat_individual_tensor = full_A_hat_individual_tensor / summations
        assert np.array(np.abs(full_A_hat_individual_tensor).sum(0) < 2*factor_A).all(), "how did you get such high_values of full_A_hat? %s"%np.abs(full_A_hat_individual_tensor).sum(0)
        if verbose:
            print('mmax val of ensembles2 is %d'%np.nanmax(full_A_hat_individual_tensor))
        
    """
    now if multi-condition - further change
    
    """
    if is_cont:
        cond_array = np.arange(len(cont_labels.copy())) # these are the indices of the labels. the all labels unique tuples (i.e. not of one axis)
        array_to_iterate_on = cont_labels_in_label_num.copy() # these are the numbers of unique labels in tuples. i.e. 1->(box,odor), 2-> (box2, odor 2)
    else:
        cond_array = np.arange(num_unique_conditions)
        array_to_iterate_on = labels_unique_order
        
    if for_each_trial_or_condition in ['m_condition', 'cont_trial'] and another_update_for_A:
        """
        the following block 
        """
        if not with_nu:
            raise ValueError('you must provide nu if m_condition!')

        
        for repeat in range(num_repeats):
            for c, label in enumerate(array_to_iterate_on): 
                """
                # how similar each condition is to each other condsidenring the axes label. e.g. (box,odor), (box2,odor) considering box is 0
                # this (below) is a dictionary of {label : (tuple)}. Each value is a matrix of # ensembles X unique labels,
                # such that if val = nu_full_each_axes_dict[0], then val[j, u] means weather ensemble j  is forced to be similar under labels j and u (e.g. does ensemble 4 forced to be similar under (box,odor), (box2,odor) )?
                """
                assert len(nu_full_each_axes_dict) > 0, "nu_full_each_axes_dict must not be empty"
                cur_nu_mat_whole = nu_full_each_axes_dict[label] 

                """
                build As of all other matrices 
                """
               
                if not split_A:
                    if not is_cont: # i.e. discrete labels  
                        non_c = cond_array[cond_array != c]      
                        if len(non_c) != num_unique_conditions - 1: 
                            raise ValueError('number of non_c must be equal to num_unique_conditions - 1 but non_c of duration %d vs num_unique_conditions-1 = %d'%( len(non_c), num_unique_conditions - 1))            
                        cur_nu_mat_non_c = cur_nu_mat_whole[:,non_c] # this is a matrix of # ensembles X all labels that are not the current one
                        full_A_not_in_label = full_A_hat_individual_tensor[:,:,non_c]
                        raise ValueError('how?! should be cont!')
                    else:                    
                        non_c = []
                        for ax in axes: 
                            if ax in cont_axis_list:
                                #non_c_now = cond_array[(np.abs(cond_array - c) <=  params_basis_pattern['wind_size']) & (cond_array != c) & ]
                                # cond_array is array of indices
                                # c is an index
                                # non_c_now means which label-indices should be taken now?!
                                non_c_now = cond_array[(np.abs(cont_labels[:,ax] - cont_labels[c,ax]) <=  params_basis_pattern['wind_size']) & (cond_array != c) & (cont_labels[:,ax] != cont_labels[c,ax])]  # change to by values, not by index
                                non_c.append(non_c_now)
                            else:
                                non_c = cond_array[cond_array != c] 
                            assert len(non_c) > 0 
                            
                        non_c = np.sort(list(set.intersection(*map(set, non_c))))
                        cur_nu_mat_non_c = cur_nu_mat_whole[:,non_c] # this is a matrix of # ensembles X all labels that are not the current one
                        full_A_not_in_label = full_A_hat_individual_tensor[:,:,non_c]




                    assert full_A_not_in_label.shape[2] == num_unique_conditions - 1 , 'num unique conditions is %d'%(num_unique_conditions - 1)
                    full_A_not_in_label_vstack = np.vstack([full_A_not_in_label[:,:,layer].T     for layer in range(num_unique_conditions - 1)]) # this is (p * (num_conds - 1)) X N                
                    assert  full_A_not_in_label_vstack.shape[1] == n_neurons,   'something went wrong! full_A_not_in_label_vstack.shape[1] = %d, while n_ensembles = %d'%(full_A_not_in_label_vstack.shape[1], n_neurons)
                    nus_list = [cur_nu_mat_non_c[:,col].reshape((-1,1)) for col in range(num_unique_conditions - 1)]
                    nus_list_vstack = np.vstack([nu_i.reshape((-1,1)) for nu_i in nus_list]).reshape((-1,1))
                    left_nus = np.vstack([np.diag(nu_i.flatten()) for nu_i in nus_list]) 
                    # this is (pX(num conditions-1)) X p. e.g. the first (top, number 0) pXp matrix is a daigonal matrix whoch j,j entry is whether the j-th ensemble need to be similar between current label and label 0
                    assert left_nus.shape[0] == nus_list_vstack.shape[0]
                    assert full_A_not_in_label_vstack.shape[0] == nus_list_vstack.shape[0]
                    # full_A_not_in_label_vstack is (p * (num_conds - 1)) X N
                    if left_nus.shape[1] != n_ensembles:
                        raise ValueError('left_nus.shape[1] = %d != n_ensembles = %d'%(left_nus.shape[1] , n_ensembles))
                    

                #######################################################################################
                phi_for_label = full_phi_hat_3d[:,:, labels_loc]
                if not split_A:
                    phi_for_label_2d = np.vstack([phi_for_label[:,:,trial] for trial in range(phi_for_label.shape[2])]) # this is total T X p (T x num_trials for the condition) X p
                    
                else:
                    phi_for_label_2d = full_phi_hat_3d.copy()
                    if not checkEmptyList(durs):
                        phi_for_label_2d =  phi_for_label_2d[:,:max_dur_for_lab]
                        
                edges = np.linspace(0, phi_for_label_2d.shape[0], len(labels_loc) + 1)
                

                
                
                #######################################################################################
                if not split_A:
                    labels_loc = np.where(labels == label)[0]
                    data_for_label = data[:,:, labels_loc]    
                    data_for_label_2d = np.hstack([data[:,:,trial]
                                                   for trial in range(data_for_label.shape[2])]) # this is N X total T
                else:
                    data_for_label_2d = data.copy()
                
                
                for axis_to_update_num, axis_to_update in enumerate(axes_to_change): # which axis to change
                    if split_A:
                        label_axis = numbers2tuples[label][axis_to_update] # what is the label of the current axis. e.g. if axis_to_update is 0, -> session 1?
                        if (axis_label_once and label_axis not in  axis_label_visited) or (not axis_label_once):
                            axis_label_visited.append(axis_label_once)
                            to_update = True
                        else:
                            to_update = False # do you want to update the baseline A to accelerate inference? (drawbeck - may create a bias towards initial trials)
                        if to_update:
                            #  y = A phi = A[:,:i] @ phi[:i,:] + A[:,i:] @ phi[i:,:] (where one term is know and the other is unknown) 
                            """
                            explanation of indices:
                                1) A_fixed_indices - these are the ensembles indices that we update now. e.g. if I want to update the ensembles for session 1, I need to recognize which ensemble indices are SHARED across all trials of that session 1.
                                so the indices here just means which ensemble indices are now fixed and need to be found.
                                2) A_changing_indices - these are the indices that are changing withint trials of that session 1 due to other axes. We are not interested in them but must consider them for inferring the other ensembles
                            
                            """
                            A_fixed_indices = axes2ensembles[axis_to_update] # what are the indices of the ensembles that we are currently after? e.g. if (session1, trial1) and we want to update the axis of session now, and given 2 ensembles in each axis, the indices will be (0,1)
                            n_ensembles_fixed = len(A_fixed_indices)
                            A_changing_indices = np.setdiff1d(np.arange(n_ensembles), A_fixed_indices) # which ensembles we DO NOT want to update. these are the ones that are SHARED among all label_axis.
                            n_ensembles_changing = len(A_changing_indices) # how many ensembles are we updating now?                                                    
                                                        
                            all_trials_of_axis_label_to_update_indices = np.array([trial_index 
                                                                                   for trial_index, lab_tuple in enumerate(cont_labels) 
                                                                                   if tuple(lab_tuple)[axis_to_update] == label_axis]) # all the trials of the axis label we need to update
                            
                            all_trials_of_labels_that_differ_from_axis_label = np.setdiff1d(np.arange(n_trials),  all_trials_of_axis_label_to_update_indices) 
                            
                            
                            all_condition_index_of_axis_label_to_update_indices = np.array([label_index  
                                                                                   for label_index, lab in enumerate(labels_unique_order) 
                                                                                   if tuple(numbers2tuples[lab])[axis_to_update] == label_axis])
                            all_condition_index_that_differ_from_axis_label =  np.setdiff1d(np.arange(num_unique_labels), all_condition_index_of_axis_label_to_update_indices)
                            if num_unique_conditions == n_trials:
                                # check that these are the same
                                assert tuple(all_condition_index_of_axis_label_to_update_indices) == tuple(all_trials_of_axis_label_to_update_indices), "something does not make sense"
                                
                            else:
                                # check that they differ
                                assert set(all_condition_index_of_axis_label_to_update_indices) != set(all_trials_of_axis_label_to_update_indices), "something does not make sense"
                                
                                

                            
                            right = []
                            left = []
                            
                            for counter_trial_to_update, (label_index_to_update, trial_to_update) in enumerate(zip(all_condition_index_of_axis_label_to_update_indices, all_trials_of_axis_label_to_update_indices)):
                                #print('phi_for_label_2d.shape %s'%str(phi_for_label_2d.shape))
                                right_now = phi_for_label_2d[:,A_fixed_indices, trial_to_update]
                                # this is the phi part of the unknown term. i.e. the ensemnbles we want to update
                                right.append(right_now)
                                # TODO not c. but concat all trials of to-update axes? 
                                print('phi_for_label_2d.shape %s'%str(phi_for_label_2d.shape))
                                print('full_A_hat_individual_tensor.shape %s'%str(full_A_hat_individual_tensor.shape))
                                assert full_A_hat_individual_tensor[:,A_changing_indices,label_index_to_update].shape[1] == phi_for_label_2d[:,  A_changing_indices, trial_to_update].shape[1], "A and phi should match in shape 1."
                                
                                extra_on_top_now = full_A_hat_individual_tensor[:,A_changing_indices,label_index_to_update] @ phi_for_label_2d[:,  A_changing_indices, trial_to_update].T  # this is the A_known @ phi_known part we know and want to deduct. we need the deduction only for the top part. we are NOT going to update these ensembles here
                                # extra_on_top is of dim n_neurons X T
                                assert tuple(extra_on_top_now.shape) == tuple(data_for_label_2d.shape[:2]), 'shape mismatch: %s %s'%(str(tuple(extra_on_top_now.shape)) ,str(tuple(data_for_label_2d.shape[:2])))
                                left_now = data_for_label_2d[:,:,trial_to_update].T - extra_on_top_now.T
                                # this is the original data part without reduction. the minus extra on top is to acount for the known term
                                left.append(left_now)
                                
                            right = np.vstack(right + [lambda_similarity * np.eye(n_ensembles_fixed)*nu[A_fixed_indices].reshape((1,-1))]) # regularization over trials in A
                            
                            left = np.vstack(left)
                            print('left shape %s'%str(left.shape))
                            full_A_hat_individual_in_axis_differ = []
                            for ind in all_condition_index_that_differ_from_axis_label:
                                full_A_hat_individual_in_axis_differ.append(full_A_hat_individual_tensor[:,A_fixed_indices,ind])
                            full_A_hat_individual_in_axis_differ = np.dstack(full_A_hat_individual_in_axis_differ)
                            print('full_A_hat_individual[:,A_fixed_indices,:][:,:,all_condition_index_that_differ_from_axis_label].mean(2) %s'%str(full_A_hat_individual_in_axis_differ.mean(2).shape))
                            left_addition = lambda_similarity * (full_A_hat_individual_in_axis_differ.mean(2)* nu[A_fixed_indices].reshape((1,-1))).T 
                            print('left_addition.shap %s'%str(left_addition.shape))
                            
                            left = np.vstack([left , left_addition])
                        
                        # now this is left original but deductable
                    
                        assert decor_A >= 0, 'decor must be non-negative'
                        
                        if decor_A  > 0:
                            #heft here also adjust to trial concat case
                            right = np.vstack([right, decor_A*(np.ones((n_ensembles_fixed , n_ensembles_fixed)) - np.eye(n_ensembles_fixed))])
                            left = np.vstack([left, np.zeros((n_ensembles_fixed , n_neurons))])
                        
                        left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                        left = left[left_not_0]
                        right = right[left_not_0]
                        
                        
                        # update A
                        addi = solve_Lasso_style(                        
                        right,
                        left,            
                        l1 = l1, x0 = [], params = params_lasso_solver,
                                                            lasso_params = {},random_state = seed).T                        
                        
                  
                        # TODO hythyt remove 0 from split a
                        # in this case we got a very narrow A, which means we want to get the full one now.
                        
                        assert len(addi.shape) == 2 
                        assert addi.shape[0] == full_A_hat_individual_tensor.shape[0], "shapes in direction 0 do not match, addi.shape[0] %s != full_A_hat_individual.shape[0] %s"%(str(addi.shape[0], str(full_A_hat_individual_tensor.shape[0])))
                        assert addi.shape[1] == n_ensembles_fixed, "shapes in direction 0 do not match, addi.shape[1] %s != n_ensembles_fixed %s"%(str(addi.shape[1], n_ensembles_fixed ))
                                                                                                                                                                    
                        

                        for ind in all_condition_index_of_axis_label_to_update_indices:
                            full_A_hat_individual_tensor[:, A_fixed_indices,ind] = addi
                        
                        
                
                        
                    else:                    
                        
                        if params_lasso_solver['solver'] == 'inv':              
                            print( 'full_A_not_in_label_vstack.shape: %s'%str(full_A_not_in_label_vstack.shape)) # should be of shape number of  (ensembles * not current state) Xneurons
                            print( 'left_nus.shape: %s'%str(left_nus.shape))    
                            left = np.vstack([phi_for_label_2d, 
                                        lambda_similarity * left_nus])                        
                            right = np.vstack([data_for_label_2d.T, 
                                        lambda_similarity * full_A_not_in_label_vstack*nus_list_vstack  ])  # nus_list_vstack is a column vector of length (ensembles * (num_unique_conditions - 1))
                                                                              
                            ##################### do not take time points that are completely zero    
                            left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                            left = left[left_not_0]
                            right = right[left_not_0]
                            
                            addi = solve_Lasso_style(                        
                            left, right, l1 = l1, x0 = [], params = params_lasso_solver,
                                                                lasso_params = {},random_state = seed)
                            
                            assert np.all(np.abs(addi).sum(0) < 2*factor_A), "values of full_A_hat_individual_tensor will be too large!"
                            
                            full_A_hat_individual_tensor[:,:,c] = addi.T
                  
        
                            
                        else: # this is the 2nd time we update A
                            print('other update here!')
                            # limit a
                            addi = []
                            for n in range(n_neurons):
                                left = np.vstack([phi_for_label_2d, 
                                            lambda_similarity * left_nus])
                                
                                right = np.vstack([data_for_label_2d[n,:].reshape((-1,1)) , 
                                            lambda_similarity * full_A_not_in_label_vstack[:,n].reshape((-1,1))  ])
                                
                                ##################### do not take time points that are completely zero
        
                                left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                                left = left[left_not_0]
                                right = right[left_not_0]
                                
                                
                                
                                addi_cur = solve_Lasso_style(
                                left,                 
                                right,           
                                l1 = l1, x0 = [], params = params_lasso_solver,
                                                                   lasso_params = {},random_state = seed).reshape((-1,1))
                                assert np.all(np.abs(addi).sum(0) < 2*factor_A), "values of full_A_hat_individual will be too large!"
                                
                                
                                
                                
                                addi.append(addi_cur)  
                            addi = np.hstack(addi)     # addi is p X N
                        
                    
                    
                            full_A_hat_individual_tensor[:,:,c] = addi.T 
                
        
    print('max val of ensembles is %d'%np.nanmax(full_A_hat_individual_tensor)) # h the problem is 
    """
    now update phi 
    """
    if another_update_for_phi:
        if l2_phi > 0:
            # data is neurons X time X trials. 
            print('data shape is %s'%str(data.shape))
            data_extend = np.vstack([data, np.zeros((n_ensembles,T,data.shape[2]))])
            l2_phi_comp = l2_phi*np.repeat(np.expand_dims(np.eye(n_ensembles),2), full_A_hat_individual_tensor.shape[2], axis = 2)
            assert l2_phi_comp.shape == (n_ensembles,n_ensembles,full_A_hat_individual_tensor.shape[2])
            print('full_A_hat_individual_tensor.shape is %s'%str(full_A_hat_individual_tensor.shape))
            full_A_hat_individual_tensor_extend = np.vstack([full_A_hat_individual_tensor, l2_phi_comp  ])

            
        else:
            data_extend = data.copy()
            full_A_hat_individual_tensor_extend = full_A_hat_individual_tensor.copy()
            print('full_A_hat_individual_tensor.shape is %s'%str(full_A_hat_individual_tensor.shape))
            
        full_phi_hat_3d_new = []    
        for c, label in enumerate(labels_unique_order):
            cur_A = full_A_hat_individual_tensor_extend[:, :, c]

            
            where_label = np.where(labels == label)[0]
            full_phi_hat_3d_label = np.dstack([np.vstack([solve_Lasso_style(cur_A, data_extend[:,t,ind].flatten(),
                                                                          **params_update_phi).reshape((1,-1)) for t in range(data.shape[1])])
                                 for ind in where_label])
            full_phi_hat_3d_new.append(full_phi_hat_3d_label)
            
        full_phi_hat_3d_updated = np.dstack(full_phi_hat_3d_new).copy()    
    return full_A_hat_individual_tensor, full_phi_hat_3d_updated, full_A_hat_individual


    
    
def plot_confusion_matrix(preds, y_test, title='Confusion Matrix', fig=[], ax=[], plot_params= {'edgecolor':'black', 'lw':3, 'linecolor':'black', 
                                                                                               }, num_to_string = {},plot_normalized = True, with_p_val = True,
                          label_size = 30, cmap = 'cool' ):
    # Calculate the confusion matrix
    cm = confusion_matrix(y_test, preds)
    
    # Normalize the confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1).reshape((-1,1))#[:, np.newaxis]
    
    # Create figure and axis if they're empty
    if checkEmptyList(ax):
        fig, ax = plt.subplots(figsize=(cm.shape[0]*2, cm.shape[0]*2))
    
    # Default plot_params if not provided
    if plot_params is None:
        plot_params = {}

    # Set default edgecolor if not provided in plot_params
    plot_params.setdefault('edgecolor', 'black')
    
    if len(num_to_string) == 0:
        yticklabels = set(y_test)
    else:
        yticklabels = set(y_test)
        yticklabels = [num_to_string[num] for num in yticklabels]
        
    if len(num_to_string) == 0:
        xticklabels = set(y_test)
    else:
        xticklabels = set(y_test)
        xticklabels = [num_to_string[num] for num in xticklabels]
        
    # Create the heatmap with absolute numbers in brackets and normalized values
    if plot_normalized:
        cm_plot = cm_normalized.copy()
    else:
        cm_plot = cm.copy()
    sns.heatmap(cm_plot, annot=False, fmt='.2f', cmap=cmap, xticklabels=xticklabels, 
                yticklabels=yticklabels, ax=ax, robust = False, vmin = 0, square = True, **plot_params)
    
    
    ax.set_xticklabels(xticklabels, fontsize = label_size, rotation = 90)
    ax.set_yticklabels(xticklabels, fontsize = label_size, rotation = 0)
    # Add absolute numbers in brackets
    for i in range(len(cm)):
        for j in range(len(cm[i])):
            if not np.isnan(cm_normalized[i][j]):
                ax.text(j + 0.5, i + 0.5, '%.2f'%cm_normalized[i][j] + '\n'+ '(%s)' % str(cm[i][j]), ha='center', va='center', color='black', fontsize=30)
            else:
                ax.text(j + 0.5, i + 0.5, '0' + '\n'+ '(%s)' % str(cm[i][j]), ha='center', va='center', color='black', fontsize=30)

    chi2, p, dof, expected = chi2_contingency(cm+ 1e-10)

    # p-value is the result, which tells you if the model's performance is statistically significant
    print("Chi-squared: %.3e \n p-value: %.3e" % (chi2, p))


    chance = 1/cm.shape[0]
    #acc = np.nanmean(np.diag(cm_normalized))
    acc = np.nansum(np.diag(cm))/np.nansum(cm)
    if with_p_val:
        ax.set_title(title + '\n\n' + 'acc. %.2f \n (chance: %.2f \n p-value: %.3e)'%(acc,chance, p) , fontsize=20)
    else:
        ax.set_title(title + '\n\n' + 'acc. %.2f \n (chance: %.2f)'%(acc,chance) , fontsize=20)
    ax.set_xlabel('Predicted',fontsize=30)
    ax.set_ylabel('True',fontsize=30)
    return fig, ax, acc , chance   , cm             

    








def plot_2d(mat, params_fig = {}, fig = [], ax = [], params_plot = {}, type_plot = 'plot'):
    # 
    if checkEmptyList(ax):
        fig, ax = plt.subplots(1,1, **params_fig)
    if type_plot == 'plot':    
        ax.plot(mat[0], mat[1], **params_plot)
    else:
        ax.scatter(mat[0], mat[1], **params_plot)
        
def create_3d_ax(num_rows, num_cols, figsize = (), params = {}):
    if 'figsize' not in params and len(figsize) > 0:
        params['figsize'] = figsize
    fig, ax = plt.subplots(num_rows, num_cols, subplot_kw = {'projection': '3d'}, **params)
    return  fig, ax   

def str2bool(str_to_change):
    """
    Transform 'true' or 'yes' to True boolean variable 
    Example:
        str2bool('true') - > True
    """
    if isinstance(str_to_change, str):
        str_to_change = (str_to_change.lower()  == 'true') or (str_to_change.lower()  == 'yes') or (str_to_change.lower()  == 't') or (str_to_change.lower()  == 'y')
    return str_to_change

#%% Default Parameters 
global epsilon,params_default, instruct_per_selected
epsilon = 1e-5
sep = os.sep



instruct_per_selected = {'epsilon': 'Default tau values to be spatially varying, "tau" in MATLAB = 1'
                         ,'step_s': 'Default step to reduce the step size over time, 0.5' , 'p': 'Number of temporal profiles',
                         'nonneg':'Should be true or false',
                         'step_decay': 'should be around 0.99',
                         'reduceDim':'whether to apply pca before',
                         'solver': 'can be inv, spgl1, omp, IRLS, ista, fista, lasso, for solving Phi',
                         'norm_by_lambdas_vec':'Should be true to consider the weighted lasso',
                         'likely_from': 'poisson or gaussian',
                         'l1':'l1 regularization (rec around 0.01)', 
                         'l4': 'correlation between temporal activity'} 

"""
link of regularization to paper:
    code - paper
    l1 - lambda
    l2 - gamma 1 (frob norm)
    l3 - gamma 3 continuation
    l4 - gamma 2 (diag - correlations)
    l5 - gamma 4 time continues

"""


"""

important - to set the parameters choose type_answer = 1
"""


#%%  g_MILCCI Functions

global params_config
params_config = {'self_tune':7, 'dist_type': 'euclidian', 'alg':'ball_tree',
                       'n_neighbors':49, 'reduce_dim':False}

    
    
    
    
    
def inverse_dict(dictionary):
    return {val:key for key, val in dictionary.items()}
        
    
    
    
def similarity_based_on_identity(vec):
    """
    Calculate a similarity matrix based on the identity of elements in a one-dimensional input vector.
    
    Parameters:
    ----------
    vec : array-like
        A one-dimensional input vector for which the similarity matrix is calculated.
    
    Returns:
    -------
    similarity_matrix : numpy.ndarray
        A binary similarity matrix where 1 indicates identical elements in the input vector,
        and 0 indicates non-identical elements.
    
    Raises:
    -------
    ValueError
        If the input vector `vec` is not one-dimensional.
    
    Example:
    --------
    >>> input_vec = [1, 2, 3, 1, 4, 5]
    >>> similarity_matrix = similarity_based_on_identity(input_vec)
    >>> print(similarity_matrix)
    [[1 0 0 1 0 0]
     [0 1 0 0 0 0]
     [0 0 1 0 0 0]
     [1 0 0 1 0 0]
     [0 0 0 0 1 0]
     [0 0 0 0 0 1]]
    """
    if not is_1d(vec):
        raise ValueError('vec must be 1d')
    try:
        vec_in_numbers = vec.astype(float)
        numbers_to_vec  = {vec_in_numbers[i]: vec_i for i, vec_i in enumerate(vec)}
        print('made float')
    except:
        vec_to_numbers = {i:val for i,val in enumerate(vec)}
        vec_in_numbers = np.array([i for i, val  in enumerate(vec)])
        numbers_to_vec = {val:key for key, val in vec_to_numbers.items()}
        
    return 1*((vec_in_numbers.reshape((-1,1)) - vec_in_numbers.reshape((1,-1))) == 0), numbers_to_vec

def create_kernel_EEG_by_labels_sim(labels = [],  distance_graph = 2, addi_similarity = 0.2):  
    # create condition graph for kayhty's data
    # label_full =  str(dict_task_label[task]) +'_'+ str(dict_ipsi_contra[ipsi_contra]) + '_' + str(i)
    if checkEmptyList(labels):
        labels = np.load('labels_EEG_patient_10_xmin_0_xmax_n_ymin_0_ymax_n.npy')
        
    return similarity_based_on_identity(labels)[0]*distance_graph +  addi_similarity

    
    
def create_kernel_kathy_by_labels_sim(combs = [], same_task = 0.5, same_ipsi = 0.5, only_unique = True, type_grannet = 'kathy3d'):  #create_data_kathy_by_labels_sim(labels = [], same_task = 0.5, same_ipsi = 0.5):
    # create condition graph for kayhty's data

    if checkEmptyList(combs):
        combs = np.load('labels_full_%s_xmin_0_xmax_n_ymin_0_ymax_n.npy'%type_grannet)
   
    if only_unique:
        combs = np.unique(combs)
    task_type = np.array([lab.split('_')[0] for lab in combs]).astype(int)
    ipsi_type = np.array([lab.split('_')[1] for lab in combs]).astype(int)
    graph_same_task = same_task*similarity_based_on_identity(task_type)[0]
    graph_same_ipsi = same_ipsi*similarity_based_on_identity(ipsi_type)[0]
   
    
    
    return graph_same_task + graph_same_ipsi 
    
    
    
def kathy_from_labels_with_trials_to_numbers(labels_with_trials  = [], to_save = True, type_new = 'str', type_grannet = 'kathy3d')  :
    if checkEmptyList(labels_with_trials):
        labels_with_trials = np.load('labels_%s_include_trial_xmin_0_xmax_n_ymin_0_ymax_n.npy'%type_grannet)
    if type_new == 'str':
        combs = np.array([val.split('_')[0] + '_' +  val.split('_')[1] for val in labels_with_trials])
    else:
        combs = np.array([int(val.split('_')[0] +  val.split('_')[1]) for val in labels_with_trials])
    if to_save:
        np.save('labels_full_%s_xmin_0_xmax_n_ymin_0_ymax_n.npy'%type_grannet, combs)
        np.save('labels_%s_xmin_0_xmax_n_ymin_0_ymax_n.npy'%type_grannet, combs)
    #num_full_conditons = 
    return combs

    
    
    
    
    
    
    
    
def check_error_stuck(last_errors, params_stuck)    :
    """
    Checks if the algorithm is stuck or not by analyzing the change in error.
    
    Args:
        last_errors (list): A list of the previous errors in descending order.
        A (numpy.ndarray): A matrix used in the algorithm.
        phi (numpy.ndarray): A vector used in the algorithm.
        params_stuck (dict): A dictionary containing parameters related to checking if the algorithm is stuck.
    
    Returns:
        bool: True if the algorithm is stuck, False otherwise.
    """
    # if return True -> stuck
    # if return false - >not stuck but decrease
    if len(last_errors) > params_stuck['in_a_row']:
        last_errors =  last_errors[-params_stuck['in_a_row']:]
    max_thres = params_stuck['max_change_ratio']*last_errors[-1]
    return (np.abs(np.diff(last_errors)) < max_thres).all()
    
def add_noise_if_stuck(last_errors, A, phi, step_GD,  params_stuck) :
    if check_error_stuck(last_errors, params_stuck) :

        ss = int(str(datetime2.now()).split('.')[-1])
        np.random.seed(ss)
        A = A + np.random.randn(*A.shape)*params_stuck['std_noise_A']
        phi = phi + np.random.randn(*phi.shape)*params_stuck['std_noise_A']
        step_GD *= params_stuck['change_step']
    return A,phi, step_GD
  


def labels_to_nums(labels):
    dict_nums_labels ={}
    list_nums = []
    for label_num, label in enumerate(labels):
        dict_nums_labels[label_num] = label
        list_nums.append(label_num)
    return dict_nums_labels, list_nums


def fine_pos_from_angle(angles):
    sins = [np.sin(ang) for ang in angles]
    coss = [np.cos(ang) for ang in angles]
    return np.vstack([sins, coss])




def labels2proximity(labels, distance_metric = 'Euclidean', distance2proximity_trans = 'exp', rounded = False , 
                     rounded_max = 360, params = {}): 
    """  
    Parameters
    ----------
    labels : TYPE
        DESCRIPTION.

    Returns
    -------
    distance : np.array of #labels X #labels (#trials X #trials)
        i-th row - what is the distance to the i-th label
        j-th col - the labels count
    proximity :  np.array of #labels X #labels (#trials X #trials)
        i-th row - what is the Proximity to the i-th label
        j-th col - the labels count
    """


    if params['labels2distance_type'] == 'distance': 
        if np.max(np.shape(labels)) == len(labels.flatten()):
            #        labels = np.array([np.min([lab, rounded_max - lab]) for lab in labels])    
            distance_base = np.abs(labels.reshape((1,-1)) - labels.reshape((-1,1)))
    
        else: # ASSUMING DIFFERENT ROWS OF LABELS HAVE DIFFERENT VALUES LIKE DIFFERENT COORDINATES
    
            distance_base = np.sqrt(np.sum( np.dstack([(labels[row,:].reshape((1,-1)) - labels[row,:].reshape((-1,1)))**2 for row in np.arange(labels.shape[0])]),2))
    
        distance_base =  distance_base +10**3*(np.eye( distance_base.shape[0])) #np.diag(np.diag(distance))
    
        distance_base =  distance_base - np.diag(np.diag( distance_base)) + 0.5*np.eye( distance_base.shape[0])*np.min( distance_base)
    elif params['labels2distance_type'] == 'kathy': 

        distance_base = create_kernel_kathy_by_labels_sim(labels , same_task = 0.5, same_ipsi = 0.5, only_unique = True)
        
    elif params['labels2distance_type'] == 'identity_boolean':
        labels_numbers_dict = {label:num for num, label in enumerate(np.unique(labels))}
        labels_numbers_list = np.array([labels_numbers_dict[label] for label in labels])
        distance_base = 1*( (labels_numbers_list.reshape((1,-1)) - labels_numbers_list.reshape((-1,1))) != 0) + 1
    else:
        raise ValueError("params['labels2distance_type'] must be 'identity_boolean' OR 'distance', but %s"%params['labels2distance_type'])
    if distance_metric == 'Euclidean':
        distance = distance_base ** 2
    elif distance_metric == 'abs':
        distance = np.abs(distance_metric)       
    else:
        raise ValueError('Unknown Distance Metric!')
    if distance2proximity_trans == 'exp':
        proximity = np.exp(-distance/params['grannet_params']['sigma_distance2proximity_tran'])

    elif  distance2proximity_trans == 'inv':
        proximity = 1/distance
    elif distance2proximity_trans == 'keep':
        proximity = distance
    else:
        raise ValueError('Unknown Proximity Metric!')     
    

    return distance, proximity


def lists2list(xss)    :
    return [x for xs in xss for x in xs] 




def createDefaultParams(params = {}):
    dictionaryVals = {'step_s':1, 
                      'learn_eps':0.01,
                      'epsilon': 2,
                      'numreps': 2, 
                      }
    return  addKeyToDict(dictionaryVals,params)

def createLmabdasMat(epsilonVal, shapeMat):
    if isinstance(epsilonVal,  (list, tuple, np.ndarray)) and len(epsilonVal) == 1:
        epsilonVal = epsilonVal[0]
    if not isinstance(epsilonVal, (list, tuple, np.ndarray)):
        labmdas = epsilonVal * np.ones(shapeMat)
    else:
        epsilonVal = np.array(epsilonVal)
        if len(epsilonVal) == shapeMat[1]:
            lambdas = np.ones(shapeMat[0]).rehspae((-1,1)) @ epsilonVal.reshape((1,-1))
        elif len(epsilonVal) == shapeMat[0]:
            lambdas =  epsilonVal.reshape((-1,1)) @  np.ones(shapeMat[1]).rehspae((1,-1))
        else:
            raise ValueError('epsilonVal must be either a number or a list/tupe/np.array with the a number of elements equal to one of the shapeMat dimensions')






def addKeyToDict(dictionaryVals,dictionaryPut):
    return {**dictionaryVals, **dictionaryPut}


def validate_inputs(params):
    """
    This function takes a dictionary of parameters as input and validates them.
    Parameters
    ----------
    params : dict;         A dictionary of parameters to be validated.
    
    Returns
    -------
    dict;         A validated dictionary of parameters.
    """    
    params['epsilon'] = float(params['epsilon'])
    params['step_s'] = float(params['step_s'])
    params['l1'] = float(params['l1'])
    params['l4'] = float(params['l4'])
    params['p'] = int(params['p'])
    params['nonneg'] = str2bool(params['nonneg'])
    params['reduceDim'] = str2bool(params['reduceDim'])

        
    params['solver'] = str(params['solver'])
    params['step_decay'] = float(params['step_decay'])
    params['norm_by_lambdas_vec'] = str2bool(params['norm_by_lambdas_vec'])
    params['likely_from'] = str(params['likely_from'])
    return params

def plot_nets_side_by_size(A1,A2, real_axis = 1, ax = [], linewidth = None, linecolor = None,cmap = None, cbar = None):

    # real axis is the num of the nets
    if checkEmptyList(ax): fig, ax = plt.subplots()
    conc = []
    if real_axis == 1:
        #A1_A2 = np.hstack([ np.hstack([(A1[:,p]/np.max(A1[:,p])).reshape((-1,1)), (A2[:,p]/np.max(A2[:,p])).reshape((-1,1))])  for p in range(A1.shape[1])])
        A1_A2 = np.hstack([np.hstack([ norm_vec_min_max(A1[:,p]).reshape((-1,1)), norm_vec_min_max(A2[:,p]).reshape((-1,1))])  for p in range(A1.shape[1])])
    else:
        A1_A2 = np.vstack([np.vstack([ norm_vec_min_max(A1[p,:]).reshape((1,-1)), norm_vec_min_max(A2[p,:]).reshape((1,-1))])  for p in range(A1.shape[0])])
        #A1_A2 = np.vstack([ np.vstack([A1[p,:]/np.max(A1[p,:]) , A2[p,:]/np.max(A2[p,:])   ])  for p in range(A1.shape[0])])
    sns.heatmap(A1_A2, ax = ax, robust = True, linewidth = linewidth, linecolor = linecolor,cmap =cmap , cbar = cbar )
        
        

        
def create_data_name(data_name = '', xmin = '0', xmax = 'n',ymin = '0',ymax = 'n', type_name = 'data'):
    """
    This function creates a string with a specified format for data file names.
    
    Parameters:
    data_name (str, optional): The name of the data. Default value is an empty string.
    xmin (str, optional): The lower limit of the x axis. Default value is '0'.
    xmax (str, optional): The upper limit of the x axis. Default value is 'n'.
    ymin (str, optional): The lower limit of the y axis. Default value is '0'.
    ymax (str, optional): The upper limit of the y axis. Default value is 'n'.
    type_name (str, optional): The type of data. Default value is 'data'.
    
    Returns:
    str: The generated string in the format 'type_name_data_name_xmin_xmax_ymin_ymax.npy'.
    
    Example:
    >>> create_data_name('data_sample', '-5', '5', '-10', '10', 'experiment')
    'experiment_data_sample_xmin_-5_xmax_5_ymin_-10_ymax_10.npy'
    """    
    return '%s_%s_xmin_%s_xmax_%s_ymin_%s_ymax_%s.npy'%(type_name, data_name,str(xmin), str(xmax), str(ymin), str(ymax))
    

def split_stacked_data(data, T = 0, k = 0):
    # data in shape Neurons X (Time X conditions) [such that [bluck 1 N X T] , bluck 2 N X T]
    # PAY ATTNETION ! RETURN N x TIME x CONDITION (if applying for phi in grannet - transpose axis 0,1)
    if len(data.shape) == 3:
        return np.dstack([split_stacked_data(data[:,:,k_global], T, k) for k_global in range(data.shape[2])] )
    if T == 0 and k == 0:
        raise ValueError('you must provide either k and T!')
    elif T != 0 and k!= 0:
        if k*T != data.shape[1]:

            raise ValueError('T*k must be equal to 2nd dim of data (data.shape[2]), but t*k = %d and data.shape[1] = %d'%(T*k, data.shape[1]))
        else:
            pass
    elif T == 0:
        T = data.shape[1] / k
    elif k == 0:
        k = data.shape[1]/T
    else:
        raise ValueError('how did you arrive here?')
    if k != int(k) or T != int(T):
        raise ValueError('T and k must be ints. but T and k are %s'%str((T,k)))
    else:
        k = int(k)
        T = int(T)
    groups_opts = np.linspace(0, k*T, k+ 1).astype(int)

    store_data = []
    for opt_count, opt_begin in enumerate(groups_opts[:-1]):
        opt_end = groups_opts[opt_count + 1]
        cur_data = data[:,opt_begin : opt_end ]
        store_data.append(cur_data)
    return np.dstack(store_data)
        
        
def  create_proximity_matching(data, labels,  increase_diag = True) :
    
    # the data is the original 3d mat
    # labels here mmust be labels full (labels per trial)
    
    un_labels = np.unique(labels)
    sim_mat =  np.zeros(( len(un_labels),len(un_labels), data.shape[0]))  #+ np.expand_dims(np.eye(len(un_labels)), 2)
    if len(labels) != data.shape[2]:
        raise ValueError('data nd labels have different durations! please provide full labels info')
   
    for neuron in range(data.shape[0]):
        list_datas = []
        # create matrices of trials X time X neuron for each state
        
        for label in un_labels :
            data_label = data[:,:,labels == label]
            data_label_stack = np.transpose(data_label, (2,1,0))
            list_datas.append(data_label_stack[:,:,neuron]) # FUTURE IMPROVEMENT - MAKE IT MUCH MORE EFFICIENTS W/O RUNNING ON NEURONS 
    
        # now I have a list of all states in the form of trials X time for each neuron
        #transforms = [np.eye(1)]
        pairs = list(itertools.combinations(range(len( list_datas) ), 2))
        for pair in pairs:
            """
            try to solve psi Y_1 - Y_2
            """
            Y_1 = list_datas[pair[0]]
            Y_2 =  list_datas[pair[1]]
            Y_2_Y_1_T = Y_2 @ Y_1.T
            U, s, Vh = np.linalg.svd(Y_2_Y_1_T, full_matrices=False)

            #print()
            psi = U @ Vh
            #psi = Y_2 @ np.linalg.pinv(Y_1)
            Y1_trans = psi @ Y_1
            dist = np.sqrt(np.mean((Y1_trans - Y_2)**2))
            sim_mat[pair[0], pair[1], neuron] = dist
            sim_mat[pair[1], pair[0], neuron] = dist
    neuron_avg = np.mean(sim_mat, 2)

  
    # increase diag
    if increase_diag: 
        neuron_avg = neuron_avg +  np.eye(neuron_avg.shape[0])*10**3

       
        
        neuron_avg += np.eye(neuron_avg.shape[0])*np.min(neuron_avg)
    return neuron_avg
            
            
            
    
    
   
    
    
    

def order_A_results(full_A, full_phi):
    # order A2 according to A1
    num_trials = int(full_phi.shape[2] / full_A.shape[2])
    full_A_new = [full_A[:,:,0]]
    full_phi_new = [full_phi[:,:,:num_trials]]
    for trial in range(1,full_A.shape[2]):
        A_order_T, _, list_reorder  = match_times(full_A_new[-1].T, full_A[:,:,trial].T, 
                                                  full_phi[:,:,trial*num_trials], 
                                                  enable_different_size = False, add_low_corr = False )
        cur_phi = full_phi[:,:,trial*num_trials:num_trials*(trial+1)]
        phi_order_T = np.dstack([cur_phi[:,list_reorder,phi_num]  for phi_num in range(cur_phi.shape[2])   ])
        full_A_new.append(A_order_T.T)
        full_phi_new.append(phi_order_T)
    
    full_A_new = np.dstack(full_A_new)
    full_phi_new = np.dstack(full_phi_new)
    return full_A_new, full_phi_new


def init_mat(shape, dist_init, multi = 1, params_dist_int = {}):
    """
    This function initializes a matrix with a specified distribution.
    
    Parameters:
    shape (tuple): The shape of the matrix to be initialized.
    dist_init (str): The type of distribution to use for initialization.
                    Options: 'zeros', 'rand', 'uniform', 'normal'
    multi (int, optional): A multiplier for the matrix values. Default value is 1.
    params_dist_int (dict, optional): Additional parameters for the distribution. 
                                      Default value is an empty dictionary.
                                      Only relevant for 'normal' distribution.
                                      The default values for loc and scale are 0 and 1.0 respectively.
                                      
    Returns:
    ndarray: The initialized matrix.
    
    Example:
    >>> init_mat((3,3), 'zeros', 2)
    array([[0., 0., 0.],
           [0., 0., 0.],
           [0., 0., 0.]])
    >>> init_mat((3,3), 'normal', 2, {'loc':0, 'scale':1.0})
    array([[ 1.70956002, -0.13946417,  1.49056311],
           [ 0.60136805, -1.00341437,  2.572688  ],
           [-0.25894951,  0.47257538,  1.52980708]])
    """
    
    params_dist_int = {**params_dist_int, **{'loc':0, 'scale': 1.0}}
    if dist_init == 'zeros':
        A = np.zeros(shape)# np.zeros((n_neurons, p))
    elif  dist_init == 'rand':
        A = np.random.rand(*shape) * multi
    elif  dist_init == 'randn':
        A = np.random.randn(*shape) * multi
    elif dist_init == 'uniform':
        A = np.random.uniform(0, 1,size =shape )*multi    
    elif dist_init == 'normal':
        A = np.random.normal(params_dist_int['loc'], params_dist_int['scale'], size = shape)*multi  
    
    return A

def plot_mid(full_A, full_phi, reco, data):
    num_states_unique = full_A.shape[2]
    num_states_full = full_phi.shape[2]
    
    
    fig, axs = plt.subplots(1, num_states_unique, figsize = (5*num_states, num_states))
    [sns.heatmap(full_A[:,:,i], ax = ax) for i,ax in enumerate(axs)]
    
    fig, axs = plt.subplots(1, num_states, figsize = (5*num_states, num_states))
    [sns.heatmap(full_phi[:,:,i], ax = ax) for i,ax in enumerate(axs)]
    
    fig, axs = plt.subplots(1, num_states, figsize = (5*num_states, num_states))
    [sns.heatmap(reco[:,:,i], ax = ax) for i,ax in enumerate(axs)]
    
    
 
    
    
    
    
    
def identity(mat):
    return mat
    
    
    
    
    
    
    
def norm_to_plot(mat_2d, epsilon = 0.01):
    return np.hstack([norm_vec_min_max(mat_2d[:,t]).reshape((-1,1)) for t in range(mat_2d.shape[1])]) 


def norm_vec_min_max(vec)  :
    return (vec - np.min(vec))/(vec.max() - np.min(vec))
    
def normalize_A_columns(full_A, normalize_A_style = 'avg', epsilon = 10**(-9) ):
    
    assert normalize_A_style in ['avg','sum','max']
    if normalize_A_style == 'avg':
        func = np.mean
    elif normalize_A_style == 'sum':
        func = np.sum
    else:
        func = np.max
    if full_A.ndim == 2:
        long_A = full_A.copy()
        shape = (1,-1)
    elif full_A.ndim == 3:
        long_A = np.vstack([np.abs(full_A)[:,:,layer] for layer in range(full_A.shape[2])])
        shape = (1,-1,1)
    else:
        raise ValueError('A shape does not make sense!')
    normalize_A_values = func(long_A, axis = 0).reshape(shape)
    return full_A/(normalize_A_values + epsilon), normalize_A_values


def create_basis_patterns(labels,  
                          numbers2tuples, 
                          cont_labels = [], cont_axis_list = [1],                          
                         params_basis_pattern = {}, 
                         value_nu_fixed = 1, 
                         disable_assert = False ):
    """
    THIS FUNCTION JUST PROVIDES A TEMPLATE "I.E. BASIS PATTERN" WE CAN USE AROUND THE MIDDLE TRIAL TO DEFINE HOW MUCH TRIALS NEED TO BE SIMILAR. Such taht the regularization is prooportional to it
    Create a basis/window pattern to re-weight continuous variables. Not needed if no continuous variables (e.g., split_A only).  
    
    Parameters:  
    labels : list or array, Numeric labels corresponding to trials.  
    numbers2tuples : dict, Mapping from numeric label to tuple of values.  
    cont_labels : array-like, optional, Continuous variable values for each label. Default is [].  
    cont_axis_list : list of int, optional, Indices of axes that are continuous variables. Default is [1].  
    params_basis_pattern : dict, optional, Parameters:  
        'wind_size' : int or 'all', Number of trials before/after (total window = wind_size*2+1).  
        'weight_func' : str, ['linear','log','exp'], How weights are distributed.  
        'weight_min' : float, Minimum weight value.  
        'weight_max' : float or 'nu', Maximum weight value.  
        'one_or_two_sides' : int, [-1,1,2], -1 connect to former, 1 to future, 2 two-sided.  
    value_nu_fixed : float, optional, Fixed reference value for nu, must be > weight_min.  
    
    Returns:  
    basis_pattern : np.array, Array of weights for the basis/window pattern.  
    params_basis_pattern : dict, Updated dictionary of parameters including full window length.  
    label_distance_to_basis_pattern_values : dict, Mapping from label distance to corresponding weight value.
    """

    # this function just create the window / template to re-weigh continous variables by. It is not needed if there are no cont. variables (e.g. for split_A only - not needed)
    params_basis_pattern = {**{'wind_size': 5, # wind size defines how many trials to take before and after (overall wind_size*2+1 trials)
                           'weight_func':'linear', # how the weight is distributed across trials
                           'weight_min': 0, #'value_nu_fixed', #'nu', 
                           'weight_max': 'nu',
                           'one_or_two_sides': -1 # if -1 it means that we connect to the former one. if 2 - two sides. if 1 -> connect to future
                           },  #'value_nu_fixed'
                        **params_basis_pattern}
   
    # the inpot labels here are numbers not tuples. !
        
    if checkEmptyList(cont_labels):
        cont_labels = np.vstack([numbers2tuples[lab] for lab in labels])
    assert  len(cont_labels) == len(labels), "is not continous?"
    discrete_axis_list = np.setdiff1d(np.arange(cont_labels.shape[1]), np.array(cont_axis_list))
    if (len(cont_labels.shape) == 1) or (np.max(cont_labels.shape) == len(cont_labels.flatten())):
        cont_labels = cont_labels.reshape((-1,1))
        cont_axis_list = [0]
    

    assert params_basis_pattern['weight_func']  in ['linear','log','exp'], "undefined params_basis_pattern['weight_func']  option! %s"%params_basis_pattern['weight_func']
        
    assert not (params_basis_pattern['weight_min'] == 0 and params_basis_pattern['weight_max'] == 0), "either params_basis_pattern['weight_max'] or params_basis_pattern['weight_min'] must be != 0!"
    if params_basis_pattern['weight_min'] == 0:
        params_basis_pattern['weight_min'] = params_basis_pattern['weight_max'] / params_basis_pattern['wind_size']
    if params_basis_pattern['weight_max'] == 0:
        params_basis_pattern['weight_max'] = params_basis_pattern['weight_min'] * params_basis_pattern['wind_size']
           
    # CHECK WINDOW SIZE
    if isinstance(params_basis_pattern['wind_size'], str):
        assert params_basis_pattern['wind_size'] == 'all', "params_basis_pattern['wind_size'] must be all if string!"        
        params_basis_pattern['wind_size'] = int(np.floor((len(labels)-1)/2))
    elif params_basis_pattern['wind_size']*2+1 > len(labels):
        print('changing wind_size')
        params_basis_pattern['wind_size'] = int(np.floor((len(labels)-1)/2))
        
    # DEFINE FULL WINDOW LENGTH
    params_basis_pattern['wind_size_full_len'] = params_basis_pattern['wind_size']*2+ 1
    
    assert value_nu_fixed > params_basis_pattern['weight_min'], "value_nu_fixed must be > params_basis_pattern['weight_min']"
    basis_pattern = np.linspace( params_basis_pattern['weight_min'], params_basis_pattern['weight_max'], params_basis_pattern['wind_size']) 
    
    if params_basis_pattern['weight_func'] != 'linear':
        if params_basis_pattern['weight_func']  == 'exp':            
            basis_pattern = np.exp(basis_pattern, **params_basis_pattern.get('to_input_into_function'))  # params_basis_pattern.get('to_input_into_function') are specific paramteres for the exp / log
        elif params_basis_pattern['weight_func']  == 'log':
            basis_pattern = np.log(basis_pattern, **params_basis_pattern.get('to_input_into_function')) 
        else:
            raise ValueError('how?!')
        basis_pattern = (basis_pattern - params_basis_pattern['weight_min'])/basis_pattern.max()*params_basis_pattern['weight_max'] #(-params_basis_pattern['weight_min'] + params_basis_pattern['weight_mשס'])
    #print('basis_pattern is %s'%str(basis_pattern))
    if params_basis_pattern['one_or_two_sides'] == 2: # this means symmetric
        basis_pattern = np.hstack([basis_pattern,np.array([0]), basis_pattern[::-1] ])    
        basis_pattern_arange = np.linspace(-params_basis_pattern['wind_size'], 
                                       params_basis_pattern['wind_size'], params_basis_pattern['wind_size_full_len']).astype(int)
    elif params_basis_pattern['one_or_two_sides'] == 1: # this means former trials are nullified
        basis_pattern = np.hstack([np.zeros(len(basis_pattern)),np.array([0]), basis_pattern[::-1] ])    
        basis_pattern_arange = np.linspace(-params_basis_pattern['wind_size'], 
                                       params_basis_pattern['wind_size'], params_basis_pattern['wind_size_full_len']).astype(int)    
    elif params_basis_pattern['one_or_two_sides'] == -1: # this means future trials are nullified
        basis_pattern = np.hstack([basis_pattern,np.array([0]), np.zeros(len(basis_pattern)) ])    
        basis_pattern_arange = np.linspace(-params_basis_pattern['wind_size'], 
                                       params_basis_pattern['wind_size'], params_basis_pattern['wind_size_full_len']).astype(int)   
    else:
        raise ValueError('?!')
    assert len(basis_pattern_arange) == len(basis_pattern), "len(basis_pattern_arange) == len(basis_pattern) must be of same length, but %d vs %d"%(len(basis_pattern_arange), len(basis_pattern) )
        
    
    print('basis_pattern is %s; vs. basis_pattern_arange is %s'%(str(basis_pattern), str(basis_pattern_arange )))

    # the below is how the differences in labels (e.g. trials) quantify to differences in weights
    # e.g. {3: 6} that means that trial r vs trial r+3 will get weight similairty of 6
    label_distance_to_basis_pattern_values = {dist:basis_val for dist, basis_val in zip(basis_pattern_arange, basis_pattern)}
    return basis_pattern, params_basis_pattern, label_distance_to_basis_pattern_values
    
    
    
def solve_ls_with_LS(data, A, lambda_l2 = 0.1):
    # solves \| data - A phi \|_2^2 + lambda_l2* \| phi \|
    n_components = A.shape[1]
    left = np.vstack([data, np.zeros((n_components, data.shape[1]) ) ])
    right = np.vstack([A , lambda_l2 * np.eye(n_components) ])
    phi = np.linalg.pinv(right) @  left
    assert phi.shape[1] == data.shape[1] 
    return phi
    
    
    
def fast_MILCCI(data, labels, params_init_A = {}, max_trials_each = 25, n_ensembles = 15,
                   params_lasso_solver = {'solver': 'inv', 'num_iters': 10}, 
                   seed = 5,
                   num_trials = 1,
                   l2_phi = 0,
                   lambda_similarity = 0.05/10,
                   for_each_trial_or_condition = 'condition', # 'cont_trial'
                   provide_chuncks = True,
                   another_update_for_phi = True, return_data_recos = True, verbose = False, P = [], 
                   with_graph = False, nu = [], with_nu = False, 
                   factor_A = 1, func_normalize_A = identity, num_repeats = 3, 
                   n_ensembles_each = [], num_axes = 0,
                   numbers2tuples = {}, value_nu_fixed = 1, decor_A = 0, durs = [],
                   params_basis_pattern = {},
                   another_update_for_A = True, #nu_trials = None,
                   cont_axis_list = [], sparse_A_or_sparse_phi = 'A',
                   split_A = False, enable_regular_MILCCI = False): 
   

    print('started!')
    is_cont = for_each_trial_or_condition in [ 'cont_trial']
    global enable_regular_MILCCI_global
    enable_regular_MILCCI_global = enable_regular_MILCCI 
    
    #%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
    params_basis_pattern = {**{'wind_size': 5, # wind size defines how many trials to take before and after (overall wind_size*2+1 trials)
                           'weight_func':'linear', # how the weight is distributed across trials
                           'weight_min': 'nu', #'value_nu_fixed', #'nu', 
                           'weight_max': 0 # if zeros that means that we limit if based on weight_min & wind_size. e.g. if weight_min is 0.1, wind_size is 5, then weight_max would be 0.5
                           }, 
                        **params_basis_pattern}
    discrete_axis_list = np.setdiff1d(np.arange(num_axes), np.array(cont_axis_list))
    if not checkEmptyList(durs) and len(durs) != data.shape[2]: # durs is duration of trials 
        raise ValueError('num trials does not match')
    if isinstance(durs, list) and len(durs) > 0:
        durs = np.array(durs)
    
    if len(labels) != data.shape[2]:
        raise ValueError('data and labels shapes do not match')
        
    if with_graph and checkEmptyList(P):
        raise ValueError('please provide P if with graph!')
        
    if with_nu:
        if checkEmptyList(nu):
            raise ValueError('please provide nu if with nu!')
        if len(nu) != n_ensembles:
            raise ValueError('len nu is %d but you required %d ensembles'%(len(nu), n_ensembles))
    else:
        nu = np.ones(n_ensembles)
    if not isinstance(nu, np.ndarray):
        nu = np.array(nu)
    if (nu < 0).any():
        raise ValueError('all values in \nu must be positive')
        
    if np.array([len(el) != num_axes for el in list(numbers2tuples.values()) ]).any():
        raise ValueError('num_axes does not make sense. num_axes need to match the len of tuples in number2tuples')

    for cont_ax in cont_axis_list:
        cur_labels = [numbers2tuples[label][cont_ax] for label in labels]
        if len(np.unique(cur_labels)) < len(cur_labels):

            print('pay attention, the labels of this continous var are not unique. Are you sure you want to proceed?')
    assert not (for_each_trial_or_condition == 'cont_trial' and len(cont_axis_list) == 0), 'if for_each_trial_or_condition == cont_trial, you must provide a continous variable'
    for discrete_ax in discrete_axis_list:
        cur_labels = [numbers2tuples[label][discrete_ax] for label in labels]
        if len(np.unique(cur_labels)) == len(cur_labels):
            print('discrete axis is %d'%discrete_ax)
            print('unique labels are %s'%str(np.unique(cur_labels)))
            if not  enable_regular_MILCCI_global:
                assert str2bool(input('the labels of this DISCRETE var are unique. Are you sure you want to proceed with it as it were discrete?' + str(len(np.unique(cur_labels))) + ' vs ' + str(len(cur_labels))))
        
    #nu = nu / np.max(nu)
    nu_inv = 1/nu
    
    labels_unique_order = make_labels_unique_order(labels)
    
    # asserts the labels_unique_order is correct
    if len(labels_unique_order) == len(labels):
        # this means that you have a continues unique variable like trial as a state
        # in this case we want to make sure that labels is the same as labels unique order
        if not (np.array(labels) == np.array(labels_unique_order)).all():
            if 'home' in os.getcwd():
                print('IMPORTANT WARNING! PAY ATTENTION THAT THE ORDER OF A IS NOT AS THE ORDER OF PHI. FOR THE ORDER OF A, PLEASE REFER TO LABELS_UNIQUE_ORDER WITHIN PARAMS SAVE')
            else:
                input('IMPORTANT WARNING! PAY ATTENTION THAT THE ORDER OF A IS NOT AS THE ORDER OF PHI. FOR THE ORDER OF A, PLEASE REFER TO LABELS_UNIQUE_ORDER WITHIN PARAMS SAVE')
                
                
                
    num_unique_conditions = len(labels_unique_order)
    n_neurons = data.shape[0]
    
    """
    now check ensembles each
    """
    
    if with_nu and  (for_each_trial_or_condition in ['m_condition','cont_trial'] or split_A):
        if np.any([el not in  set(list(numbers2tuples.keys())) for el in set(labels)]):
            raise ValueError('all labels must be assigned to tuple!')
        # check that all tuples are of the same length
        lens_tuples = np.array([len(value) for value in list(numbers2tuples.values())])
        axes = lens_tuples[0]
        if np.any(lens_tuples != lens_tuples[0]):
            raise ValueError('not all tuples are of the same length: %s'%str(lens_tuples))
            
        if not isinstance(n_ensembles_each, (list, tuple, np.ndarray)):
            n_ensembles_each = [n_ensembles_each]*num_axes
        
        
        if num_axes == 0:
            num_axes = len(labels_unique_order[0])
            
        if len(n_ensembles_each) == 0 :
            n_ensembles_each = [len(el) for el in np.array_split(np.arange(n_ensembles), num_axes) ]
        n_ensembles_each_cumsum = np.cumsum(np.array([0] + list(n_ensembles_each)))
        axes2ensembles = {ax : np.arange(el, el2) for ax, (el, el2) in enumerate(zip(n_ensembles_each_cumsum[:-1], n_ensembles_each_cumsum[1:]))}

        if np.sum(n_ensembles_each) != n_ensembles:
            print('n_ensembles')
            print(n_ensembles)
            print(n_ensembles_each)
            raise ValueError('n_ensembles_each must much total n_ensembles')
        """
        FOR CONTINOUS VARIABLES. CREATE THE BASIS PATTERNS
        """
        if len(cont_axis_list) > 0: # check if we have any continous variable
        
            # make sure that if the weight min is nu, we set it
            params_basis_pattern['weight_min'] = params_basis_pattern.get('weight_min', np.mean(nu))
            params_basis_pattern['weight_max'] = params_basis_pattern.get('weight_max', np.mean(nu))
            
            assert (not isinstance(params_basis_pattern['weight_min'], str)) or params_basis_pattern['weight_min'] in ['nu'], "params_basis_pattern['weight_min'] must be nu but %s"%str(params_basis_pattern['weight_min'])
            if params_basis_pattern['weight_min'] in ['nu']:   params_basis_pattern['weight_min'] = np.mean(nu)
            assert (not isinstance(params_basis_pattern['weight_max'], str)) or params_basis_pattern['weight_max'] in ['nu'], "params_basis_pattern['weight_max'] must be nu but %s"%str(params_basis_pattern['weight_min'])
            if params_basis_pattern['weight_max'] in ['nu']:   params_basis_pattern['weight_max'] = np.mean(nu)
            assert params_basis_pattern['weight_max'] == 0 or params_basis_pattern['weight_max'] > params_basis_pattern['weight_min'], "weight max (must be) > weight min, but %s <= %s"%(str(params_basis_pattern['weight_max']), str(params_basis_pattern['weight_min']))
            
            basis_pattern, params_basis_pattern, label_distance_to_basis_pattern_values = create_basis_patterns(labels,
                                  numbers2tuples, 
                                  #nu_trials=nu_trials, # nu trials refer to nu by which we normalize the basis pattern
                                  cont_labels = np.vstack([numbers2tuples[lab] for lab in labels]), 
                                  cont_axis_list = cont_axis_list,
                                 params_basis_pattern =  params_basis_pattern,
                                 value_nu_fixed = value_nu_fixed) 
        
        
        """
        create the nus set
        """
        # "nu_each_axes_list_2d" is a list:
        # this is a list of what is the nu for each unique axes (i.e. each unique class). i.e. 
        # the first element is the nu associated with the axes in the first tuple element (e.g. box).
        # for instance- if numbers2tuples is {1: ('box_1', 'banana')}, 
        # then the first element in nu_split corresponds to box and remain constant for that.
        # each element is a matrix of (n_ensembles_per_that class X 2) (e.g. shape [n_ensembles_each, 2] ) 
        # first column if the nu values (regularization values) that is not for the class (i.e. the regular nu), all identical
        # the 2nd col is the nu fixed value
        nu_each_axes_list_2d = [np.hstack([nu[el1:el2].reshape((-1,1)), value_nu_fixed*np.ones(el2 - el1).reshape((-1,1))]) 
                                for el1, el2 in zip(n_ensembles_each_cumsum[:-1], n_ensembles_each_cumsum[1:])]
        
        # now create a nu foor each pair of conditions. it is a dict with {label : matrix of (n_ensembles X num_unique_labels)}
        # one of the columns of each matrix in the values of labels_unique_order is zeros. the zeros corresponds to the current state. we do not want to use it. 
        nu_full_each_axes_dict = {}
        
        for label_count, label in enumerate(labels_unique_order):
            cur_mat = np.zeros((n_ensembles, num_unique_conditions))
            
            # now start filling cur_mat
            for label_count2, label2 in enumerate(labels_unique_order):
                if label_count2 != label_count:
                    """
                    check what axes label and label2 share
                    """
                    tup1 = numbers2tuples[label]
                    tup2 = numbers2tuples[label2]
                    cur_full_nu = []
                    # PAY_ATTENTION - HERE IS THE REGULARIZATION OVER LABELS. 
                    # important - if the first element in the tuple is box, box will be **fixed** for the first set of ensembles!
                    for ax in np.arange(axes):
                        if ax in discrete_axis_list:                        
                            if tup1[ax] == tup2[ax]: # this means they share the same axes and we want nu to be very high (i.e. 1 ) - in this case taking the 2nd col of each nu_each_axes_list_2d
                                ind_desired = 1
                                #print('using one!')
                            else:
                                ind_desired = 0
                            cur_nu_part = nu_each_axes_list_2d[ax][:,ind_desired] # this is a list of number of ensembles that change with this axis
                        else: #i.e continous
                            if np.abs(tup1[ax] - tup2[ax]) <= params_basis_pattern['wind_size']: # if the trials are close enough to each other
                                # todo : may be faster to do it via graph. (e.g. use the function calculate_graph_simularity)
                                diff_tup = tup1[ax] - tup2[ax] # these are 2 tuples of different trials. We calculate their distance in calss ax
                                assert diff_tup  in label_distance_to_basis_pattern_values, "label_distance_to_basis_pattern_values is %s, but diff_tup (the different in labels across the labels wrt class %d) = %d not inside "%(str(label_distance_to_basis_pattern_values), ax, diff_tup)
                                
                                # the row below gives us the regularization parametrs between trials within the window. 
                                cur_nu_part =  np.array(label_distance_to_basis_pattern_values[diff_tup])*np.ones((n_ensembles_each[ax],1))
                                if verbose:
                                    print('cur_nu_part %s, diff_tup %s, ax %s\n\n'%(str(cur_nu_part), str(diff_tup), str(ax)))
                                
                            else:  # if trials are not close to each other enough, they are not regularized together
                                cur_nu_part = np.array([0]*n_ensembles_each[ax]) #nu[n_ensembles_each_cumsum[ax]:n_ensembles_each_cumsum[ax+1]]
                                
                            

                        cur_full_nu.append(cur_nu_part.reshape((-1,1)))
                    cur_full_nu = np.vstack(cur_full_nu)    # vector of number of ensembles. 
                    if len(cur_full_nu.flatten()) != n_ensembles:
                        raise ValueError('what happened?!  len(cur_full_nu.flatten()) != n_ensembles: (%d, %d)'%( len(cur_full_nu.flatten()), n_ensembles))
                    
                    
                    cur_mat[:, label_count2] = cur_full_nu.flatten() # ensembles X unique_labels
            nu_full_each_axes_dict[label] = cur_mat # i.e. for each label we havea key in the dict. the value is a matrix of n_ensembles X unique labels [the label is in number but corresponds to tupl]
        
    
    n_trials = data.shape[2]
    
    ##############################################
    #if sparse_A_or_sparse_phi  == 'A': - this means that we need to first learn the diciotnary of times using sample of the data -> learn neural weights -> define it as a new dict -> learn phi weights on it.
    # i.e. take coeffs define them as dict then do lasso
    ##############################################
    
    trials_consider = lists2list([list(np.where(labels == label)[0][:max_trials_each]) for label in labels_unique_order])
    
    params_init_A = { **{'max_points_in_dim_reduction': 80, 'transform_algorithm': 'lasso_lars', 
                         'transform_alpha':0.01/2,'random_state':42}, **params_init_A}
    params_save = locals()
    
    
    params_init_A['max_points_in_dim_reduction'] = np.min([data.shape[1], params_init_A['max_points_in_dim_reduction']])
    
    # learn samples X features. find X = UV with U being sparse (not mistake!)
    # data is in form of neurons X time X trials. If we want the neurons ensembles to be saprse, no transpose. Otherwise transpose
    dict_learner = DictionaryLearning(
                        n_components=n_ensembles, transform_algorithm= params_init_A['transform_algorithm'], 
                        transform_alpha=params_init_A['transform_alpha'],
                        random_state=params_init_A['random_state'],
                    )
    
    max_points = np.min([params_init_A['max_points_in_dim_reduction'], data.shape[1] ])
    
    if checkEmptyList(durs):
        data_hstack_to_fit = np.hstack([data[:,:max_points, j] for j in trials_consider])
    else:
        data_hstack_to_fit = np.hstack([data[:,:np.min([durs[j],max_points]), j] for j in trials_consider])
        
        
    print('data_hstack_to_fit is of shape %s'%(str(data_hstack_to_fit.shape)))
    print('trials consider is %s'%str(trials_consider))
        
    data_hstack_to_fit_with_transform =  data_hstack_to_fit if sparse_A_or_sparse_phi  == 'A' else data_hstack_to_fit.T
    # if sprse_A is sparse, then we need to transform on data with the same number of As. 
    dict_learner.fit( data_hstack_to_fit_with_transform  )#data_hstack_to_fit.T)
    
    if checkEmptyList(durs):
        data_hstack_to_transform = np.hstack([data[:,:, j] for j in range(n_trials)])
    else:
        data_hstack_to_transform = np.hstack([data[:,:durs[j], j] for j in range(n_trials)])
        
    
    
    if sparse_A_or_sparse_phi  == 'A':
        assert data_hstack_to_fit_with_transform.shape[0] == data_hstack_to_transform.shape[0] , "shape mismatch! %s vs %s "%(str(data_hstack_to_fit_with_transform.shape ) , str(data_hstack_to_transform.shape))
        
        #### take the coeffs (i.e. the ensembles) from the fit
        
        
        
        if params_init_A.get('ensemble_positive'):
            #full_A_hat_non_norm = sparse_encode(data_hstack_to_transform, dict_learner.components_, algorithm='lasso_lars', positive=True)
            full_A_hat_non_norm = sparse_encode(data_hstack_to_fit_with_transform, dict_learner.components_, 
                                                algorithm=params_init_A['transform_algorithm'], 
                                                alpha=params_init_A['transform_alpha'], positive=True)
        else:
            full_A_hat_non_norm = dict_learner.transform(data_hstack_to_fit_with_transform)
        assert full_A_hat_non_norm.shape[0] == data_hstack_to_fit_with_transform.shape[0]
        #else:
        #    full_A_hat_non_norm= dict_learner.transform(data_hstack_to_transform) # this is time X ensembles X states
        #print('shape full_A_hat_non_norm is %s'%str(full_A_hat_non_norm.shape))
        #input('ok?!')
        # learn the dense dict per trial
        full_phi_hat_non_norm =  solve_ls_with_LS(data = data_hstack_to_transform, A = full_A_hat_non_norm).T #dict_learner.components_.T  
    else:
        full_phi_hat_non_norm = dict_learner.transform(data_hstack_to_transform) # this is time X ensembles X states
        full_A_hat_non_norm = dict_learner.components_.T 
    
    #full_phi_hat_non_norm = dict_learner.transform(data_hstack_to_transform.T) # this is time X ensembles X states

    #full_A_hat_non_norm = dict_learner.components_.T # this is neurons X ensembles
    if verbose:
        print('full_A_hat_non_norm.shape %s'%str(full_A_hat_non_norm.shape))
    
    """
    normalize A
    """
    if func_normalize_A != identity:
        full_A_hat_sums = func_normalize_A(np.abs(full_A_hat_non_norm),0)
        if not enable_regular_MILCCI_global:
            assert np.array(full_A_hat_sums > 0.0000000001).all(), "pay attention, some ensembles are completely 0. Check data normalizationand values. maybe allow less ensembles?"
        full_A_hat = full_A_hat_non_norm*factor_A / (full_A_hat_sums.reshape((1,-1)) + 10**-18)
        full_phi_hat = full_phi_hat_non_norm/factor_A  * (full_A_hat_sums.reshape((1,-1)) + 10**-18)
    else:
        full_A_hat = full_A_hat_non_norm
        full_phi_hat= full_phi_hat_non_norm
        
    
    
    T = data.shape[1]
    
    edges = np.linspace(0,full_phi_hat.shape[0], n_trials + 1 ).astype(int)

    full_phi_hat_3d = np.dstack([full_phi_hat[t1: t2] for t1, t2 in zip(edges[:-1], edges[1:])])
    

    
   
    """
    percentiles
    """
    
    l1 = np.percentile(np.abs(full_A_hat).sum(1).flatten(), 40)
    l1_phi = np.percentile(np.abs(full_phi_hat_3d).sum(0).mean(), 40)
    
    A_baseline = full_A_hat# func_baseline(full_A_hat, axis = 2)
    
    params_update_phi = {}
    
    if not check_if_labels_batches(labels) and provide_chuncks:
        raise ValueError('you must provide labels in chunks!')
        
    
    params_lasso_phi  = {**{'solver': 'nnls'}, **{}}
    print(params_init_A.get('ensemble_positive'))
    if params_init_A.get('ensemble_positive'):
        params_lasso_solver =  {**{'solver': 'nnls'}, **{}}
    else:
        params_lasso_A  = {}
    print('solve is :%s'%params_lasso_solver['solver'])
    ############################################################ in current implementation if ensemble_positive then must be via inv
    # TODO  - extend ensemble positive to more solvers
    
    
    ############################################################
    
    params_update_phi  = {'params': params_lasso_phi, 'l1': l1_phi, 'random_state': seed + 4, 'x0': []}
    
    """
    if for each trial individualy 
    """
    if for_each_trial_or_condition == 'trial' and not split_A:
        raise ValueError('fix to inv and ensembles similarity instead of 0!')
     
        full_A_hat_individual = [np.hstack([solve_Lasso_style(
            np.vstack([full_phi_hat_3d[:,:,trial], 
                       lambda_similarity * np.eye(n_ensembles)]), 
    
            np.vstack([data[n,:,trial].reshape((-1,1)), 
                       np.zeros((n_ensembles, 1))]),    
    
            l1 = l1, x0 = [], params = params_lasso_solver,
                                               lasso_params = params_lasso_A , random_state = seed).reshape((-1,1)) for n in range(n_neurons)])
                             for trial in range(num_trials)]
        
        full_A_hat_individual_tensor = np.dstack(full_A_hat_individual).transpose([1,0,2])
        
        if another_update_for_phi:
            
            T = data.shape[1]
            
            if l2_phi > 0:
                data_extend = np.hstack([data, 
                                         np.zeros((n_ensembles,T,1))])
                full_A_hat_individual_tensor_extend =  np.vstack([full_A_hat_individual_tensor, 
                                                                  l2_phi*np.expand_dims(np.eye(n_ensembles),2)])
                
            else:
                data_extend = data.copy()
                full_A_hat_individual_tensor_extend =  full_A_hat_individual_tensor.copy()
            
            full_phi_hat_3d_updated = np.dstack([np.vstack([solve_Lasso_style(full_A_hat_individual_tensor_extend[:,:, trial], data_extend[:,t,trial].flatten(),
                                                                      **params_update_phi).reshape((1,-1)) for t in range(data.shape[1])])
                             for trial in range(num_trials)])
            
        
    elif for_each_trial_or_condition in [ 'condition', 'm_condition'] and not split_A: # m-condition is multi condition (i.e. multi-MILCCI)
        """
        if for each condition (but fixed for trials).
        Here I need to integrate the nu
        """
        full_A_hat_individual = []
        labels_unique_order = make_labels_unique_order(labels)
        for c, label in enumerate(labels_unique_order):
            if verbose:
                print(label)
    
            labels_loc = np.where(labels == label)[0]
    
            # make the phi as small as possible given the durations. i.e. not need to solve for all since maybe duration is smaller for that label. 
            if not checkEmptyList(durs):
                max_dur_for_label = np.max(durs[labels_loc])
                phi_for_label = full_phi_hat_3d[:,:max_dur_for_lab, labels_loc]
            else:
                phi_for_label = full_phi_hat_3d[:,:, labels_loc]
            
            
            phi_for_label_2d = np.vstack([phi_for_label[:,:,trial] for trial in range(phi_for_label.shape[2])]) # this is total T X p (T x num_trials for the condition) X p
            edges = np.linspace(0, phi_for_label_2d.shape[0], len(labels_loc) + 1)
            
            #######################################################################################
            data_for_label = data[:,:, labels_loc]
    
            data_for_label_2d = np.hstack([data[:,:,trial]
                                           for trial in range(data_for_label.shape[2])]) # this is N X total T
            
            
            #A_for_label = full_A_hat[:,:, c] # this is N x p
    
            """
            infer ind A - only if no P
            data_for_label_2d : neurons X (time X trials)
            phi_for_label_2d : (trials X time) X ensembles
            """
            if not with_graph:
                if params_lasso_solver['solver'] == 'inv' or params_lasso_solver['solver'] == 'nnls':
                    # assume \hat{A} = \tilde{A} \cdot \nu
                    # now look for tilde{A}
                    # addi means for individual A
                    # we want to solve: addi = arg min \| \|
                    # todo add comment
                    if decor_A  == 0:
                        
              
                        right = np.vstack([phi_for_label_2d, 
                                    lambda_similarity * np.eye(n_ensembles)*nu.reshape((1,-1))])
                        left = np.vstack([data_for_label_2d.T, 
                                    lambda_similarity * (full_A_hat * nu.reshape((1,-1))).T ])
                        
                        left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                        left = left[left_not_0]
                        right = right[left_not_0]
                        
                        
                        addi = solve_Lasso_style(                        
                        right,
                        left,            
                        l1 = l1, x0 = [], params = params_lasso_solver,
                                                            lasso_params = params_lasso_A ,random_state = seed)                        
                        
                        # bring back to \hat{A} = \tilde{A} / nu
                        #addi = addi * nu_inv.reshape((-1,1))
                    else:
                        # we want to solve: left = right @ A
             
                        right = np.vstack([phi_for_label_2d, 
                                    lambda_similarity * np.diag(nu.flatten()), decor_A*(np.ones((n_ensembles, n_ensembles)) - np.eye(n_ensembles)) ]) # this is (p + Td) X p 
                        
                        left = np.vstack([data_for_label_2d.T, 
                                    lambda_similarity * (full_A_hat * nu.reshape((1,-1))).T, np.zeros((n_ensembles , n_neurons)) ])  # this is (2p+ Td) X N
                        
                        ##################### do not take time points that are completely zero
                        # this is the first time we update
                        left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                        left = left[left_not_0]
                        right = right[left_not_0]
                        
                        
                        
                        addi = solve_Lasso_style(                        
                        right,             
                        left,                
                        l1 = l1, x0 = [], params = params_lasso_solver,
                                                            lasso_params = params_lasso_A,random_state = seed)
                        
                        ################## 
                        assert np.all(np.abs(addi).sum(0) < 2*factor_A), "values of full_A_hat_individual will be too large!"
                        
                    full_A_hat_individual.append(addi)

                    
                    
                else:
                    # again find tilde A
                    raise ValueError('change some things....')
                    addi = np.hstack([solve_Lasso_style(
                    np.vstack([phi_for_label_2d*nu_inv.reshape((1,-1)) , 
                               lambda_similarity * np.eye(n_ensembles)]), 
        
                    np.vstack([data_for_label_2d[n,:].reshape((-1,1)), 
                                lambda_similarity * full_A_hat[n].reshape((-1,1)) * nu.reshape((-1,1)) ]),        
                    l1 = l1, x0 = [], params = params_lasso_solver,
                                                       lasso_params = params_lasso_A,random_state = seed).reshape((-1,1)) for n in range(n_neurons)])    
                    
                    
                    
                    addi = addi * nu_inv.reshape((-1,1))                   
                    
                    full_A_hat_individual.append(addi)
                    
            else:
                
                raise ValueError('TODO !')
                # TODO
        
        full_A_hat_individual_tensor = np.dstack(full_A_hat_individual).transpose([1,0,2]) # neurons by ensembles by trials
        print('mmax val of ensembles is %d'%np.nanmax(full_A_hat_individual_tensor))
        if func_normalize_A != identity and (not enable_regular_MILCCI_global):
            summations_1d = func_normalize_A(np.abs(full_A_hat_individual_tensor), axis = 0) 
            
            assert np.array(summations_1d  > 0.0000000001).all(), "pay attention, some ensembles are completely 0. Check data normalizationand values. maybe allow less ensembles?"        
            
            summations = (np.expand_dims(summations_1d, 0) + 10**-18 ) / factor_A
            print('summations')
            print(summations)
            full_A_hat_individual_tensor = full_A_hat_individual_tensor / summations
            assert np.array(np.abs(full_A_hat_individual_tensor).sum(0) < 2*factor_A).all(), "how did you get such high_values of full_A_hat? %s"%np.abs(full_A_hat_individual_tensor).sum(0)
            print('mmax val of ensembles2 is %d'%np.nanmax(full_A_hat_individual_tensor))
            
        """
        now if multi-condition - further change
        
        """
        cond_array = np.arange(num_unique_conditions)
        if for_each_trial_or_condition in ['m_condition', 'cont_trial'] :
            print('is here')
            if not with_nu:
                raise ValueError('you must provide nu if m_condition!')
            # now follow https://mail.google.com/mail/u/0/#inbox/QgrcJHrhwzTdHRKpZfBsWnWKGgtRQwLFsSb
            
            for repeat in range(num_repeats):
                for c, label in enumerate(labels_unique_order):
                    
                    """
                    # how similar each condition is to each other condsidenring the axes label. e.g. (box,odor), (box2,odor) considering box is 0
                    # this (below) is a dictionary of {label : (tuple)}. Each value is a matrix of # ensembles X unique labels,
                    # such that if val = nu_full_each_axes_dict[0], then val[j, u] means weather ensemble j  is 
                    forced to be similar under labels j and u (e.g. does ensemble 4 forced to be similar under (box,odor), (box2,odor) )?
                    """
                    cur_nu_mat_whole = nu_full_each_axes_dict[label] 
                    
                    # print('cur_nu_mat_whole')
                    # print(cur_nu_mat_whole)
                    # input('kjklj?')
                    """
                    build As of all other matrices 
                    """
                    non_c = cond_array[cond_array != c]
                    cur_nu_mat_non_c = cur_nu_mat_whole[:,non_c] # this is a matrix of # ensembles X all labels that are not the current one
                    if len(non_c) != num_unique_conditions - 1:
                        raise ValueError('how?')
                    full_A_not_in_label = full_A_hat_individual_tensor[:,:,non_c]

                    assert full_A_not_in_label.shape[2] == num_unique_conditions - 1
                    if repeat == 0 and c==  0 : print('num unique conditions is %d'%(num_unique_conditions - 1))
                    full_A_not_in_label_vstack = np.vstack([full_A_not_in_label[:,:,layer].T 
                                                            for layer in range(num_unique_conditions - 1)]) # this is (p * (num_conds - 1)) X N
                    
                    if full_A_not_in_label_vstack.shape[1] != n_neurons:
                        raise ValueError('something went wrong! full_A_not_in_label_vstack.shape[1] = %d, while n_ensembles = %d'%(full_A_not_in_label_vstack.shape[1], n_neurons))
                    nus_list = [cur_nu_mat_non_c[:,col].reshape((-1,1)) for col in range(num_unique_conditions - 1)]
                    nus_list_vstack = np.vstack([nu_i.reshape((-1,1)) for nu_i in nus_list]).reshape((-1,1))
                    left_nus = np.vstack([np.diag(nu_i.flatten()) for nu_i in nus_list]) # this is (pX(num conditions-1)) X p. e.g. the first (top, number 0) pXp matrix is a daigonal matrix whoch j,j entry is whether the j-th ensemble need to be similar between current label and label 0
                    assert left_nus.shape[0] == nus_list_vstack.shape[0]
                    assert full_A_not_in_label_vstack.shape[0] == nus_list_vstack.shape[0]
                    if left_nus.shape[1] != n_ensembles:
                        raise ValueError('left_nus.shape[1] = %d != n_ensembles = %d'%(left_nus.shape[1] , n_ensembles))
                    #new_A_baseline = full_A_not_in_label_vstack * np.vstack(nus_list)
                    

                    #######################################################################################
                    labels_loc = np.where(labels == label)[0]
            
                    phi_for_label = full_phi_hat_3d[:,:, labels_loc]
            
                    phi_for_label_2d = np.vstack([phi_for_label[:,:,trial] for trial in range(phi_for_label.shape[2])]) # this is total T X p (T x num_trials for the condition) X p

                    edges = np.linspace(0, phi_for_label_2d.shape[0], len(labels_loc) + 1)
                    
                    #######################################################################################
                    data_for_label = data[:,:, labels_loc]
            
                    data_for_label_2d = np.hstack([data[:,:,trial]
                                                   for trial in range(data_for_label.shape[2])]) # this is N X total T
                    if another_update_for_A:
                        if params_lasso_solver['solver'] == 'inv':
                            print('update 000000000000000000000000000000')
                            print('shapes:\n')
                            
                            print( 'full_A_not_in_label_vstack.shape: %s'%str(full_A_not_in_label_vstack.shape))
                            print( 'left_nus.shape: %s'%str(left_nus.shape))

                            left = np.vstack([phi_for_label_2d, 
                                        lambda_similarity * left_nus])
                            
                            right = np.vstack([data_for_label_2d.T, 
                                        lambda_similarity * full_A_not_in_label_vstack*nus_list_vstack  ])
                            
                                                            
                            ##################### do not take time points that are completely zero

                            left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                            left = left[left_not_0]
                            right = right[left_not_0]
                            
                            addi = solve_Lasso_style(                        
                            left, right, l1 = l1, x0 = [], params = params_lasso_solver,
                                                                lasso_params = params_lasso_A ,random_state = seed)
                            
                            assert np.all(np.abs(addi).sum(0) < 2*factor_A), "values of full_A_hat_individual will be too large!"
                            
                            full_A_hat_individual_tensor[:,:,c] = addi.T
                            # TODO limit A
    
                            
                        else: # this is the 2nd time we update A
                            print('other update here!')
                            
                            addi = []
                            for n in range(n_neurons):
                                left = np.vstack([phi_for_label_2d, 
                                            lambda_similarity * left_nus])
                                
                                right = np.vstack([data_for_label_2d[n,:].reshape((-1,1)) , 
                                            lambda_similarity * full_A_not_in_label_vstack[:,n].reshape((-1,1))  ])
                                
                                ##################### do not take time points that are completely zero
    
                                left_not_0 = np.where((left.sum(1) != 0) & (right.sum(1) != 0)  )[0] 
                                left = left[left_not_0]
                                right = right[left_not_0]
                                
                                ################## 
                                
                                
                                
                                addi_cur = solve_Lasso_style(
                                left,                 
                                right,           
                                l1 = l1, x0 = [], params = params_lasso_solver,
                                                                   lasso_params = params_lasso_A, random_state = seed).reshape((-1,1))
                                assert np.all(np.abs(addi).sum(0) < 2*factor_A), "values of full_A_hat_individual will be too large!"
                                
                                
                                
                                
                                addi.append(addi_cur)  
                            addi = np.hstack(addi)     # addi is p X N
                            
                            
               
                            full_A_hat_individual_tensor[:,:,c] = addi.T #.append(addi)
            
        print('mmax val of ensembles3 is %d'%np.nanmax(full_A_hat_individual_tensor)) # TODO add comment the problem is 
        """
        now update Phi
        """
        if another_update_for_phi:
            if l2_phi > 0:
                # data is neurons X time X trials. 
                print('data shape is %s'%str(data.shape))
                data_extend = np.vstack([data, np.zeros((n_ensembles,T,data.shape[2]))])
                l2_phi_comp = l2_phi*np.repeat(np.expand_dims(np.eye(n_ensembles),2), full_A_hat_individual_tensor.shape[2], axis = 2)
                assert l2_phi_comp.shape == (n_ensembles,n_ensembles,full_A_hat_individual_tensor.shape[2])
                print('full_A_hat_individual_tensor.shape is %s'%str(full_A_hat_individual_tensor.shape))
                full_A_hat_individual_tensor_extend = np.vstack([full_A_hat_individual_tensor, l2_phi_comp  ])
                #raise ValueError('?!')
                
            else:
                data_extend = data.copy()
                full_A_hat_individual_tensor_extend = full_A_hat_individual_tensor.copy()
                print('full_A_hat_individual_tensor.shape is %s'%str(full_A_hat_individual_tensor.shape))
                
            full_phi_hat_3d_new = []    
            for c, label in enumerate(labels_unique_order):
                cur_A = full_A_hat_individual_tensor_extend[:, :, c]
                print('cur A shape %s'%str(cur_A.shape))
                print('datashape %s' %str(data.shape))
                
                where_label = np.where(labels == label)[0]
                full_phi_hat_3d_label = np.dstack([np.vstack([solve_Lasso_style(cur_A, data_extend[:,t,ind].flatten(),
                                                                              **params_update_phi).reshape((1,-1)) for t in range(data.shape[1])])

                # full_phi_hat_3d_label = np.dstack([np.vstack([solve_Lasso_style(cur_A, data_extend[:,t,ind].flatten(),
                #                                                               **params_update_phi).reshape((1,-1),
                #                                                                                            params = {'solver' : 'inv'} ) for t in range(data.shape[1])])
                                     for ind in where_label])
                full_phi_hat_3d_new.append(full_phi_hat_3d_label)
                
            full_phi_hat_3d_updated = np.dstack(full_phi_hat_3d_new).copy()
    elif for_each_trial_or_condition in [ 'cont_trial'] or split_A:  
        # make sure that cont trial can exist         
        dict_params_cont = {  'durs': durs,    'verbose': verbose,    'lambda_similarity': lambda_similarity,    'nu': nu,      'n_ensembles': n_ensembles,
                            'solve_Lasso_style': solve_Lasso_style,    'l1': l1, 'axes2ensembles':axes2ensembles,
                            'params_lasso_solver': params_lasso_solver,    'seed': seed,    'with_graph': with_graph,
                                    'decor_A': decor_A,    'factor_A': factor_A,    'func_normalize_A': func_normalize_A,    'identity': identity,
                                    'num_unique_conditions': num_unique_conditions,    'for_each_trial_or_condition': for_each_trial_or_condition,
                                    'with_nu': with_nu,    'num_repeats': num_repeats,  #  'cur_nu_mat_whole': cur_nu_mat_whole,
                                    'n_neurons': n_neurons,    'another_update_for_A': another_update_for_A, #   'cur_A': cur_A,
                                    'l2_phi': l2_phi,    'T': T  ,
                                    'split_A':split_A, 'n_trials':n_trials,
                                    'another_update_for_phi':another_update_for_phi, 'params_update_phi':params_update_phi}

        
        full_A_hat_individual_tensor, full_phi_hat_3d_updated, full_A_hat_individual = cal_continous_trials(labels, full_phi_hat_3d, full_A_hat, 
                                                                                                            data = data, is_cont = for_each_trial_or_condition in [ 'cont_trial'] , 
                             cont_labels = np.vstack([numbers2tuples[lab] for lab in labels]),
                             cont_axis_list = cont_axis_list,  
                                 dict_params_cont = dict_params_cont, #is_cont = True,
                                 numbers2tuples = numbers2tuples, nu_full_each_axes_dict = nu_full_each_axes_dict
                                 )


        
        
    else:
        raise ValueError('?')
            
    print('mmax val of ensembles is %d'%np.nanmax(full_A_hat_individual_tensor))        
    params_save['full_A_hat_individual'] = full_A_hat_individual
    params_save['nu_full_each_axes_dict'] = nu_full_each_axes_dict
    if 'labels_unique_order' not in params_save:
        params_save['labels_unique_order'] = labels_unique_order
        
        
    
    
                
            
        
    if return_data_recos:
        """
        data recos
        """
        data_recos = []
        data_recos_updated = []
        for c, label in enumerate(labels_unique_order):            
            cur_A = full_A_hat_individual_tensor[:,:,c]
            labels_loc = np.where(labels == label)[0]
            data_recos.append(np.dstack([cur_A  @  full_phi_hat_3d[:,:,ind].T for ind  in labels_loc] ))
            data_recos_updated.append(np.dstack([cur_A  @  full_phi_hat_3d_updated[:,:,ind].T for ind  in labels_loc] ))
        data_recos = np.dstack(data_recos)
        data_recos_updated = np.dstack(data_recos_updated)
            
        return full_phi_hat_3d_updated, full_phi_hat_3d, full_A_hat,full_A_hat_individual_tensor, data_recos, data_recos_updated, params_save
    
    return full_phi_hat_3d_updated, full_phi_hat_3d, full_A_hat, full_A_hat_individual_tensor, params_save
                
    








    
    


    
    
    
def check_if_labels_batches(labels):
    labels_visited = []
    for label1, label2 in zip(labels[:-1], labels[1:]):
        if label1 != label2:
            labels_visited.append(label1)
        if label2 in labels_visited:
            return False
    return True    
    
    
    
    



          

def mov_avg(c, axis = 1, wind = 5):
    if len(c.shape) == 2 and axis == 1:
        return np.hstack([np.mean( c[:,np.max([i-wind, 1]):np.min([i+wind, c.shape[1]])],1).reshape((-1,1))
              for i in range(c.shape[1])])
    elif len(c.shape) == 2 and axis == 0:
        return mov_avg(c.T, axis = 1).T
    elif len(c.shape) == 3: # and axis == 0:
        return np.dstack([mov_avg(c[:,:,t], axis = axis) for t in range(c.shape[2])  ])
    else:
        raise ValueError('how did you arrive here? data dim is %s'%str(c.shape))
    
    
    


def norm(mat):
    """
    Parameters
    ----------
    mat : np.ndarray
        l2 norm of mat.

    Returns
    -------
    TYPE
        DESCRIPTION.

    """
    if len(mat.flatten()) == np.max(mat.shape):
        return np.sqrt(np.sum(mat**2))
    else:
        _, s, _ = np.linalg.svd(mat, full_matrices=True)
        return np.max(s)
    
def mkCorrKern(params = {}):
    """
    Parameters
    ----------
    params : TYPE, optional
        DESCRIPTION. The default is {}.

    Returns
    -------
    corr_kern : TYPE
        DESCRIPTION.

    """
    # Make a kernel
    raise ValueError('not clear, depracated')
    params = {**{'w_space':3,'w_scale':4,'w_scale2':0.5, 'w_power':2,'w_time':0}, **params}
    dim1  = np.linspace(-params['w_scale'], params['w_scale'], 1+2*params['w_space']) # space dimension
    dim2  = np.linspace(-params['w_scale2'], params['w_scale2'], 1+2*params['time']) # time dimension
    corr_kern  = gaussian_vals(dim1, std = params['w_space'], mean = 0 , norm = True, 
                               dimensions = 2, mat2 = dim2, power = 2)
    return corr_kern
    
def checkCorrKern(data, corr_kern, param_kernel = 'embedding', recreate = False, know_empty = False,
                  labels2distance_type =  'identity_boolean'):
    if len(corr_kern) == 0: #raise ValueError('Kernel cannot ')
        if not know_empty: warnings.warn('Empty Kernel - creating one')
        if param_kernel == 'embedding' and recreate:
            corr_kern  = mkDataGraph(data, corr_kern) 
        elif  param_kernel == 'convolution'  and recreate:
            corr_kern  = mkCorrKern(corr_kern) 
        
        else:
            raise ValueError('Invalid param_kernel. Should be "embedding" or "convolution"')
            
    return corr_kern

def create_readme(text, name="readme.txt", directory=None):
    """Creates a text file with the given name and content in the specified directory."""
    if directory:
        os.makedirs(directory, exist_ok=True)  # Ensure the directory exists
        file_path = os.path.join(directory, name)
    else:
        file_path = name

    with open(file_path, "w") as file:
        file.write(text)

    return file_path  # Return the file path for reference

def checkEmptyList(obj):
    return isinstance(obj, list) and len(obj) == 0
    



def is_1d(mat):
    if isinstance(mat,list): mat = np.array(mat)
    elif isinstance(mat, np.ndarray): pass
    else: raise ValueError('Mat must be numpy array or a list')
    return np.max(mat.shape) == len(mat.flatten())

def is_2d(mat, dim = 2):
    """
    Check if a matrix is 2-dimensional.

    Parameters
    ----------
    mat : list or np.ndarray
        The input matrix.
    dim : int, optional
        The number of dimensions to check for. The default is 2.

    Returns
    -------
    bool;         True if the matrix is 2-dimensional, False otherwise.

    Raises
    ------
    ValueError;         If `mat` is not a list or a `numpy` array.

    """    
    if isinstance(mat,list): mat = np.array(mat)
    elif isinstance(mat, np.ndarray): pass
    else: raise ValueError('Mat must be numpy array or a list')
    return (len(mat.shape) > dim and (np.array(mat.shape) != 1).sum() == dim) or (len(mat.shape) == dim and (1 not in mat.shape))


    

def normalizeDictionary(D, cutoff = 1):
    D_norms = np.sqrt(np.sum(D**2,0))       # Get the norms of the dictionary elements 
    D       = D @ np.diag(1/(D_norms*(D_norms>cutoff)/cutoff+(D_norms<=cutoff))); # Re-normalize the basis
    return D

    
def dictionary_update(dict_old, A, data, step_s, GD_type = 'norm', params ={}):    
    if params['likely_from'].lower() == 'gaussian':
        
        dict_new = takeGDStep(dict_old, A, data, step_s, GD_type, params)
    if params['likely_from'].lower() == 'poisson':
        dict_new = takeGDStepPoisson(dict_old, A, data, step_s, GD_type, params)        
    if not params.get('normalizeSpatial'):
        dict_new = normalizeDictionary(dict_new,1)                            # Normalize the dictionary

    dict_new[np.isnan(dict_new)] = 0
    if np.mean(dict_new) < 1e-9:
        dict_new += np.random.rand(*dict_new.shape)
    return dict_new
    


    
    
    
def dictInitialize(phi = [], shape_dict = [], norm_type = 'unit', to_norm = True, params = {},  to_norm_mat = False):

    """
    Parameters
    ----------
    phi : list of lists or numpy array or empty list
        The initializaed dictionary
    shape_dict : tuple or numpy array or list, 2 int elements, optional
        shape of the dictionary. The default is [].
    norm_type : TYPE, optional
        DESCRIPTION. The default is 'unit'.
    to_norm : TYPE, optional
        DESCRIPTION. The default is True.
    dist : string, optional
        distribution from which the dictionary is drawn. The default is 'uniforrm'.        
    Raises
    ------
    ValueError
        DESCRIPTION.
        
    Returns
    -------
    phi : TYPE
        The output dictionary

    """

    if len(phi) == 0 and len(shape_dict) == 0:
        raise ValueError('At least one of "phi" or "shape_dict" must not be empty!')
    if len(phi) > 0:
        if to_norm_mat:
            return np.abs(norm_mat(phi, type_norm = norm_type, to_norm = to_norm))
        else:
            return np.abs(phi)
    else:
        #if dist == 'uniform':
        phi = createMat(shape_dict, params)


        return dictInitialize(phi, shape_dict, norm_type, to_norm,  params)

    
def createMat(shape_dict,  params = {} ):
    """
    Parameters
    ----------
    shape : TYPE
        DESCRIPTION.
    dist : TYPE, optional
        DESCRIPTION. The default is 'uniforrm'.
    params : TYPE, optional
        DESCRIPTION. The default is {'mu':0, 'std': 1}.

    Raises
    ------
    ValueError
        DESCRIPTION.

    Returns
    -------
    TYPE
        DESCRIPTION.
    """

    params = {**{'mu':0, 'std': 1}, **params}
    dist = params['dist_init']

    if dist == 'uniform':
        return np.random.uniform(params[uniform_vals[0]], params[uniform_vals[1]], size = (shape_dict[0], shape_dict[1]) )
    elif dist == 'rand':
        return np.random.rand(shape_dict[0], shape_dict[1]) 
    elif dist == 'randn':
        return np.random.randn(shape_dict[0], shape_dict[1]) 
    elif dist == 'norm':
        return params['mu'] + np.random.randn(shape_dict[0], shape_dict[1])*params['std']
    elif dist == 'zeros':
        return np.zeros((shape_dict[0], shape_dict[1]))
    else:
        raise ValueError('Unknown dist for createMat')
#<h1 id="header">Header</h1>    

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)


def singleGaussNeuroInfer(lambdas_vec, data, phi, l1,  nonneg, A = [], 
                          ratio_null = 0.1, params = {}, grannet = False):
    # Use quadprog  to solve the weighted LASSO problem for a single vector
    #include_Ai = params['grannet_params']['include_Ai']

    if phi.shape[1] != len(lambdas_vec):
        raise ValueError('Dimension mismatch!')  
   
    # Set up problem
    data = data.flatten()                                           # Make sure time-trace is a column vector
    lambdas_vec = lambdas_vec.flatten()                             # Make sure weight vector is a column vector
    p      = len(lambdas_vec)                                       # Get the numner of dictionary atoms
    
    if  params['solver'].lower() == 'spgl1':
         l1 *= lambdas_vec.mean()                                                       ## Run the weighted LASSO to get the coefficients    
    if len(data) == 0 or np.sum(data**2) == 0:
        A = np.zeros(p)                                             # This is the trivial solution to generate all zeros linearly.
        print('data activity is zero for this neuron and state')
        #raise ValueError('zeros again')
    else:
        if nonneg:
            if A == [] or (A==0).all():
                A = solve_qp(2*(phi.T @ phi) , -2*phi.T @ data + l1*lambdas_vec, 
                             solver = params['solver_qp'] )       # Use quadratic programming to solve the non-negative LASSO
                if np.nan in A: raise ValueError('nan')
                ub = np.inf*np.ones((p,1)),
            else:

                if (not is_pos_def(phi.T @ phi)) and  (params['deal_nonneg'] == 'make_nonneg'):
                    phi_T_phi = phi.T @ phi + epsilon
                elif not is_pos_def(phi.T @ phi):
                    A = solve_Lasso_style(phi, data, l1, [], params = params, random_state = 0).flatten()
                    #solver_L1RLS(phi, data, l1, zeros(N2, 1), params )         # Solve the weighted LASSO using TFOCS and a modified linear operator
                    if params['norm_by_lambdas_vec']:
                        A = A.flatten()/lambdas_vec.flatten();              # Re-normalize to get weighted LASSO values
                        #  consider changing to oscar like here https://github.com/vene/pyowl/blob/master/pyowl.py 
                else:
                    phi_T_phi = phi.T @ phi
                    
                A = solve_qp(2*(phi.T @ phi),-2*phi.T @ data+l1*lambdas_vec, 
                             solver = params['solver_qp'] )         # Use quadratic programming to solve the non-negative LASSO

                if np.isnan(A).any(): 
                    raise ValueError('There are nan values is A')
          
        else:
           A = solve_Lasso_style(phi, data, l1, [], params = params, random_state = 0).flatten()
           #solver_L1RLS(phi, data, l1, zeros(N2, 1), params )         # Solve the weighted LASSO using TFOCS and a modified linear operator
           if params['norm_by_lambdas_vec']:
               A = A.flatten()/lambdas_vec.flatten();              # Re-normalize to get weighted LASSO values
               #  consider changing to oscar like here https://github.com/vene/pyowl/blob/master/pyowl.py 
    if params['nullify_some']:
        A[A<ratio_null*np.max(A)] = 0;    
    return A






def solve_Lasso_style(A, b, l1, x0 = [], params = {}, lasso_params = {}, random_state = 0):
  """
      Solves the l1-regularized least squares problem
          minimize (1/2)*norm( A * x - b )^2 + l1 * norm( x, 1 ) 
          
    Parameters
    ----------
    A : TYPE
        DESCRIPTION.
    b : TYPE
        DESCRIPTION.
    l1 : float
        scalar between 0 to 1, describe the reg. term on the cofficients.
    x0 : TYPE
        DESCRIPTION.
    params : TYPE, optional
        DESCRIPTION. The default is {}.
    lasso_params : TYPE, optional
        DESCRIPTION. The default is {}.
    random_state : int, optional
        random state for reproducability. The default is 0.

    Raises
    ------
    NameError
        DESCRIPTION.

    Returns
    -------
    x : np.ndarray
        the solution for min (1/2)*norm( A * x - b )^2 + l1 * norm( x, 1 ) .

  lasso_options:
               - 'inv' (least squares)
               - 'lasso' (sklearn lasso)
               - 'fista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.FISTA.html)
               - 'omp' (https://pylops.readthedocs.io/en/latest/gallery/plot_ista.html#sphx-glr-gallery-plot-ista-py)
               - 'ista' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.ISTA.html)       
               - 'IRLS' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.IRLS.html)
               - 'spgl1' (https://pylops.readthedocs.io/en/latest/api/generated/pylops.optimization.sparsity.SPGL1.html)
               
               
               - . Refers to the way the coefficients should be claculated (inv -> no l1 regularization)
  """ 
  if np.isnan(A).any():
      print('there is a nan in A')
      #input('ok? solve_Lasso_style')
  if len(b.flatten()) == np.max(b.shape):
      b = b.reshape((-1,1))
  if 'solver' not in params.keys():
      warnings.warn('Pay Attention: Using Default (inv) solver for updating A. If you want to use lasso please change the solver key in params to lasso or another option from "solve_Lasso_style"')
  params = {**{'threshkind':'soft','solver':'inv','num_iters':10}, **params}

  if params['solver'] == 'inv' or l1 == 0:
      if len(b.flatten()) == np.max(b.shape ):
          x = linalg.pinv(A) @ b.reshape((-1,1))
      else:
          x = linalg.pinv(A) @ b
      
      
  elif params['solver'] == 'nnls':
      # solves for x:   minimize ||Ax - b||²  subject to x ≥ 0
      if b.ndim == 1 or len(b.flatten()) == max(b.shape):
          
          try:
              x = nnls(A , b.flatten())[0]
          except:
              print('shape A: %s; shape b : %s; b.ndim == 1 %s; len(b.flatten() == max(b.shape)) %s '%(str(A.shape), str(b.shape), str(b.ndim == 1), len(b.flatten()) == max(b.shape)) )
      else:
        # identify which dimensions match
                   
        # find x
        if b.shape[0] == A.shape[0]:
            x = np.hstack([(nnls(A , b[:,j].flatten())[0]).reshape((-1,1)) for j in range(b.shape[1])])
        elif b.shape[0] == A.shape[1]:
            x = np.hstack([(nnls(A , b[j,:].flatten())[0]).reshape((-1,1)) for j in range(b.shape[1])])
        else:
            raise ValueError('dimension mismatch?! A dim is %s; x dim is %s; b dim dim is %s'%(str(A.shape), str(x.shape), str(b.shape)))
          
        # make sure that x is compatible with structures
        assert x.shape[0] == A.shape[1]

  elif params['solver'] == 'lasso' :
      #fixing try without warm start
    clf = linear_model.Lasso(alpha=l1,random_state=random_state, **lasso_params)

    #input('ok?')
    clf.fit(A,b.flatten() )     #reshape((-1,1))
    x = np.array(clf.coef_)

  elif params['solver'].lower() == 'fista' :
      Aop = pylops.MatrixMult(A)
  
      #if 'threshkind' not in params: params['threshkind'] ='soft'
      #other_params = {'':other_params[''],
      x = pylops.optimization.sparsity.FISTA(Aop, b.flatten(), niter=params['num_iters'],
                                             eps = l1 , threshkind =  params.get('threshkind') )[0]
  elif params['solver'].lower() == 'ista' :

      #fixing try without warm start
      if 'threshkind' not in params: params['threshkind'] ='soft'
      Aop = pylops.MatrixMult(A)
      x = pylops.optimization.sparsity.ISTA(Aop, b.flatten(), niter=params['num_iters'] , 
                                                 eps = l1,threshkind =  params.get('threshkind'))[0]
      
  elif params['solver'].lower() == 'omp' :
  
      Aop = pylops.MatrixMult(A)
      x  = pylops.optimization.sparsity.OMP(Aop, b.flatten(), 
                                                 niter_outer=params['num_iters'], sigma=l1)[0]     
  elif params['solver'].lower() == 'spgl1' :
      print('spgl1')
      Aop = pylops.MatrixMult(A)
      x = pylops.optimization.sparsity.SPGL1(Aop, b.flatten(),iter_lim = params['num_iters'], 
                                             tau = l1)[0]      
      
  elif params['solver'].lower() == 'irls' :
   
      Aop = pylops.MatrixMult(A)
      
      #fixing try without warm start
      x = pylops.optimization.sparsity.IRLS(Aop, b.flatten(),  nouter=50, espI = l1)[0]      
  else:     
    raise NameError('Unknown update c type')  
  return x


    

        


def identify_first(labels_str_list, return_meaning = True):
    """
    Identify the first occurrence index of each unique label.
    
    Parameters:
        labels_str_list (list or np.ndarray): List or array of labels (strings or hashable types).
        return_meaning (bool, optional): 
            If True, returns both a NumPy array of first occurrence indices 
            and a dictionary mapping each unique label to its first index. 
            If False, returns only the NumPy array. Default is True.
    
    Returns:
        np.ndarray: Indices of the first occurrence of each unique label.
        dict (optional): Mapping from label to its first occurrence index 
                         (only if return_meaning is True).
    """
    if not return_meaning:
        return np.array([np.where(np.array(labels_str_list) == desired_label)[0][0] for desired_label in np.unique(labels_str_list)])
    else:
        unique_labels_str =  np.unique(labels_str_list)
        array_meaning_labels = np.array([np.where(np.array(labels_str_list) == desired_label)[0][0] for desired_label in unique_labels_str])
        dict_meaning_labels = {desired_label : np.where(np.array(labels_str_list) == desired_label)[0][0] for desired_label in unique_labels_str }
        return array_meaning_labels, dict_meaning_labels, unique_labels_str
    
    
    
def check_if_labels_batches(labels):
    """
    Checks if the labels form valid batches where no label repeats within a batch.

    This function iterates through pairs of consecutive labels to ensure that no label
    is repeated within a batch. It assumes that the input is a sequence where each
    batch consists of consecutive labels and a valid batch is one where no label 
    appears more than once.

    Parameters:
    labels (list or array-like): A list or array of labels to check for valid batching.

    Returns:
    bool: True if each batch of labels contains unique labels, False otherwise.

    Example:
    >>> check_if_labels_batches(['a', 'b''c'])
    True
    >>> check_if_labels_batches(['a', 'b', 'a', 'b'])
    False
    """
    labels_visited = []
    for label1, label2 in zip(labels[:-1], labels[1:]):
        if label1 != label2:
            labels_visited.append(label1)
        if label2 in labels_visited:
            return False
    return True




    
def create_cbar_pp(color_matrix = [], label_basic = 'k', alpha = 0.99, to_plot = True, inv = False, rot = 0, w_cmap = False, cmap = 'Blues',
                add_yticklabels = False, data_it = True, vmin = 0,  vmax = 1, fig = [], ax = [], fig_size = (0.4,7), 
                num_changes = 10, cbar_label = 'states', label_fontsize = 20, addi_yticklabel = '', yticklabels_or = [], fontsize_ticks = 20):
    """
    Create a colorbar plot with customizable options.
    
    Parameters:
    -----------
    color_matrix : numpy.ndarray, optional
        Matrix of data for color mapping.
    label_basic : str, optional
        Label base for colorbar ticks.
    alpha : float, optional
        Alpha (transparency) value for plotted data points.
    to_plot : bool, optional
        If True, create and display the colorbar plot; if False, only set color map parameters.
    inv : bool, optional
        Reverse the color map if True.
    rot : int, optional
        Rotation angle for colorbar tick labels.
    w_cmap : bool, optional
        Use a custom colormap based on color_matrix if True.
    cmap : str or ListedColormap, optional
        Colormap name or custom colormap to use.
    add_yticklabels : bool, optional
        If True, add y-axis tick labels to the colorbar.
    data_it : bool, optional
        If True, interpret color_matrix as data points for scatter plot; if False, generate colorbar based on vmin and vmax.
    vmin : float, optional
        Minimum value for color mapping.
    vmax : float, optional
        Maximum value for color mapping.
    fig : plt.Figure, optional
        Existing Matplotlib figure to use for the colorbar plot.
    ax : plt.Axes, optional
        Existing Matplotlib axes to use for the colorbar plot.
    fig_size : tuple, optional
        Figure size (width, height) for the colorbar plot.
    num_changes : int, optional
        Number of changes.
    cbar_label : str, optional
        Label for the colorbar.
    label_fontsize : int, optional
        Font size for the colorbar label and tick labels.
    
    Returns:
    --------
    ax : plt.Axes or int
        Matplotlib axes object if to_plot is True, otherwise 0.
    cbar : matplotlib.colorbar.Colorbar or int
        Colorbar object if to_plot is True, otherwise 0.
    cmap : str or ListedColormap
        Used colormap.
    
    Note:
    -----
    This function allows you to create and customize a colorbar plot with various options.
    """
    if w_cmap and len(cmap) == 0:
        if not inv:
            cmap = ListedColormap(color_matrix.T)
        else:
            cmap = ListedColormap(color_matrix.T[::-1,:])
    elif len(cmap) == 0:
        cmap = 'Blues'
    elif checkEmptyList(color_matrix):
        data_it = False
        color_matrix = np.zeros((1,num_changes))
        
        
    if to_plot:
        # Create a figure with a colorbar
        base_vals = np.linspace(vmin, vmax, num_changes)
        if checkEmptyList(ax):
            fig, ax = plt.subplots(figsize = fig_size)

        if data_it:
            list_min_max = [color_matrix.min(), color_matrix.max()]
            data_min = np.max(np.abs(list_min_max))*(-1)

            data_max = np.max(np.abs(list_min_max))
            im =  ax.scatter(np.linspace(data_min,data_max,3), 
                           np.linspace(data_min,data_max,3),
                            c = np.linspace(data_min,data_max,3), cmap = cmap)

        else:
            im = ax.scatter(base_vals,base_vals, 
                        c =  base_vals, cmap  = cmap , alpha = alpha)
       
        len_each =  1/color_matrix.shape[1]*0.5
        
        if add_yticklabels:
            cbar = plt.colorbar(im, ax=ax, ticks=np.arange(color_matrix.shape[1]), cax = ax, label = cbar_label)

            cbar.set_label(cbar_label, size = label_fontsize)

            cbar.set_ticks(np.linspace(len_each,1-len_each,color_matrix.shape[1]) )
            if   checkEmptyList(yticklabels_or ):

                if not inv and len(label_basic) > 0:
                    cbar.ax.set_yticklabels(['$%s_{%d}$'%(label_basic,i) for i in np.arange(1, color_matrix.shape[1] +1)], 
                                        fontsize = fontsize_ticks, rotation = rot + 90)
                elif len(label_basic) > 0:
                    cbar.ax.set_yticklabels(['$%s_{%d}$'%(label_basic,i) for i in np.arange(1, color_matrix.shape[1] +1)[::-1]], 
                                        fontsize = fontsize_ticks, rotation = rot)
                else:
                    cbar.set_ticks(base_vals)
                    base_vals = ['%.2f'%i + addi_yticklabel for i in base_vals]
                    cbar.ax.set_yticklabels(base_vals, 
                                        fontsize = fontsize_ticks, rotation = rot)
            else:
                cbar.set_ticks(base_vals)
                
                cbar.ax.set_yticklabels(yticklabels_or, 
                                    fontsize = fontsize_ticks, rotation = rot)
        else:
            cbar = fig.colorbar(im,  cax = ax)

            
        fig.tight_layout()

        return ax, cbar, cmap, fig
    return 0,0, cmap

    



def MovToD2(mov, dimension = 2): 
    """
    PAY ATTENTION! THIS ONE IS NOT FOR GRANNET. THE EXPECTATION IS  [pixels X pixels X time]
    Parameters
    ----------
    mov : can be list of np.ndarray of frames OR 3d np.ndarray of [pixels X pixels X time]
        The data

    Returns
    -------
    array 
        a 2d numpy array of the movie, pixels X time 
    
    let mov be N x T x k.
    if dimension == 0 ->  output is [Nk x T]
    if dimension == 1 -> output is [N x Tk]
    """
    if isinstance(mov, list):
        return np.hstack([frame.flatten().reshape((-1,1)) for frame in mov])
    elif isinstance(mov, np.ndarray) and len(np.shape(mov)) == 2:
        return mov
        #elif isinstance(mov, np.ndarray) and len(np.shape(mov)) == 2:
        #    return np.hstack([mov[:,:,frame_num].flatten().reshape((-1,1)) for frame_num in range(mov.shape[2])])
    elif isinstance(mov, np.ndarray) and len(np.shape(mov)) == 3:
        print('start calculated movtod2')
        to_d2_return = np.hstack([mov[:,:,frame_num].flatten().reshape((-1,1)) for frame_num in range(mov.shape[2])])  
        print('end calculated movtod2')
        return   to_d2_return
    else:
        raise ValueError('Unrecognized dimensions for mov (cannot change its dimensions to 2d)')
    

    
    
    
def D2ToMov_inv(data_2d, shape )    :
    
    # get_data in shape of (pix X pix) X time
    # give pixels X pixels X time 3d 
    return np.dstack([data_2d[:,k].reshape(shape) for k in range(data_2d.shape[1])])
    
    
def normalize_to_95_perc(d): # normalize data
    stacked = MovToD2_grannet(d)
    ratio_norm = np.percentile(stacked,99.999,axis = 1); #print(ratio_norm)
    
    d = d/ ratio_norm.reshape((-1,1,1))
    return d
        
    
def D2ToMov(mov, frameShape, type_return = 'array'):
    """
    Parameters
    ----------
    mov : TYPE
        DESCRIPTION.
    frameShape : TYPE
        DESCRIPTION.
    type_return : string, can be 'array' or 'list', optional
        The default is 'array'.
        
    Raises
    ------
    ValueError - if dimensions do not fit


    Returns
    -------
    list or np.ndarray (according to the input "type return") of frames with shape frameShape X time
    """
    
    if mov.shape[0] != frameShape[0]*frameShape[1] :
        raise ValueError('Shape of each frame ("frameShape") is not consistent with the length of the data ("mov")')
    if type_return == 'array':
        return np.dstack([mov[:,frame].reshape(frameShape) for frame in range(mov.shape[1])])
    elif type_return == 'list':     
        return [mov[:,frame].reshape(frameShape) for frame in range(mov.shape[1])]
    else:
        raise ValueError('Invalid "type_return" input. Should be "list" or "array"')
    

    
def spec_corr(v1,v2, to_abs = True):
  """
  absolute value of correlation
  """
  corr = np.corrcoef(v1.flatten(),v2.flatten())
  if to_abs:
      return np.abs(corr[0,1])
  return corr[0,1]
    
    
    
    


def mkDataGraph(data, params = {}, reduceDim = False, reduceDimParams = {}, graph_function = 'gaussian',
                K_sym  = True, use_former = False, data_name = 'none', toNormRows = True,
                graph_params = {},
                grannet = False):
    """
    Parameters
    ----------
    data : should be neurons X time OR neurons X p
        DESCRIPTION.
    params : TYPE, optional
        DESCRIPTION. The default is {}.
    reduceDim : TYPE, optional
        DESCRIPTION. The default is False.
    reduceDimParams : TYPE, optional
        DESCRIPTION. The default is {}.
    graph_function : TYPE, optional
        DESCRIPTION. The default is 'gaussian'.
    K_sym : TYPE, optional
        DESCRIPTION. The default is True.
    use_former : TYPE, optional
        DESCRIPTION. The default is True.
    data_name : TYPE, optional
        DESCRIPTION. The default is 'none'.
    toNormRows : TYPE, optional
        DESCRIPTION. The default is True.
    data - 3d case, needed only for grannet
    graph_params - paramgs_graph:
        'kernel_grannet_type' can be 'ind' or "averaged"  or "combination"  or 'one_kernel'

    Returns
    -------
    TYPE
        DESCRIPTION.
    """
    
    #if grannet and graph_params['kernel_grannet_type'] != 'ind' and checkEmptyList(data):
    #    raise ValueError('you must provide data for non-ind graph! but grannet type is %s'%graph_params['kernel_grannet_type'])
    if params['n_neighbors'] <= 1:
        raise ValueError("['n_neighbors'] must be > 1!")
    if not grannet or len(data.shape) == 2:# (grannet and graph_params['kernel_grannet_type'] == 'ind'):  # and checkEmptyList(data):
        #if params['labels2distance_type'] == 'identity_boolean' :
                
        """
        IN THIS CASE IT CALCULATES THE KERNEL AND RETURN A SPARSE MATRIX WITH VALUES ONLY IN THE KNN
        """
        reduceDimParams = {**{'alg':'PCA'},  **reduceDimParams}
        params = addKeyToDict(params_config,
                     params)
        if len(data.shape) == 3:
            data = np.hstack([data[:,:,i].flatten().reshape((-1,1)) for i in range(data.shape[2])])
            print('data was reshaped to 2d')
           
        if reduceDim:
            pca = PCA(n_components=params['n_comps'])
            data = pca.fit_transform(data)
        K= calcAffinityMat(data, params,  data_name, use_former, K_sym, graph_function, 
                              graph_params = graph_params,   grannet = grannet)   
        K =  K - np.diag(np.diag(K) ) 
    
        if toNormRows:
            K = K/K.sum(1).reshape((-1,1))

    else: # in case of grannet
        """
        IN THIS CASE IT CALCULATES THE distances between 2 neurons
        """
        raise ValueError('how did you arrive here? if grannet and not ind type, you should use the grannet function (mkDataGraph_grannet), not this one!')
        


    return K
    
def mkDataGraph_grannet(data, params = {}, reduceDim = False, reduceDimParams = {}, graph_function = 'gaussian',
                K_sym  = True, use_former = False, 
                data_name = 'none', toNormRows = True,  graph_params = {},
                grannet = False):
    """
    gets 3d data and return 3d kernel
    THIS ONE IS ONLY FOR THE AVERAGE CASE OR THE WEIGHTING
    only if graph_params['kernel_grannet_type'] == "averaged" or "combination"!
    """                  
    if 'trends'     in data_name:
        terms_and_labels_trend = np.load('grannet_trends_for_jupyter_results_march_2023.npy', allow_pickle=True).item()
        terms =  terms_and_labels_trend['terms']
        labels = terms_and_labels_trend['labels']
    
    path_exp = r'C:\Users\14434\Documents\GitHub\g_MILCCI_Python\g_MILCCI_Python\results_march_2023\try_submit_kernels'

    non_zeros = params['n_neighbors'] + 1
    if not grannet or (grannet and graph_params['kernel_grannet_type'] == 'ind'):  # and checkEmptyList(data):
        raise ValueError('should not use this function in this case!')
    else: # in case of grannet
        """
        1) create kernel for each !
        """
        if isinstance(data, np.ndarray) or ( isinstance(data, list) and np.array([data[0].shape[1] ==data_i.shape[1] for data_i in data]).all() ):
        
            kernels_inds = np.dstack([cal_dist(k, data, graph_params = params['graph_params'], 
                                    grannet = True, distance_metric = params['distance_metric'])
                            for k in range(data.shape[2])])
        else:
            kernels_inds = np.dstack([cal_dist(0,data_i, graph_params = params['graph_params'], 
                                    grannet = True, distance_metric = params['distance_metric'])
                            for data_i in data])

        print('finished stage 1: calculated distances')
        """
        2)     apply average or weighting   !
        """
        if graph_params['kernel_grannet_type'] == "combination" :
            """
            in this case we need to calculate the shared graph
            """
            shared_kernel = cal_dist(0, MovToD2_grannet(data), graph_params = params['graph_params'], 
                                    grannet = True, distance_metric = params['distance_metric']) 
            print('finished stage 2: calculated shared ker')
            
            

            dist_mat = kernel_combination(kernels_inds, shared_kernel, w = 0, graph_params = graph_params)
            print('finished stage 2: calculated combination')
            
        elif  graph_params['kernel_grannet_type'] == "averaged" :
            """
            in this case we need to calculate the shared graph
            """
            dist_mat = kernel_averaging( kernels_inds, w = [], graph_params  = params['graph_params']) 
            print('finished stage 2 calculated averaging')
            
        else:
            raise ValueError('graph_params["kernel_grannet_type"] need to be "averaged" or "combination"')

                
            
        """
        3) apply knn !
        """
        K = take_NN_given_dist_mat(dist_mat,non_zeros, K_sym = True, include_abs = True,  toNormRows = True)

         
 
        """
        4) normalize kernel  !
        """    
        
            
        K = np.dstack([normalize_K(K[:,:,k], toNormRows = toNormRows) for k in range(K.shape[2])])

    
        if K_sym:
            K = np.dstack([(K[:,:,i] + K[:,:,i].T)/2 for i in range(K.shape[2])])

        K = np.dstack([normalize_K(K[:,:,k], toNormRows = toNormRows) for k in range(K.shape[2])])
        try:
            plt.figure(figsize = (20,9))
            ss = int(str(datetime2.now()).split('.')[-1])
            sns.heatmap(pd.DataFrame(K[:,:,0], terms, terms), robust = True)
            plt.savefig(path_exp + os.sep + 'try_hea%d.png'%ss)
            plt.show()

            plt.close()
        except:
            print('did not print graph')
            
        return K
    
def normalize_K(K, toNormRows = True):
    K = K - np.diag(np.diag(K) )
    if toNormRows:
        K = K/K.sum(1).reshape((-1,1))  
        
    
    return K
    
     
        
        
        
        
        


def dist_vecs(vec1, vec2, distance_metric = 'euclidean'):
    if distance_metric == 'euclidean':
        return np.sum((vec1.flatten() - vec2.flatten())**2)
    else:
        raise ValueError('not defined yet :(')
    
    


    
def cal_dist(k, data, graph_params = {}, grannet = True, distance_metric = {}):
    """
    Calculates distances between time series data for GrNNEt.

    Args:
        k (int): The index of the condition to calculate distances for. WHICH CONDITION
        data (numpy.ndarray): The 3D input data for the Grannet analysis with dimensions N x T x conditions, where N is the number of nodes, T is the number of time points, and conditions is the number of conditions.
        graph_params (dict, optional): A dictionary of graph parameters to be used for the analysis. Default is `params_default['graph_params']`.
        grannet (bool, optional): Whether or not to use the Grannet method for analysis. Default is `True`.
        distance_metric (str, optional): The distance metric to be used for calculating the distances. Default is `params_default['distance_metric']`.

    Returns:
        numpy.ndarray: A 3D matrix of size N x N  containing the pairwise distances between the time series data for the specified condition. 
        
    THIS FUNCTION IS TO CALCULATE DISTANCES FOR GRANNET KERNEL GOALS
    data here is 3d: N X T X conditions
    useful for cases where graph_params['kernel_grannet_type'] is "averaged"  or "combination" 
    returns the kernel for condition k
    """
    num_rows = data.shape[0]
    num_conds = data.shape[1]
    
    if not grannet:
        raise ValueError('to use cal_dist you must be in grannet mode (this function is only for grannet)')
        

    """
    below is a 3d N x N matrix of distances 
    """
    if len(data.shape) == 3:
        dists_multi_d = np.vstack([[dist_vecs(data[n,:,k], data[n2,:,k], distance_metric = distance_metric)
               for n2 in range(num_rows)] 
              for n in range(num_rows)] )
    elif len(data.shape) == 2:
        print('calcultes together')
        dists_multi_d = np.vstack([[dist_vecs(data[n,:], data[n2,:], distance_metric = distance_metric)
               for n2 in range(num_rows)] 
              for n in range(num_rows)] )
        print('finished calculte together')
    return dists_multi_d
    
def kernel_averaging(data_3d, w = [], graph_params  =  {} ) :
    
    """
    Given a 3D matrix `data_3d` of shape (N, T, k) or a list of N x T matrices, this function calculates the weighted average of the kernels for each input in the list.
    
    Parameters:
    -----------
    - data_3d : numpy ndarray or list
        The input 3D matrix of shape (N, T, k) or a list of N x T matrices.
    
    - w : list, numpy array, tuple or float, optional
        The weights used for averaging the kernels. If empty or 0, the default weights from `graph_params['params_weights']` are used. If a number, `w` is treated as the weight for all input kernels.
    
    - graph_params : dict, optional
        A dictionary containing the graph parameters. Default is `{}`.
    
    Returns:
    --------
    - data_3d_weighted : numpy ndarray or list
        The weighted average of the kernels for each input in the list.
    """
    # data_3d can be a matrix of N X T X k or a list of N x T matrices
    #    this function is only for the case where graph_params['kernel_grannet_type'] == "averaged":
    # THIS FUNCTION IS CALLED AFTER! WE FOUND THE INDIVIDUAL KERNELS
    if graph_params['kernel_grannet_type'] != "averaged" : 
        raise ValueError('this function is only for the case where "graph_params["kernel_grannet_type"] != averaged"')


    if checkEmptyList(w) or  ( isinstance(w, numbers.Number) and  w ==0):
        w = graph_params['params_weights']  
    """
    make data a list
    """    
    if isinstance(data_3d, np.ndarray):
        data_3d_list = [data_3d[:,:,k] for k in range(data_3d.shape[2])]
        return_type = 'array'
    else:
        data_3d_list = data_3d.copy()
        return_type = 'list'
    
    """
    update w
    """
    if isinstance(w, numbers.Number):
        # if w is a number 
        w = [w]*len(data_3d)
    elif isinstance(w, (list , np.ndarray, tuple)) and len(w) == len(data_3d): # if w is a list
        pass
    else:
        raise ValueError("graph_params['params_weights'] must be a number or list with the same len as data but is %s, with len %d"%(str(graph_params['params_weights']), len(w)))
    
    
    """
    normalize w
    """   
    w_vec = w_vec / np.sum(w_vec)
    
    """
    averaging
    """    
    data_3d_weighted = [np.sum(np.dstack([data_3d_list[k_weight]*w_k[k_wight] for k_weight, w_k in enumerate(normalize_w_with_i(w_vec, k) ) ]), 2)        
        for k in range(len(data_3d_list))]   
        
    if return_type == 'array':
        data_3d_weighted = np.dstack(data_3d_wighted)  
        
    return data_3d_weighted 




def hard_thres_on_A(A_2d, non_zeros, direction = 1):
    # A        should be N X T 
    """
    Apply hard thresholding on each column of the input matrix A_2d by setting 
    all entries except the non_zeros highest in absolute value to zero. Returns
    the thresholded matrix with the same shape as the input.

    Parameters:
    A_2d (ndarray): Input matrix with shape (N, T).
    non_zeros (int): Number of entries to keep after thresholding.

    Returns:
    (ndarray): Thresholded matrix with the same shape as A_2d.
    """ 
    if direction == 0:
        A_ret = np.hstack([hard_thres_on_vec(A_2d[:,t], non_zeros).reshape((-1,1)) for t in range(A_2d.shape[1])])
    if direction == 1:
        A_ret = np.vstack([hard_thres_on_vec(A_2d[t,:], non_zeros).reshape((1,-1)) for t in range(A_2d.shape[0])])
    if A_ret.shape != A_2d.shape:
        raise ValueError('A shapes must be identical but %s and %s'%(str(A_2d.shape), str(A_ret.shape)))

    
    return A_ret
        
        
def hard_thres_on_vec(vec, non_zeros, include_abs = True)    :
    """
    Apply hard thresholding on the input vector vec by setting all entries
    except the non_zeros highest (in absolute value if include_abs=True, else
    highest in value) to zero. Returns the thresholded vector with the same 
    shape as the input.

    Parameters:
    vec (ndarray): Input vector.
    non_zeros (int): Number of entries to keep after thresholding.
    include_abs (bool): Whether to use the absolute value of entries when 
        computing the threshold. Default is True.

    Returns:
    (ndarray): Thresholded vector with the same shape as vec.
    """ 
    if non_zeros <= 0:
        raise ValueError('non zeros must be > 0 ')
    if non_zeros < 1:
        non_zeros = int(len(vec.flatten())*non_zeros)
    if include_abs:
        argsort_inds = np.argsort(np.abs(vec))[::-1]
    else:
        argsort_inds = np.argsort(vec)[::-1]      

    vec[argsort_inds[non_zeros:]] = 0 # nullify small values
    return vec
       
#ax.set_xticks(np.arange(corr_kern.shape[0]) + 0.5); ax.set_yticks(np.arange(corr_kern.shape[0])+0.5);    

def normalize_w_with_i(w,i) :
    """
    Normalize the input vector w such that its entries sum to one, then set
    the ith entry to one. Returns the normalized vector.

    Parameters:
    w (ndarray): Input vector.
    i (int): Index of the entry to set to one.

    Returns:
    (ndarray): Normalized vector with the same shape as w.
    """    
    w = w / np.sum(w)
    w[i] = 1
    w = w / np.sum(w)
    return w
    
    


    

def gaussian_vals(mat, std = 1, mean = 0 , norm = False, dimensions = 1, mat2 = [], power = 2):
    """
    check_again
    Parameters
    ----------
    mat : the matrix to consider
    std : number, gaussian std
    mean : number, optionalis 
        mean gaussian value. The default is 0.
    norm : boolean, optional
        whether to divide values by sum (s.t. sum -> 1). The default is False.

    Returns
    -------
    g : gaussian values of mat

    """    
    if dimensions == 1:
        if not checkEmptyList(mat2): warnings.warn('Pay attention that the calculated Gaussian is 1D. Please change the input "dimensions" in "gaussian_vals" to 2 if you want to consider the 2nd mat as well')


        g = np.exp(-((mat-mean)/std)**power)
        if norm: return g/np.sum(g)

    elif dimensions == 2:

        g = gaussian_vals(mat, std , mean , norm , dimensions = 1, mat2 = [], power = power)
        g1= g.reshape((1,-1))
        g2 = np.exp(-0.5/np.max([int(len((mat2-1)/2)),1])) * mat2.reshape((-1,1))
        g = g2 @ g1 
        
        g[int(g.shape[0]/2), int(g.shape[1]/2)] = 0
        if norm:
            g = g/np.sum(g)
        
    else:
        raise ValueError('Invalid "dimensions" input')
    return g
        
def cut_gauss(gaussian, t, wind, left, right):
    """
    Cuts a Gaussian array to fit within specified left and right boundaries.
    
    Parameters:
        gaussian (numpy.ndarray): The 1D Gaussian array to be cut.
        t (int): The center index around which the Gaussian array is considered.
        wind (int): The half-size of the window around the center index 't'.
        left (int): The left boundary index of the desired region.
        right (int): The right boundary index of the desired region.
    
    Returns:
        numpy.ndarray: The trimmed Gaussian array that fits within the specified boundaries.
    """
    if t + wind > right:
        diff = t + wind - right
        return gaussian[:-diff]
    elif t - wind < left:
        diff = left - (t - wind)
        return gaussian[diff:]
    else:
        return gaussian


def gaussian_convolve(mat, wind = 10, direction = 1, sigma = 1, norm_sum = True, plot_gaussian = False):
    """
    Convolve a 2D matrix with a Gaussian kernel along the specified direction.

    Parameters:
        mat (numpy.ndarray): The 2D input matrix to be convolved with the Gaussian kernel.
        wind (int, optional): The half-size of the Gaussian kernel window. Default is 10.
        direction (int, optional): The direction of convolution. 
            1 for horizontal (along columns), 0 for vertical (along rows). Default is 1.
        sigma (float, optional): The standard deviation of the Gaussian kernel. Default is 1.

    Returns:
        numpy.ndarray: The convolved 2D matrix with the same shape as the input 'mat'.

    Raises:
        ValueError: If 'direction' is not 0 or 1.
    """
    if direction == 1:
        gaussian = gaussian_array(2*wind,sigma)
        if norm_sum:
            gaussian = gaussian / np.sum(gaussian)
        if plot_gaussian:
            plt.figure(); plt.plot(gaussian)
        mat_shape = mat.shape[1]
        T_or = mat.shape[1]
        mat = pad_mat(mat, np.nan, wind)
        return np.vstack( [[ np.nansum(mat[row, t:t+2*wind]*gaussian)                    
                     for t in range(T_or)] 
                   for row in range(mat.shape[0])])
    elif direction == 0:
        return gaussian_convolve(mat.T, wind, direction = 1, sigma = sigma).T
    else:
        raise ValueError('invalid direction')    
    
    
    
    
def gaussian_array(length,sigma = 1  ):
    """
    Generate an array of Gaussian values with a given length and standard deviation.
    
    Args:
        length (int): The length of the array.
        sigma (float, optional): The standard deviation of the Gaussian distribution. Default is 1.
    
    Returns:
        ndarray: The array of Gaussian values.
    """
    x = np.linspace(-3, 3, length)  # Adjust the range if needed
    gaussian = np.exp(-(x ** 2) / (2 * sigma ** 2))
    normalized_gaussian = gaussian / np.max(gaussian)
    return normalized_gaussian
        




def interp_different_duration(list_data, type_interp = 'med'):
    """
    Interpolate data with different durations to a common length.

    This function takes a list of numpy arrays representing data with varying durations and
    interpolates each array to a common length based on the given type of interpolation.

    Parameters:
    - list_data (list): A list of numpy arrays, where each array represents data of different durations.
    - type_interp (str or float, optional): The type of interpolation to use. It can be one of the following:
        - 'med': Interpolate to the median duration among the input arrays.
        - 'min': Interpolate to the minimum duration among the input arrays.
        - 'max': Interpolate to the maximum duration among the input arrays.
        - 'mean': Interpolate to the mean duration among the input arrays.
        - A specific number: Interpolate all arrays to the given specific number of data points.

    Returns:
    - numpy.ndarray: A 3-dimensional numpy array containing the interpolated data. The dimensions are:
        (number of arrays in list_data, number of neurons, desired length of interpolated data).

    Raises:
    - ValueError: If an unfamiliar type_interp is provided or the desired length is not a positive number.
    """    

    
    lens = [data_s.shape[1] for data_s in list_data]
    if isinstance(type_interp, str):
        if type_interp == 'med':
            des_len = np.median(lens)
        elif type_interp == 'min':
            des_len = np.min(lens)
        elif type_interp == 'max':
            des_len = np.max(lens)
        elif type_interp == 'mean':
            des_len = np.mean(lens)
        else:
            raise ValueError('unfamiliar lens')
    else:
        des_len = type_interp
    if des_len <= 1:
            raise ValueError('type interp must be a positive number')

    numpy_points = [np.linspace(0,des_len, len_s) for len_s in lens]

    desired_points = np.arange(des_len)

    return [np.vstack([np.interp(desired_points, numpy_points[i], data_s[neuron,:]).reshape((1,-1)) for neuron in range(data_s.shape[0])])   for i, data_s in enumerate(list_data)]

    



    
    
    
    
#%%  Other pre-processing
def norm_mat(mat, type_norm = 'evals', to_norm = True):
  """
  This function comes to norm matrices by the highest eigen-value
  Inputs:
      mat       = the matrix to norm
      type_norm = what type of normalization to apply. Can be 'evals', 'unit' or 'none'.
      to_norm   = whether to norm or not to.
  Output:  
      the normalized matrix
  """    
  if to_norm and type_norm != 'none':
    if type_norm == 'evals':
      eigenvalues, _ =  linalg.eig(mat)
      mat = mat / np.max(np.abs(eigenvalues))
    elif type_norm == 'unit':
      mat = mat @ np.diag(1 / np.sqrt(np.sum(mat**2,0))) 
         
  return mat


#%% Plotting Functions


#%% Working with files
    
def load_mat_file(mat_name , mat_path = '',sep = sep, squeeze_me = True,simplify_cells = True):
    """
    Function to load mat files. Useful for uploading the c. elegans data. 
    Example:
        load_mat_file('**.mat',[some path])
    """
    if mat_path == '':
        data_dict = sio.loadmat(mat_name, squeeze_me = squeeze_me,simplify_cells = simplify_cells)
    else:
        data_dict = sio.loadmat(mat_path+sep+mat_name, squeeze_me = True,simplify_cells = simplify_cells)
    return data_dict    
    
#%% Data Pre-Processing    
 
def take_df_under_col(df, col_name, val_col):
    """
    Returns a subset of a pandas DataFrame where the value in the specified column matches the given value.

    Parameters:
    df (pandas.DataFrame): the DataFrame to filter
    col_name (str): the name of the column to match against
    val_col (any): the value to match in the column

    Returns:
    pandas.DataFrame: a subset of the original DataFrame where the specified column has the given value
    """    
    return df.iloc[np.where(df[col_name] ==  val_col)[0],:]

def to_smooth_data(data, kernel = 'gaussian', direction = 0, 
                   kernel_size_take = 40,
                   kernel_size = 40, freq = 150,
                kernel_params = {'std' : 1, 'mean' : 0}):
    """
    Parameters
    ----------
    data : TYPE
        DESCRIPTION.
    kernel : can be a string or a np.array or a list, optional
        if string, currently, can be ony 'gaussian'. The default is 'gaussian'.
    direction: if data is 2d

    Returns
    -------
    smoothed_data : TYPE
        DESCRIPTION. 
    """
    if direction not in [0,1]: raise ValueError('Invalid Direction!') 
        
    if isinstance(kernel , str):            
        if kernel == 'gaussian':
            if np.mod(kernel_size,2) == 0: kernel_size =kernel_size + 1
            mat = np.arange(-(kernel_size-1)/2, (kernel_size-1)/2)/freq
            kernel = gaussian_vals(mat, norm = True, power = 2, **kernel_params)
            take_in_kernel = np.round(np.linspace(0,kernel_size-2, kernel_size_take + 2)).astype(int)
            kernel = kernel[take_in_kernel]
            
        else:
            raise ValueError('Unknown kernel type!')
    if len(data.shape) == 3: # Here the 3rd dim must be the repeatitions
        smoothed_data = np.dstack([to_smooth_data(data[:,:,i], kernel, direction, kernel_size, freq, 
                    kernel_params) for i in range(data.shape[2])])
        
    elif len(data.shape) == 2:
        if len(kernel) >= data.shape[direction]:
            raise IndexError('Kernel should be shorter than the data')
        if direction == 0:
            smoothed_data = np.hstack([np.convolve(data[:,col], kernel, mode = 'same').reshape((-1,1))
                                       for col in np.arange(data.shape[1]) ]  )
        elif direction == 1 :
            smoothed_data = np.vstack([np.convolve(data[row,:], kernel, mode = 'same').reshape((1,-1))
                                       for row in np.arange(data.shape[0]) ]  )

            
    elif len(data.shape) == 1:
        if len(kernel) >=  len(data):
            raise IndexError('Kernel should be shorter than the data')
        smoothed_data = np.convolve(data, kernel, mode = 'same')      
    else:
        raise ValueError('Invalid Data Dimension for Smoothing')
    
    return smoothed_data
    
    
def merge_dicts(list_of_dicts, dict_01 = {}):
    """
    Merge a list of dictionaries into a single dictionary.
    
    This function takes a list of dictionaries and merges them into a single dictionary.
    It can handle merging any number of dictionaries in the list.
    
    Args:
        list_of_dicts (list): A list of dictionaries to be merged.
        dict_01 (dict, optional): An optional initial dictionary to start the merging process.
            Defaults to an empty dictionary.
    
    Returns:
        dict: A dictionary containing all the key-value pairs from the input dictionaries
        in the list merged together.
    
    Example:
        dict_list = [{'a': 1, 'b': 2}, {'b': 3, 'c': 4}, {'d': 5}]
        result = merge_dicts(dict_list)
        # Output: {'a': 1, 'b': 3, 'c': 4, 'd': 5}
    """
    if len(list_of_dicts) == 1:
            return {**dict_01, **list_of_dicts[0]}
    else:
        dict_01 =  {**dict_01, **{**list_of_dicts[0], **list_of_dicts[1]}}
    if len(list_of_dicts) == 2:
        return dict_01
    
    else:
        return merge_dicts(list_of_dicts[2:], dict_01)

    

def from_counted_dict_to_array(counted_dict,  to_make_array = True):
    """
    Converts a dictionary with values that are pandas DataFrames of the same shape into a 3D numpy array.
    
    Args:
        counted_dict (dict): Dictionary with keys that are integers and values that are pandas DataFrames of the same shape.
        to_make_array (bool): If True, converts the dictionary into a 3D numpy array. If False, returns the original dictionary.
    
    Returns:
        If to_make_array is True, returns a 3D numpy array with the values of the input dictionary stacked along the third axis.
        If to_make_array is False, returns the input dictionary.
    """
    if to_make_array:
        keys_sorted = np.sort(list(counted_dict.keys()))
        #array_stack = np.empty((counted_dict[key_sorted[0]].shape[0],counted_dict[key_sorted[0]].shape[1], keys_sorted))
        return np.dstack([counted_dict[key].values.reshape((counted_dict[key].shape[0],counted_dict[key].shape[1], 1))
                                   for key in keys_sorted])
    return counted_dict
        

    
def plot_grphs_networks(type_mapping = 'trends'):
    """
    Plots graphs for each column of a matrix, where each column represents a directed graph. 
    The edges of the graphs are represented by the non-zero entries in the corresponding column, and 
    the nodes are labeled using a mapping provided as input.
    
    Args:
        type_mapping (str): Determines the mapping to be used. If 'trends', the function will use the mapping stored
                            in the file 'mapping_trends.npy' in the current directory.
    
    Returns:
        None
    """
    if type_mapping == 'trends':
        mapping = np.load('mapping_trends.npy', allow_pickle=True).item()
        path_result = r'E:\CODES FROM GITHUB\g_MILCCI_Python\g_MILCCI_Python\trends\2022-10-31\kk.npy'
        # new trends is in 
        
    n = np.load(path_result, allow_pickle=True).item()
    na = n['A']/n['A'].sum(0)
    fig, axs = plt.subplots(1, na.shape[1], figsize = (20,5))
    for col_num in np.arange(na.shape[1]):
        col = na[:, col_num]
        graph_one_col(col, mapping, axs[col_num])




def keep_only_first_last_ticklabels(ax, xticklabels  = [], yticklabels = [],  xticks = [], yticks = [],
                                    fontsize = 14, apply_to_x = True, apply_to_y = True, num_digits = 10):
    """
    Modify tick labels and tick positions on a matplotlib axis to show only the first and last ticks.
    
    Parameters:
    ax (matplotlib.axes._subplots.AxesSubplot): The matplotlib axis to be modified.
    xticklabels (list, optional): Custom tick labels for the x-axis. If empty, default tick labels will be used.
    yticklabels (list, optional): Custom tick labels for the y-axis. If empty, default tick labels will be used.
    xticks (list, optional): Custom tick positions for the x-axis. If empty, default tick positions will be used.
    yticks (list, optional): Custom tick positions for the y-axis. If empty, default tick positions will be used.
    fontsize (int, optional): Font size for the tick labels. Default is 14.
    apply_to_x (bool, optional): If True, modify tick labels and positions on the x-axis. Default is True.
    apply_to_y (bool, optional): If True, modify tick labels and positions on the y-axis. Default is True.
    
    Returns:
    matplotlib.axes._subplots.AxesSubplot: The modified matplotlib axis.
    """
    if checkEmptyList(xticks):
        xticks = ax.get_xticks()
    if checkEmptyList(yticks):
        yticks = ax.get_yticks()
    if checkEmptyList(xticklabels):
        xticklabels = ax.get_xticks()
        if num_digits > 0:
            xticklabels = (np.round((xticklabels + 0.5)*num_digits)).astype(int)/num_digits
        if (xticklabels.astype(int) == xticklabels).all():
            
            xticklabels = xticklabels.astype(int)
        xticklabels =  xticklabels.astype(str)
        
        xticklabels = [xticklabels[0], xticklabels[-1]]
    if checkEmptyList(yticklabels):
        yticklabels = ax.get_yticks()
        if num_digits > 0:
            yticklabels = (np.round((yticklabels + 0.5)*num_digits)).astype(int)/num_digits
        if (yticklabels.astype(int) == yticklabels).all():
            yticklabels = yticklabels.astype(int)
        yticklabels = yticklabels.astype(str)
        yticklabels = [yticklabels[0], yticklabels[-1]]

    if apply_to_x:
        ax.set_xticks([xticks[0], xticks[-1]])
        ax.set_xticklabels(xticklabels,  fontsize =  fontsize)
    if apply_to_y:
        ax.set_yticks([yticks[0], yticks[-1]])
        ax.set_yticklabels(yticklabels,  fontsize =  fontsize)
        
    
    return ax
        


def add_labels(ax, xlabel='X', ylabel='Y', zlabel='', title='', xlim = None, ylim = None, zlim = None,xticklabels = np.array([None]),
               yticklabels = np.array([None] ), xticks = [], yticks = [], legend = [], ylabel_params = {'fontsize':18},
               zlabel_params = {'fontsize':18}, 
               xlabel_params = {'fontsize':18}, 
               title_params = {'fontsize':26}):
  """
  This function add labels, titles, limits, etc. to figures;
  Inputs:
      ax      = the subplot to edit
      xlabel  = xlabel
      ylabel  = ylabel
      zlabel  = zlabel (if the figure is 2d please define zlabel = None)
      etc.
  """
  if xlabel != '' and xlabel != None: ax.set_xlabel(xlabel, **xlabel_params)
  if ylabel != '' and ylabel != None:ax.set_ylabel(ylabel, **ylabel_params)
  if zlabel != '' and zlabel != None:ax.set_zlabel(zlabel,**ylabel_params)
  if title != '' and title != None: ax.set_title(title, **title_params)
  if xlim != None: ax.set_xlim(xlim)
  if ylim != None: ax.set_ylim(ylim)
  if zlim != None: ax.set_zlim(zlim)
  
  if (np.array(xticklabels) != None).any(): 
      if len(xticks) == 0: xticks = np.arange(len(xticklabels))
      ax.set_xticks(xticks);
      ax.set_xticklabels(xticklabels);
  if (np.array(yticklabels) != None).any(): 
      if len(yticks) == 0: yticks = np.arange(len(yticklabels)) +0.5
      ax.set_yticks(yticks);
      ax.set_yticklabels(yticklabels);
  if len(legend)       > 0:  ax.legend(legend)


def remove_edges(ax, include_ticks = True, top = False, right = False, bottom = True, left = True):

    """
    Removes selected edges and ticks from a Matplotlib axis object.

    Parameters:
        ax (matplotlib.axes.Axes): The axis object to modify.
        include_ticks (bool): If True, ticks will not be removed from the axis. Default is False.
        top (bool): If True, the top edge of the axis will be removed. Default is False.
        right (bool): If True, the right edge of the axis will be removed. Default is False.
        bottom (bool): If True, the bottom edge of the axis will be removed. Default is False.
        left (bool): If True, the left edge of the axis will be removed. Default is False.

    Returns:
        None

    Example usage:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots()
        # Draw a plot on the axis object
        remove_edges(ax, include_ticks=False, top=True, right=True, bottom=True, left=True)
    """    
    ax.spines['top'].set_visible(top)    
    ax.spines['right'].set_visible(right)
    ax.spines['bottom'].set_visible(bottom)
    ax.spines['left'].set_visible(left)  
    if not include_ticks:
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])

def combine_mid_interpolate(x):
    """
    Combine and interpolate the midpoints of a given array.

    Parameters:
        x (array-like): Input array containing values.

    Returns:
        list: Combined array with interpolated midpoints.

    Examples:
        >>> x = [1, 3, 5, 7]
        >>> combine_mid_interpolate(x)
        [1, 2, 3, 4, 5, 6, 7]
    """    
    l = len(x)   
    
    x_smooth = np.interp(np.arange(l), 
                         np.arange(l) + 0.5 , 
                         x)[:-1]
    combined = lists2list([ [x[i],x_smooth[i]] for i in range(len(x)-1)]) + [x[-1]]
    return combined
    

def downsample(x, size)    :
    return np.resize(x, size)

    
def create_AP_tempolate(to_plot = False, norm = False):
    AP = [-70,-70, -50, 25, 40, 25,-60, -85, -90, -75,  -70, -70]
    AP  =  AP[:1] + AP[1:]
    if to_plot:
        fig, ax = plt.subplots()
        ax.plot(AP)
        remove_edges(ax, left = True, bottom = True)
    if norm:
        AP = AP/np.max(np.abs(AP))
    return AP
    


    
    
def create_colors(len_colors, perm = [0,1,2], style = 'random', cmap  = 'viridis', shuffle_colors = False, shuffle_seed = 0):
    """
    Create a set of discrete colors with a one-directional order
    Input: 
        len_colors = number of different colors needed
    Output:
        3 X len_colors matrix decpiting the colors in the cols
    """
    if style == 'random':
        colors = np.vstack([np.linspace(0,1,len_colors),(1-np.linspace(0,1,len_colors))**2,1-np.linspace(0,1,len_colors)])
        colors = colors[perm, :]
        assert not shuffle_colors, "TB done"
    else:
        
        # Define the colormap you want to use
        
        cmap = plt.get_cmap(cmap) 
        # Create an array of values ranging from 0 to 1 to represent positions in the colormap
        positions = np.linspace(0, 1, len_colors)
        
        # Generate the list of colors by applying the colormap to the positions
        colors = [cmap(pos) for pos in positions]
        
        # You can now use the 'colors' list as a list of colors in your application
        if shuffle_colors:
            random.seed(shuffle_seed)
            random.shuffle(colors)
            
    return colors




def remove_background(ax, grid = False, axis_off = True):
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    if not grid:
        ax.grid(grid)
    if axis_off:
        ax.set_axis_off()
    
        
def create_ax(ax, nums = (1,1), size = (10,10), proj = 'd2',return_fig = False,sharey = False, sharex = False, fig = []):
    #ax_copy = ax.copy()
    if isinstance(ax, list) and len(ax) == 0:
        #print('inside')
        if proj == 'd2':
            fig,ax = plt.subplots(nums[0], nums[1], figsize = size, sharey = sharey, sharex = sharex)
        elif proj == 'd3':
            fig,ax = plt.subplots(nums[0], nums[1], figsize = size,subplot_kw={'projection':'3d'}, sharey = sharey, sharex = sharex)
        else:
            raise NameError('Invalid proj input')
        if return_fig:
            return fig, ax

    if  return_fig :
        return fig, ax
    return ax

def plot_3d(mat, params_fig = {}, fig = [], ax = [], params_plot = {}, type_plot = 'plot', to_return = False):
    """
    Plot 3D data.

    Parameters:
    - mat (numpy.ndarray): 3D data to be plotted.
    - params_fig (dict): Addonal parameters for creating the figure.
    - fig (matplotlib.figure.Figure): Existing figure to use (optional).
    - ax (numpy.ndarray): Existing 3D subplot axes to use (optional).
    - params_plot (dict): Additional parameters for the plot.
    - type_plot (str): Type of 3D plot ('plot' for line plot, 'scatter' for scatter plot).

    Returns:
    - fig (matplotlib.figure.Figure): The created or existing figure.
    - ax (numpy.ndarray): The created or existing 3D subplot axes.
    """ 
    if checkEmptyList(ax):
        fig, ax = create_3d_ax(1,1, params_fig)
    if type_plot == 'plot':    
        scatter = ax.plot(mat[0], mat[1], mat[2], **params_plot)
    else:
        scatter = ax.scatter(mat[0], mat[1], mat[2], **params_plot)
    if to_return:
        return scatter
    



        
def vstack_f(ar1, ar2, direction = 0)    :
    
    """
    Stack arrays vertically including initialization
    
    Parameters:
    ar1 (numpy.ndarray): The first array to stack.
    ar2 (numpy.ndarray): The second array to stack.
    direction (int, optional): The direction in which to stack the arrays (default=0).
    
    Returns:
    numpy.ndarray: The stacked array.
    
    Raises:
    ValueError: If the shapes of the input arrays are not consistent.
    
    Example:
    >>> a = np.array([[1, 2], [3, 4]])
    >>> b = np.array([[5, 6]])
    >>> vstack_f(a, b)
    array([[1, 2],
           [3, 4],
           [5, 6]])
    """
    if direction == 0:
        if len(ar2.flatten()) == np.max(ar2.shape):
            ar2 = ar2.reshape((1,-1))
        if checkEmptyList(ar1):
            return ar2
        else:
            if len(ar1.flatten()) == np.max(ar1.shape):
                ar1 = ar1.reshape((1,-1))
            if ar1.shape[1]  != ar2.shape[1]:
                raise ValueError('shapes are not consistent! shape ar1 is %s and shape ar2 is %s'%(str(ar1.shape), str(ar2.shape)))
            return np.vstack([ar1,ar2])
    elif direction == 1:
        if checkEmptyList(ar1):
            a1_T = []
        else:
            a1_T = ar1.T
        return vstack_f(a1_T, ar2.T, 0).T 
    

def norm_trends(data, factor = 100):
    
    for l in range(data.shape[2]):
        data0 = (data[:,:,l] - np.min(data[:,:,l],1).reshape((-1,1))) / (np.percentile(data[:,:,l], 99, axis = 1) - np.min(data[:,:,l], axis = 1)).reshape((-1,1))
        data[:,:,l] = data0
    
    data[data > 1] = 1
    data *= factor
    return data
        

    




def pad_mat(mat, pad_val = np.nan, size_each = 1, axis = 1):
    # size each can be an integer or a list
    if isinstance(size_each,(list, tuple, np.ndarray)):
        size_left = size_each[0]
        size_right = size_each[1]
    else:
        size_left = size_each
        size_right = size_each        
        
    if axis == 1:
        pad_left = np.ones((mat.shape[0], size_left))*pad_val
        pad_right = np.ones((mat.shape[0], size_right))*pad_val
        mat = np.hstack([pad_left, mat, pad_right])
        
    elif axis == 0:
        pad_left = np.ones((size_left, mat.shape[1]))*pad_val
        pad_right = np.ones((size_right, mat.shape[1]))*pad_val
        mat = np.vstack([pad_left, mat, pad_right])  
        
    elif axis == 2:
        #each_pad = np.ones((mat.shape[0], mat.shape[1], size_each))*pad_val
        pad_left = np.ones((mat.shape[0], mat.shape[1], size_left))*pad_val
        pad_right = np.ones((mat.shape[0], mat.shape[1], size_right))*pad_val
        mat = np.dstack([pad_left, mat, pad_right])        
    else:
        raise ValueError('undefined axis for padding')
    return mat

def zero_pad(mat, max_pad, left = False, right = True, axis = 1):
    """
    DEPARACTED. SEE PAD_MAT
    Zero-pad a 2D numpy array along a specified axis.

    Parameters:
        mat (numpy.ndarray): The input 2D array to be zero-padded.
        max_pad (int): The maximum number of padding elements to add on each side of the array.
        left (bool, optional): Whether to add padding to the left side of the array (default is False).
        right (bool, optional): Whether to add padding to the right side of the array (default is True).
        axis (int, optional): The axis along which padding should be added (0 for rows, 1 for columns, default is 1).

    Returns:
        numpy.ndarray: A new numpy array with zero-padding added as specified.

    If both `left` and `right` are set to False, the function returns the input array without any padding.

    Example:
        >>> import numpy as np
        >>> arr = np.array([[1, 2], [3, 4]])
        >>> result = zero_pad(arr, 2, left=True, right=False)
        >>> print(result)
        array([[0, 0, 1, 2],
               [0, 0, 3, 4]])

    Note:
        The function transposes the input array if `axis` is set to 0 and then applies padding along columns.
    """    
    if axis == 0:
        return zero_pad(mat.T, max_pad, left, right, axis = 0).T
    else:
        if left and right:
            num_pad = (max_pad - mat.shape[1])/2
            num_pad_left = int(np.floor(num_pad))
            num_pad_right = int(np.ceil(num_pad))
        elif left:
            num_pad_left = max_pad - mat.shape[1]
            num_pad_right = 0
        elif right:
            num_pad_left = 0
            num_pad_right =  max_pad - mat.shape[1] 
        else:
            return mat

        N = mat.shape[0]
        return  np.hstack([np.zeros((N, num_pad_left)), mat, np.zeros((N, num_pad_right))])


"""
existing methods

"""

from tensorly.decomposition import tucker, parafac, non_negative_tucker
from tensorly import random as random_tl

def assertions_existing_methods(existing_methods_results, desired_shape_A, desired_shape_phi, method_name):
    # A need to be N_channels X p X K
    # phi needs to be T x p X K
    assert 'phi' in existing_methods_results, 'no phi keys in existing_methods_results, keys are %s'%str(existing_methods_results.keys())
    assert 'A' in existing_methods_results, 'no A keys in existing_methods_results, keys are %s'%str(existing_methods_results.keys())
    if len(existing_methods_results['A']) > 0:
        assert existing_methods_results['A'].shape[:2] == desired_shape_A[:2], "something is wrong with As dimension of method %s (existing_methods_results['A'].shape != desired_shape_A): %s vs %s"%(method_name, existing_methods_results['A'].shape , desired_shape_A)
    else:
        print('pay atttention! A is empty for %s'%method_name)
    
    if len(existing_methods_results['phi']) > 0:
        assert existing_methods_results['phi'].shape == desired_shape_phi, "something is wrong with phis dimension of method %s (existing_methods_results['phi'].shape != desired_shape_phi): %s vs %s"%(method_name, existing_methods_results['phi'].shape , desired_shape_phi)
    else:
        print('pay atttention! phi is empty for %s'%method_name)
    
    
    
    
    
    
def run_existing_methods(data, p, methods_to_compare = ['flattened_svd','parafac','tucker'],
                         params_parafac = {}, params_tucker = {}, noise_add_std = 0.005):
    # the mathods are taken from http://tensorly.org/stable/modules/api.html#module-tensorly.decomposition
    # user guide http://tensorly.org/stable/user_guide/quickstart.html#tensor-decomposition
    # we want to get back phi: [times, num ensembles, num_state]
    # data is [n channels, n times, num_state]
    np.random.seed(11)
    num_states = data.shape[2]
    num_times = data.shape[1]
    num_channels = data.shape[0]
    
    desired_shape_A = (num_channels, p, num_states)
    desired_shape_phi = (num_times, p, num_states )
    
    
    results = {}
    A_flattened, phi_flattened = run_flattened_svd(data, p)
    results['flattened'] = {'A':A_flattened, 'phi':phi_flattened}
    assertions_existing_methods(results['flattened'], desired_shape_A, desired_shape_phi, 'flattened')
    
    ##############################################################
    
    A_tucker, phi_tucker, core, factors = run_tucker(data, p = p, params_tucker = params_tucker)
    results['tucker'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
    assertions_existing_methods(results['tucker'], desired_shape_A, desired_shape_phi, 'tucker')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['tucker'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['tucker'], desired_shape_A, desired_shape_phi, 'tucker')        
        
    ##############################################################
    
    A_parafac, phi_parafac, factors = run_parafac(data, p = p, params_parafac = params_parafac, scale = False)
    results['parafac'] = {'A':A_parafac, 'phi':phi_parafac, 'factors':factors}
    assertions_existing_methods(results['parafac'], desired_shape_A, desired_shape_phi, 'parafac')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['parafac'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['parafac'], desired_shape_A, desired_shape_phi, 'parafac')        
        
    ##############################################################
    
    A_parafac, phi_parafac, factors = run_parafac(data, p = p, params_parafac = params_parafac, scale = True)
    results['parafac_scale'] = {'A':A_parafac, 'phi':phi_parafac, 'factors':factors}
    assertions_existing_methods(results['parafac_scale'], desired_shape_A, desired_shape_phi, 'parafac_scale')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['parafac_scale'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['parafac_scale'], desired_shape_A, desired_shape_phi, 'parafac_scale')        
        
    ##############################################################
        
    A_parafac, phi_parafac, factors = run_nonneg_parafac(data, p = p, params_parafac = params_parafac, scale = True)
    results['tucker'] = {'A':A_parafac, 'phi':phi_parafac, 'factors':factors}
    assertions_existing_methods(results['tucker'], desired_shape_A, desired_shape_phi, 'tucker')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['tucker'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['tucker'], desired_shape_A, desired_shape_phi, 'tucker')        
        
        
    ##############################################################
    
    A_parafac, phi_parafac, factors = run_nonneg_parafac(data, p = p, params_parafac = params_parafac, scale = True)
    results['parafac_nonneg_scale'] = {'A':A_parafac, 'phi':phi_parafac, 'factors':factors}
    assertions_existing_methods(results['parafac_nonneg_scale'], desired_shape_A, desired_shape_phi, 'parafac_nonneg_scale')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['parafac_nonneg_scale'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['parafac_nonneg_scale'], desired_shape_A, desired_shape_phi, 'parafac_nonneg_scale')        
    
    ##############################################################
    
    A_parafac, phi_parafac, factors = run_nonneg_parafac(data, p = p, params_parafac = params_parafac, scale = False)
    results['parafac_nonneg'] = {'A':A_parafac, 'phi':phi_parafac, 'factors':factors}
    assertions_existing_methods(results['parafac_nonneg'], desired_shape_A, desired_shape_phi, 'parafac_nonneg')
    
    if len(A_tucker) == 0:
        data_noisy = data + np.random.randn(*data.shape)*noise_add_std        
        A_tucker, phi_tucker, core, factors = run_tucker(data_noisy, p = p, params_tucker = params_tucker)
        results['parafac_nonneg'] = {'A':A_tucker, 'phi':phi_tucker, 'factors':factors, 'core':core}
        assertions_existing_methods(results['parafac_nonneg'], desired_shape_A, desired_shape_phi, 'parafac_nonneg')        
    
    return results
    


from tensorly.decomposition import non_negative_parafac


def run_nonneg_parafac_diagnostics(data, rank=10, params_parafac={}, scale=False):
    # ---- Check invalid entries ----
    if np.isnan(data).any():
        print("X Input tensor contains NaNs.") 
        return []
    if np.isinf(data).any():
        print("X Input tensor contains Infs.") 
        return []

    # ---- Scaling / centering ----
    X_scaled = data.copy()
    if scale:
        shape = data.shape
        X_flat = X_scaled.reshape(-1, shape[-1])
        scaler = StandardScaler()
        X_flat = scaler.fit_transform(X_flat)
        X_scaled = X_flat.reshape(shape)

    # ---- Attempt non-negative PARAFAC ----
    try:
        factors = non_negative_parafac(X_scaled, rank=rank, **params_parafac)
        for i, f in enumerate(factors.factors):
            if np.isnan(f).any():
                print(f"x Factor {i} contains NaNs.")
                return []
            if np.isinf(f).any():
                print(f"x Factor {i} contains Infs.")
                return []
        print("V Non-negative PARAFAC succeeded.")
        return factors
    except Exception as e:
        print("X Non-negative PARAFAC failed with error:", e)
        return []

def run_nonneg_parafac(data, p=10, params_parafac={}, scale = False):
    params_parafac_def = {
        "init": "svd",
        "n_iter_max": 2000,
        "tol": 1e-6,
        "normalize_factors": True
    }

    params_parafac = {**params_parafac_def, **params_parafac}
    N, T, k = data.shape

    factors = run_nonneg_parafac_diagnostics(data, rank = p, params_parafac = params_parafac, scale = scale)
    if len(factors) == 0:
        return [], [], []

    factors_f = factors.factors
    factors_w = factors.weights

    A_parafac = factors_f[0]
    phi_base_parafac = factors_f[1]
    k_parafac = factors_f[2]

    phi_parafac = np.dstack([phi_base_parafac * k_parafac[k_spec, :].reshape((1, -1)) for k_spec in range(k)])

    return A_parafac, phi_parafac, factors

    

    
def run_flattened_svd(data, p = 10, max_TK = 1000, verbose = False):
    
    
    data2d = np.hstack([data[:,:,layer] for layer in range(data.shape[2])])
    print('data2d.shape %s' % str( data2d.shape)    )
    print('data.shape %s' % str(  data.shape)  )  
    
    
    
    data_edges = np.linspace(0, data2d.shape[1] , 1+ data.shape[2] ).astype(int)
    assert data2d.shape[0] == data.shape[0], '%d vs %d'%( data2d.shape[0] , data.shape[0])
    assert data2d.shape[1] == data.shape[1]*data.shape[2], '%d vs %d'%( data2d.shape[1] , data.shape[1]*data.shape[2])
    A, s, VT = np.linalg.svd(data2d, full_matrices=False)
    
    A, s, VT = A[:, :p], s[:p], VT[:p, :]
    
    if verbose:
        print('s.shape %s'%s.shape)
        print('VT shape %s'%str(VT.shape))
        print('A.shape %s' %str(A.shape))
        print('------------------------------------')
    S = np.diag(s)

    assert VT.shape[1] == data2d.shape[1], '%d vs %d'%( VT.shape[2], data2d.shape[2])
    assert A.shape[0] == data2d.shape[0]
    #print('VT shape')
    #print(VT.shape)
    
    #phi = split_stacked_data(VT, T = data.shape[1], k = data.shape[2]).transpose((1,0,2))
    assert S.shape[0] == p
    phi_full = S @ VT
    phi_full_3d = np.dstack([phi_full[:, start:end]  for  start, end in zip( data_edges[:-1], data_edges[1:] ) ]).transpose(1,0,2)
    if verbose: print('phi_full_3d .shape %s'%str(phi_full_3d.shape ))
    return A, phi_full_3d

def run_tucker_diagnostics(data, rank=10, params_tucker={}):
    # ---- Check basic invalid entries ----
    if np.isnan(data).any():
        print("X TUCKER Input tensor contains NaNs.")
        return [],[]
    if np.isinf(data).any():
        print("X TUCKER Input tensor contains Infs.")
        return [],[]
    # ---- Check extreme values ----
    if np.max(np.abs(data)) > 1e300:
        print("WARN TUCKER Very large values may cause overflow.")
    if np.max(np.abs(data)) < 1e-300 and np.max(np.abs(data)) != 0:
        print("WARN TUCKER Very small values may cause underflow.")

    # ---- Check condition numbers of mode unfoldings ----
    for mode in range(data.ndim):
        unfolding = np.reshape(np.moveaxis(data, mode, 0), (data.shape[mode], -1))
        try:
            u, s, vh = np.linalg.svd(unfolding, full_matrices=False)
            cond_num = s[0] / s[-1] if s[-1] > 1e-12 else np.inf
            if cond_num > 1e12:
                print("WARN TUCKER Mode-%d unfolding is ill-conditioned (cond=%.2e)" % (mode, cond_num))
        except Exception as e:
            print("X TUCKER SVD failed for mode-%d unfolding: %s" % (mode, str(e)))
            return [],[]

    # ---- Attempt Tucker decomposition ----
    try:
        # Call your existing function
        core, factors = tucker(data, rank = rank, **params_tucker)

        # Check if any factors contain NaNs/Infs
        if factors is not None:
            for i, f in enumerate(factors):
                if np.isnan(f).any():
                    print("X TUCKER Factor %d contains NaNs." % i)
                    return [], []
                if np.isinf(f).any():
                    print("X TUCKER Factor %d contains Infs." % i)
                    return [],[]

        print("V TUCKER decomposition seems OK.")
        return core, factors

    except Exception as e:
        print("TUCKER params tucker are %s" % str(params_tucker))
        print("X TUCKER Tucker failed with error: %s" % str(e))
        if "SVD did not converge" in str(e):
            print("TUCKER Reason: Likely degeneracy, ill-conditioning, or too high rank.")
        else:
            print("TUCKER Reason: Unknown, check tensor values and scaling.")
        print('finished tucker! \n\n\n')
        return []  ,[]  

    
    
def run_tucker(data, p = 10, params_tucker = {}):
    """
    explanation: factors[0] is A, 
    ignoring core?! :( 
    explanation: http://tensorly.org/stable/modules/generated/tucker-function.html#tensorly.decomposition.tucker
    """
    from numpy.linalg import LinAlgError
    from tensorly.decomposition import tucker

    params_tucker_default = {
        #'tol': 1e-5,
        'n_iter_max': 100,
        'init': 'svd',  # 'random',
        'random_state': 42
    }
    
    params_tucker = {**params_tucker_default , **params_tucker}
    
    k = data.shape[2]
    T = data.shape[1]
    
    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    data = np.clip(data, -1e6, 1e6)  
    data[data == np.isnan(data)] = 0
    
    if T >= k:
        print('3rd dim rank will be %d' % p)
        try:
            core, factors =  run_tucker_diagnostics(data, rank=[p, p, p], params_tucker = params_tucker)
            # core, factors = tucker(data, rank=[p, p, p], ** params_tucker)
        except LinAlgError:
            print("SVD did not converge in T>=k case — returning empty results.")
            return np.array([]), np.array([]), None, None
        if len(factors) == 0:
            A_tucker = []
            phi_tucker = []
            k_tucker  = []
        else:
            A_tucker = factors[0]
            phi_base_tucker = factors[1]
            k_tucker = factors[2]
        
            print('k_tucker.shape')
            print(k_tucker.shape)
            print('phi_base_tucker')
            print(phi_base_tucker.shape)
            phi_tucker = np.dstack([phi_base_tucker*k_tucker[k_spec, :].reshape((1,-1)) for k_spec in range(k)])
            print('phi tucker shape %s' % str(phi_tucker.shape))
    else:
        print('3rd dim rank will be %d' % p)
        try:
            #core, factors = tucker(data, rank=[p, p, T], ** params_tucker)
            core, factors =  run_tucker_diagnostics(data, rank=[p, p, T],  params_tucker = params_tucker)
        except LinAlgError:
            print("SVD did not converge in T<k case — returning empty results.")
            return np.array([]), np.array([]), None, None
        
        if len(factors) == 0:
            A_tucker = []
            phi_tucker = []
            k_tucker  = []
        else:
            A_tucker = factors[0]
            phi_base_tucker = factors[1]
            k_tucker = factors[2]
        
            print('k_tucker.shape')
            print(k_tucker.shape)
            print('phi_base_tucker')
            print(phi_base_tucker.shape)
            phi_tucker = np.dstack([phi_base_tucker*k_tucker[k_spec, :].reshape((-1,1)) for k_spec in range(k)])
            print('phi tucker shape %s' % str(phi_tucker.shape))
        
    return A_tucker, phi_tucker, core, factors

    
from sklearn.preprocessing import StandardScaler


    
def run_parafac_diagnostics(data, rank=10, params_parafac={}, scale=False):
    # ---- Check basic invalid entries ----
    if np.isnan(data).any():
        print("X Input tensor contains NaNs.") 
        return []
    if np.isinf(data).any():
        print("X Input tensor contains Infs.") 
        return []

    # ---- Scaling / centering ----
    X_scaled = data.copy()
    if scale:
        shape = data.shape
        X_flat = X_scaled.reshape(-1, shape[-1])
        scaler = StandardScaler()
        X_flat = scaler.fit_transform(X_flat)
        X_scaled = X_flat.reshape(shape)

    # ---- Check extreme values ----
    if np.max(np.abs(X_scaled)) > 1e300:
        print("! Warning: Very large values may cause overflow.")
    if np.max(np.abs(X_scaled)) < 1e-300 and np.max(np.abs(X_scaled)) != 0:
        print("! Warning: Very small values may cause underflow.")

    # ---- Check condition numbers of mode unfoldings ----
    for mode in range(X_scaled.ndim):
        unfolding = np.reshape(np.moveaxis(X_scaled, mode, 0), (X_scaled.shape[mode], -1))
        try:
            u, s, vh = np.linalg.svd(unfolding, full_matrices=False)
            cond_num = s[0] / s[-1] if s[-1] > 1e-12 else np.inf
            if cond_num > 1e12:
                print(f"⚠️ Mode-{mode} unfolding is ill-conditioned (cond={cond_num:.2e})")
        except Exception as e:
            print(f"X SVD failed for mode-{mode} unfolding:", e)
            return []

    # ---- Attempt PARAFAC ----
    try:
        factors = parafac(X_scaled, rank=rank, **params_parafac)
        # Check factors for NaNs/Infs
        for i, f in enumerate(factors.factors):
            if np.isnan(f).any():
                print(f"x Factor {i} contains NaNs.")
                return []
            if np.isinf(f).any():
                print(f"x Factor {i} contains Infs.")
                return []

        print("V PARAFAC succeeded.")
        return factors
    except Exception as e:
        err_msg = str(e)
        print("X PARAFAC failed with error:", err_msg)
        if "SVD did not converge" in err_msg:
            print("Reason: Likely degeneracy, ill-conditioning, or too high rank.")
        elif "singular" in err_msg.lower():
            print("Reason: Likely rank-deficient or ill-conditioned input.")
        else:
            print("Reason: Unknown, check tensor values and scaling.")
        return []

def run_parafac(data, p=10, params_parafac={}, scale=False):
    """
    PARAFAC decomposition with optional scaling/centering.
    """
    params_parafac_def = {
        "init": "svd",
        "n_iter_max": 2000,
        "tol": 1e-6,
        "normalize_factors": True
    }
    params_parafac = {**params_parafac_def, **params_parafac}
    N, T, k = data.shape

    factors = run_parafac_diagnostics(data, rank=p, params_parafac=params_parafac, scale=scale)
    if len(factors) == 0:
        return [], [], []

    factors_f = factors.factors
    factors_w = factors.weights

    A_parafac = factors_f[0]
    phi_base_parafac = factors_f[1]
    k_parafac = factors_f[2]

    # Construct phi_parafac
    phi_parafac = np.dstack([phi_base_parafac * k_parafac[k_spec, :].reshape((1, -1)) for k_spec in range(k)])

    return A_parafac, phi_parafac, factors
    
    

    
def create_dict_of_clusters_multi_d_A(full_A, labels = [], terms = [], perc_null = 90, thres = 0):
    if checkEmptyList(terms):
        terms = np.arange(full_A.shape[0])
    if checkEmptyList(labels):
        labels = np.arange(labels.shape[2])
    if isinstance(thres ,(list, tuple, np.ndarray)) and isinstance(thres[0] ,(list, tuple, np.ndarray)):
        return {label : create_dict_of_clusters_single_A(full_A[:,:,label_count],terms = terms, perc_null = perc_null, thres = thres[label_count]) 
                for label_count, label in enumerate(labels)}, labels, terms
    return {label : create_dict_of_clusters_single_A(full_A[:,:,label_count],terms = terms, perc_null = perc_null, thres = thres) 
            for label_count, label in enumerate(labels)}, labels, terms

def create_dict_of_clusters_single_A(A_2d,terms = [], perc_null = 80, thres = 0):
    if checkEmptyList(terms):
        terms = np.arange(A_2d.shape[0])
        
    if isinstance(thres,(list, tuple, np.ndarray)) and 0 in thres:
        raise ValueError('if providing thres list, 0 should not be there. but %s'%thres)
        
        
    if not isinstance(thres,(list, tuple, np.ndarray)) and   thres == 0:
        
        return {'group_%d'%i:terms[np.abs(A_2d[:,i])  > np.percentile(np.abs(A_2d), perc_null)]  
                for i in np.arange(A_2d.shape[1])}
    else:
        if not isinstance(thres,(list, tuple, np.ndarray)):
            thres = [thres]*A_2d.shape[1]
        return {'group_%d'%i:terms[np.abs(A_2d[:,i])  > thres[i]]  
                for i in np.arange(A_2d.shape[1])}

    


    






    
    
    
def l2_vec(vec):
    return np.sqrt(np.nansum(vec**2))
def nan_corr(vec1, vec2)    :
    return np.nansum(vec1 * vec2) / (l2_vec(vec1) * l2_vec(vec2))

    
    
def snythetic_evaluation(full_A, full_phi, real_full_A, real_full_phi):

    if len(real_full_A ) == 0 or len(real_full_phi) == 0:
        d_full = np.load(r'grannet_synth_results_march_2023.npy', allow_pickle=True).item()
        real_full_A = d_full['A']
        real_full_phi = d_full['phi']
        
    ordered_A = []
    ordered_phi = []
    for counter, (cond, phi1) in enumerate(real_full_phi.items()):     
        phi2, A, _ = match_times(phi1, full_phi[:,:,counter].T, full_A[:,:,counter]) #for counter, (cond, phi1) in enumerate(real_full_phi)
        ordered_phi.append(phi2)
        ordered_A.append(A)
    full_phi_ordered = np.dstack(ordered_phi)
    full_A_ordered = np.dstack(ordered_A)
    
    return full_phi_ordered, full_A_ordered    
    
    
from scipy.stats import norm    
def spike_times_to_rate_single_neuron(spike_times_single, max_time = 0, 
                                      padded = False,
                                      window_params = {'wind_type':'gauss', 'wind':1, 'std':0.1, 'interval':0.3}, time_axis = []):
    # CONTINUES!
    """
    Calculate the firing rate of a single neuron given its spike times.
    
    Parameters:
    - spike_times_single (array): Array of spike times for a single neuron.
    - max_time (float, optional): Maximum time duration. If not provided, it is set to the maximum spike time.
    - padded (bool, optional): Whether to pad spike times beyond the max_time. Default is False.
    - window_params (dict, optional): Parameters for the window function. Default is a Gaussian window with parameters {'wind_type': 'gauss', 'wind': 1, 'std': 0.1, 'interval': 0.3}.
    - time_axis (array, optional): Time axis values. If not provided, it is generated based on window_params['interval'].
    
    Returns:
    - vals_rate (array): Firing rate values.
    - time_axis (array): Time axis values.
    
    Raises:
    - ValueError: If the window type is undefined.
    
    Example:
    ```
    spike_times = np.array([0.1, 0.3, 0.7, 1.2, 1.5])
    rate, time = spike_times_to_rate_single_neuron(spike_times)
    ```
    
    """
    wind = window_params['wind']

    if max_time == 0:
        max_time = np.max(spike_times_single)

    wind_p = wind/2
    if not padded:
        spike_times_single = np.hstack([ -spike_times_single[spike_times_single <= wind_p].reshape((1,-1)), spike_times_single.reshape((1,-1)) , 
                                        max_time + spike_times_single[spike_times_single >= max_time - wind_p].reshape((1,-1))]).flatten()
    if  checkEmptyList(time_axis):
        time_axis = np.arange(0, max_time, window_params['interval'])
    if window_params['wind_type'] == 'gauss':
        vals_rate =  np.array([gaussian_val_given_t(t, spike_times_single, wind_p, max_time, window_params['std'], 
                                                    to_plot_example=False) for t in time_axis])
        return vals_rate, time_axis
    else:
        raise ValueError('wind type undefined!')
    
def gaussian_val_given_t(t, vals_all, wind_p, max_t, sigma, to_plot_example = False, path_save_fig = '.'):

    min_max_t = [np.max(t - wind_p), np.min(t+wind_p)]
    vals_in_wind = vals_all[(vals_all < min_max_t[1]) & (vals_all >= min_max_t[0])]


    g_vals = gaussian_pdf(vals_in_wind, t, sigma)


    return np.sum(g_vals)


def gaussian_pdf(x, mu, sigma):
    """
    Calculate Gaussian Probability Density Function (PDF) values.

    Parameters:
    - x: array-like, values at which to evaluate the PDF
    - mu: mean of the distribution
    - sigma: standard deviation of the distribution

    Returns:
    - y: array, Gaussian PDF values corresponding to the input x values
    """
    y = norm.pdf(x, loc=mu, scale=sigma)
    return y



    
def from_cont_times_to_count(spike_times, max_time, interval = 50, wind = 200, additional_signals_to_find_times = {}, mid_points = []):
    # i.e. every element in the returned mat is the avg between begin and end.
    # this is discrete FR estimation 
    if 'mid_points' in locals() and len(mid_points) > 0:
        begins = mid_points - wind/2#np.arange(0, max_time - wind,  interval)
        ends = mid_points + wind / 2#begins + wind
    else:
        begins = np.arange(0, max_time - wind,  interval)
        ends = begins + wind
        
    
    inside = [np.sum(gaussian_vals(spike_times[(spike_times > begin) & (spike_times <= end)], begin, end) )
              for begin, end in zip(begins, ends)]
    if len(additional_signals_to_find_times) > 0:
        if isinstance(additional_signals_to_find_times, dict ):
            indices = {}
            for key, addi in additional_signals_to_find_times.items():
                indices[key] = find_indices(addi, max_time, interval, wind)
                
        elif isinstance(additional_signals_to_find_times, np.ndarray ):            
            indices = find_indices(additional_signals_to_find_times,  max_time, interval, wind)
        else:
            raise ValueError('wrong type additional signals!')
    else:
        indices = []
    return inside, begins, ends, mid_points, indices




def spike_times_to_rate_several_neurons(spike_times, to_plot = False, window_params = {}, max_time = 0,
                                        return_time_axis = False):
    # I believe that this is the better way to calculate FR (firing rate). 
    if isinstance(spike_times, dict):
        spike_times = [val for _, val in spike_times.items()]
    window_params = {**{'wind_type':'gauss', 'wind':0.25, 'std':0.03, 'interval':0.05}, **window_params}
    # ideally max time will be the end of the trial time
    if max_time == 0:
        max_time = np.max([np.max(el) for el in spike_times])
    time_axis = np.arange(0, max_time, window_params['interval'])
    print('time axis size')
    print(time_axis.shape)
    rates = np.vstack([spike_times_to_rate_single_neuron(spike_times_i, window_params = window_params, max_time = max_time, time_axis = time_axis)[0]  for spike_times_i in spike_times])
    if return_time_axis:
        return rates, time_axis
    return rates

from scipy.sparse import coo_matrix  


def from_spike_times_to_rate(spike_dict, type_convert = 'discrete',
                             res = 0.01, max_min_val = [], return_T = False, 
                             T_max = np.inf, T_min = 0,  return_time_axis = False,
                             need_to_take_T_resolution = False, # recommend True if T_max is that same res and spike dict
                             params_gauss = {'wind' : 10, 'direction' : 1, 'sigma' : 1, 'norm_sum' : True, 'plot_gaussian' : False},
                             limit_to_t_min_t_max = False):
    """
    Converts spike times to firing rates.
    spike dict is dictionary of units vs spike times
    res is how much to mutiply it by, such that each res bin will get 1 index. For instance, if res is 0.01 then each [0,1] will get 100 indices. 
    For instance, if my units are ms, an I want 20ms each index, I can have res = 20. 
    in this case it would be better to have res = 1, and then in params gauss have wind = 20 ms and sigma = 5 ms. 
    Parameters:
    - spike_dict (dict): A dictionary of units vs spike times.
    - res (float): A value by which to multiply the spike times.
    - type_convert (str): Type of conversion to perform (default is 'discrete').
    - Ts (dict): Dictionary containing time indices.
    - Ns (dict): Dictionary containing neuron indices.
    - firings_rates_gauss (dict): Dictionary containing Gaussian-convolved firing rates.
    - firings_rates (dict): Dictionary containing firing rates.
    - max_min_val (list): List containing minimum and maximum values.
    - return_T (bool): Whether to return firing rate matrices (default is False).
    - T_max (float): Maximum time value (default is np.inf).
    - params_gauss (dict): Dictionary containing parameters for Gaussian convolution.
    
    Returns:
    - firing_rate_mat (ndarray): Matrix containing firing rates.
    - firing_rate_mat_gauss (ndarray): Matrix containing Gaussian-convolved firing rates.
    - return_T (bool): Whether to return firing rate matrices.
    
    import numpy as np


    """  
    assert 1*return_T + 1*return_time_axis < 2, "you cannot have both return_T and return_time_axis. chooe one"
    if isinstance(spike_dict , (np.ndarray, list)):
        spike_dict = {1: spike_dict}       
        
        
    if T_min >= T_max:
        raise ValueError('T_min must be larger than T_max')
        
    if need_to_take_T_resolution:
        T_max = T_max / res
        T_min = T_min / res
        
    
        
    if res != 1:
        spike_dict = {key:np.array(val) / res for key,val in spike_dict.items()}
    if T_min > 0:
        spike_dict = {key:val - T_min for key,val in spike_dict.items()}
        spike_dict = {key : val[val > 0] for key,val in spike_dict.items()}

    dict_spike_times = np.array(lists2list(list(spike_dict.values()))).flatten()

    assert dict_spike_times.max() > T_min and dict_spike_times.min() < T_max, "T_max and T_min results in empty spikes %d %d %d %d"%(dict_spike_times.max() , T_min , dict_spike_times.min() , T_max)
    """
    make sure keys are continues
    """
    if set(np.arange(len(spike_dict))) != set(list(spike_dict.keys())):
        new_keys = np.arange(len(spike_dict))
        old_keys = np.sort(list(spike_dict.keys()))
        old2new = {old:new for old,new in zip(old_keys, new_keys)}
        spike_dict = {old2new[key]:val for key,val in spike_dict.items()}
    else:
        old2new = {}
    
    spike_dict_keys_sorted = np.sort(list(spike_dict.keys()))
    times_dict_sorted = [spike_dict[neuron] for neuron in spike_dict_keys_sorted]
   
    if checkEmptyList(max_min_val):
        try:
            min_val = np.min([np.min(val) for val in list(spike_dict.values()) if len(val) > 0])
            max_val = np.max([np.max(val) for val in list(spike_dict.values()) if len(val) > 0])
        except:
            print(spike_dict)
        #min_max_val = [min_val, max_val]
        
        
    N = len(spike_dict)

    if not limit_to_t_min_t_max:
        if T_min > 0:
            max_val = max_val - T_min     
        max_val = int(np.ceil(max_val))
        max_val = 1+ int(T_max if T_max < 10**7 else int(np.min([max_val, T_max])))
    else:
    
        max_val = int(T_max - T_min ) + 1
    firing_rate_mat = np.zeros((int(N) ,max_val))    

        
    if type_convert == 'discrete':         
        T_thres = T_max #- T_min

        tup_neurons_and_spikes = np.vstack([ np.hstack([np.array([neuron]*np.sum( times < T_thres )).reshape((-1,1)) , np.array(times[ times < T_thres]).reshape((-1,1)) ])
                                  for neuron, times  in zip(spike_dict_keys_sorted , times_dict_sorted) ])
        
        rows =  tup_neurons_and_spikes[:,0]
        cols =  tup_neurons_and_spikes[:,1].astype(int)
        assert len(rows) == len(cols), '%d_%d'%(len(rows), len(cols))
        

        data = np.ones(len(rows))  # Assuming all values are 1
        assert len(data) == len(cols), '%d_%d'%(len(data), len(cols))
        assert cols.max() <= int(max_val), 'max_val_%d, col max %d'%(max_val, cols.max())
    
        sparse_mat = coo_matrix((data, (rows, cols)), shape=(N, max_val))


        firing_rate_mat = sparse_mat.toarray()
        firing_rate_mat_gauss = gaussian_convolve(firing_rate_mat,  **params_gauss)
            

    if return_T:
        return  firing_rate_mat, firing_rate_mat_gauss, return_T
    if return_time_axis:
        time_axis = np.arange(T_min*res, T_max*res, res)
        assert len(time_axis) in [firing_rate_mat.shape[1] - 1, firing_rate_mat.shape[1], firing_rate_mat.shape[1] + 1], "len mismatch. durations are time_axis %d vs rate_mat %d"%(len(time_axis), firing_rate_mat.shape[1])
        return  firing_rate_mat, firing_rate_mat_gauss, old2new, time_axis
    
    return  firing_rate_mat, firing_rate_mat_gauss, old2new

def plot_raster(dict_spike_time, ax = [], fig = [], max_T = np.inf, res = 0.05, plot_params = {}):
    # input is neuron num as key
    if checkEmptyList(ax):
        fig, ax = plt.subplots()
    T = np.max([np.max(times) for times in list(dict_spike_time.values())])
    keys = list(dict_spike_time.keys())
    yval = np.array(lists2list([[key]*len(val) for key, val in dict_spike_time.items() ]))
    xval = np.array(lists2list([list(val/res) for key, val in dict_spike_time.items() ]))
    
    if max_T < T:
        yval = yval[xval < max_T]
        xval = xval[xval < max_T]
        
    ax.scatter(xval, yval, marker = '$|$', **plot_params)

import copy    


    
    
def update_labels_to_add_axes(list_labels_add, 
    dict_variables_to_update = { 'labels_concise':[],  'number2labels':[],
     'labels':[],     'labels_text':[], 'labels_unique':[],
     'labels_unique_text':[], 'num_unique_labels':[], 'labels_unique_order':[],'labels_text_tuples':[], 'labels_unique_tuples':[],
                                'number2labelsstr':[],
                                'num_axes': [],  'labels2number':[]},
    omit_existing_labbels = False):
    """
    this function adds the label text to the last one. should be text labels. 
    """
    assert len(dict_variables_to_update['labels']) == len(list_labels_add), "labels must match in length"
    old_dict = copy.deepcopy(dict_variables_to_update)

    assert len([lab for lab in list_labels_add if lab in dict_variables_to_update.get('labels_unique_text')]) == 0, "label already exists"
    
    #### change 
    if omit_existing_labbels:
        labels_text_tuples  = [tuple((lab,)) for tup,lab in zip(dict_variables_to_update.get('labels_text_tuples'), list_labels_add)]
    else:
        labels_text_tuples  = [tuple(list(tup) + [lab]) for tup,lab in zip(dict_variables_to_update.get('labels_text_tuples'), list_labels_add)]
    labels_text = [str(el) for el in labels_text_tuples]
    
    #### now we want to identify how many unique labels we have
    unique_labels_text_unordered = sort_tuple(set(labels_text_tuples))
    num_unique_labels = len(unique_labels_text_unordered)
    
    ###### labels to numbers
    number2labels = {num : tup for num, tup in enumerate(unique_labels_text_unordered)}
    labels2number = {v:k for k,v in number2labels.items()}
    number2labelsstr = {k:str(v) for k, v in number2labels.items()}
    ####### labels concise
    labels_concise = np.unique(np.vstack(labels_text_tuples).flatten())
    labels_unique_tuples = unique_labels_text_unordered.copy()
    labels_unique = np.unique(list(number2labels.keys()))
    labels_unique_order = labels_unique.copy()
    labels_unique_text = np.array([number2labelsstr[lab] for lab in labels_unique])
    
    labels_text = np.array(labels_text) #[labels_tup2text[tup] for tup in labels_text_tuples]
    labels = np.array([labels2number[tup] for tup in labels_text_tuples])
    
    num_axes = len(labels_text_tuples[0])
    
    new_dict = {
    #"dict_variables_to_update": dict_variables_to_update,
    "list_labels_add": list_labels_add,
    #"old_dict": old_dict,
    "labels_unique_text": labels_unique_text,
    "labels_text_tuples": labels_text_tuples,
    "labels_text": labels_text,
    "unique_labels_text_unordered": unique_labels_text_unordered,
    "num_unique_labels": num_unique_labels,
    "number2labels": number2labels,
    "labels2number": labels2number,
    "number2labelsstr": number2labelsstr,
    "labels_concise": labels_concise,
    "labels_unique_tuples": labels_unique_tuples,
    "labels_unique": labels_unique,
    "labels_unique_order": labels_unique_order,
    "labels_unique_text": labels_unique_text,
    "labels": labels,
    "num_axes":num_axes
    }

    return new_dict, old_dict
    
    
import numpy as np
def transform_A_with_repeats2dict(full_A, dict_of_unique_labels , labels_dict, axis_names = [] , class_names = []):
    # full_A is a tensor. 
    # return it as a dict of {class_name : part of tensor}
    # axis names is repeated names of axis. i.e. [odor, odor, context, context] if we have 4 ensembles
    if checkEmptyList(class_names):
        num_classes = len(np.unique(axis_names))
        class_names = np.arange(num_classes)
    if isinstance(full_A, dict):
        return full_A

    assert isinstance(full_A, np.ndarray), "full_A must be matrix" 
    assert full_A.shape[1] == len(axis_names), "mismatch between axis names and shape of A!"
    assert set(class_names) == set(axis_names), "set(class_names) == set(axis_names) must be the same but set(class_names): %s vs set(axis_names): %s"%(str(set(class_names)) , str(set(axis_names)))
    assert  isinstance(dict_of_unique_labels, dict)
    #dict_of_unique_labels = transform_labels_as_list_of_tuples2dict(dict_of_unique_labels, class_names = class_names)
        
    class_name_to_ensemble_indices = {class_name : np.where(np.array(axis_names) == class_name)[0] for class_name in class_names}
    #full_A_return = {class_names[class_num]:  
    #                 for class_num, labels_for_class in  dict_of_unique_labels.items()} # this gives me a list of labels for that class
    full_A_return = {}
    for class_num, (class_name,  labels_for_class) in  enumerate(dict_of_unique_labels.items()):
        #class_name = class_names[class_num]
        A_tensor_limited_columns = full_A[:,class_name_to_ensemble_indices[class_name],:]  
        A_build_now =  []
        labels_full_now = labels_dict[class_name]
        for label_num, label in enumerate(labels_for_class):
            
            indices_trials = np.where(np.array(labels_full_now) == label)[0]
            
            index_trial = indices_trials[0]
            A_now = A_tensor_limited_columns[:,:,index_trial]
            # check that full A are the same for all of them
            for index_trial_check in indices_trials:
                A_now_check= A_tensor_limited_columns[:,:,index_trial_check]
                assert (A_now_check == A_now).all(), 'something is wrong for clas %s label %s'%(class_name, label)
            
            A_build_now.append(A_now)
            
        if len(A_build_now) == 1:
            A_build_now = np.expand_dims(A_build_now[0], 2)
        else:
            A_build_now = np.dstack(A_build_now)
            
        full_A_return[class_name] = A_build_now
            
            

    return full_A_return 
        
        

    
    
    
    
def transform_labels_as_list_of_tuples2dict(labels_list_of_tuples, class_names = []):
    # input is list of tuples 
    # need to get  a dict of {class name : list of labels for all trials}
    # if class names (i.e. axis names) are not available - then count them from 0
    assert isinstance(labels_list_of_tuples, (list,tuple,np.ndarray))
    num_classes = len(labels_list_of_tuples[0])
    #N_trials = len(labels_list_of_tuples)
    if checkEmptyList(class_names):
        class_names = np.arange(num_classes)
    assert num_classes == len(class_names), "num_classes %d !=  len(class_names) %d vs %s" %(num_classes , len(class_names), class_names)
    
    dict_of_unique_labels = {class_name : np.sort(np.unique(np.array([tup[class_num] for tup in labels_list_of_tuples])))
                   for class_num, class_name in enumerate(class_names)} # dict linking class names to labels
    
    labels_dict = {class_name : np.array([tup[class_num] for tup in labels_list_of_tuples])
                   for class_num, class_name in enumerate(class_names)}
    
    return dict_of_unique_labels , labels_dict
    
    
    
    
    
import numpy as np



def aic_bic(data, A, phi, param_count, alpha=1.0):
    """
    Compute AIC and BIC with parameter count scaled by alpha for similarity regularization.

    Parameters
    ----------
    data : ndarray
        (N, T, K)
    A : ndarray
        (N, p) or (N, p, K)
    phi : ndarray
        (p, T) or (p, T, K)
    param_count : int
        Number of free parameters without regularization.
    alpha : float, optional
        Similarity factor ∈ [0,1]. Default 1 (no regularization).

    Returns
    -------
    AIC : float
    BIC : float
    rss : float
    """
    N, T, K = data.shape
    rss = 0.0
    for k in range(K):
        A_k = A if A.ndim == 2 else A[:, :, k]
        phi_k = phi if phi.ndim == 2 else phi[:, :, k].T
        assert A_k.shape[1] == phi_k.shape[0], (A_k.shape, phi_k.shape)
        pred = A_k @ phi_k
        rss += np.sum((data[:, :, k] - pred) ** 2)

    n_obs = N * T * K
    sigma2 = rss / n_obs
    logL = -0.5 * n_obs * (np.log(2 * np.pi * sigma2) + 1)

    eff_param_count = alpha * param_count

    AIC = 2 * eff_param_count - 2 * logL
    BIC = eff_param_count * np.log(n_obs) - 2 * logL

    return AIC, BIC, logL 

def calculate_logl_and_mse(Y, reco):
    """
    Compute log-likelihood (Gaussian) and mean squared error.

    Parameters
    ----------
    Y : ndarray
        Original data
    reco : ndarray
        Reconstructed data

    Returns
    -------
    log_l : float
        Log-likelihood under Gaussian assumption
    mse : float
        Mean squared error
    """
    resid = Y - reco
    mse = np.mean(resid**2)
    n_obs = Y.size
    sigma2 = mse if mse > 0 else 1e-12  # prevent log(0)
    log_l = -0.5 * n_obs * (np.log(2 * np.pi * sigma2) + 1)
    return log_l, mse




def penalized_aic_multi_MILCCI(Y, full_A, Phi, labels_list, lambdas_per_ensemble = {},
                                  lambdas_per_class = {}, axis_names=[], class_names = []):
    """
    Compute penalized AIC for multi-siblings model.

    Args:
        Y: np.array of shape (N, T, M) - data
        full_A: dict of {d: np.array of shape (N, p_d, num_unique_labels_class_d)} - factor tensors
        Phi: np.array of shape (T, p, M) - trial factor matrices
        labels_list: list or array of trial labels per class
        lambdas: dict.
        for convenience - enable either by class or by ensemble.
        {ensemble_number: np.array of shape (number of unique labels for that ensemble class X number of unique labels for that ensemble class)} - penalty between labels
        {classnumber : np.array of shape (number of unique labels for that ensemble class X number of unique labels for that ensemble class)} -}
        axis_names: list of class names corresponding to each ensemble (axes) in Phi
        class_names: list of class names
    Returns:
        AIC_penalized: float
        df_total: total degrees of freedom
        log_l: log-likelihood
    """
    assert isinstance(Y, np.ndarray)
    assert len(lambdas_per_ensemble) > 0 or len(lambdas_per_class) > 0
    n_ensembles = full_A.shape[1]
    
    if len(lambdas_per_ensemble) > 0 and len(lambdas_per_class) == 0:
        assert set(list(lambdas_per_ensemble.keys())) == set(np.arange(n_ensembles)) 
        assert len(axis_names) > 0, "if building on lambdas_per_ensemble, then ensembles meaning must be non empty!"
        assert set(axis_names)==set(class_names), "%s != %s"%(str(set(axis_names)),str(set(axis_names)))
        
        # make sure that lambda is fixed for ensembles of the same class               
        lambdas_per_class = {}
        for ensemble_index, ensemble_meaning in enumerate(axis_names):
            if ensemble_meaning not in lambdas_per_class:
                lambdas_per_class[ensemble_meaning] = lambdas_per_ensemble[ensemble_index]
            else:
                assert (lambdas_per_class[ensemble_meaning] == lambdas_per_ensemble[ensemble_index]).all(), 'TB implemented. For now, if 2 ensembles are of the same class, they must have the same lambda value! but for class %s it is not the case'%ensemble_meaning
    lambdas = lambdas_per_class.copy()
    N, T, M = Y.shape
    assert isinstance(full_A, (dict, np.ndarray))
    assert isinstance(labels_list, (list, np.ndarray))

    dict_of_unique_labels, labels_dict = transform_labels_as_list_of_tuples2dict(
        labels_list, class_names=class_names
    )

    if isinstance(full_A, np.ndarray):
        full_A = transform_A_with_repeats2dict(
            full_A, dict_of_unique_labels, labels_dict, axis_names=axis_names, class_names=class_names
        )

    assert isinstance(dict_of_unique_labels, dict)
    assert Phi.shape[0] == T and Phi.shape[2] == M, "Phi must match Y in time and trials"

    ensemble_indices_per_class = {class_name: np.where(np.array(axis_names) == class_name)[0] for class_name in class_names}

    df_total = 0
    resid_sum = 0

    for d, A_d in full_A.items():  # d is class name, A_d is the current sub of A
        class_name = d
        num_unique_labels_class_d = A_d.shape[2]  # number of unique labels in class d
        p_d = A_d.shape[1]  # number of ensembles in class d
        assert A_d.shape[0] == N, f"A_d for class {d} must have N rows"
        df_d = 0
        ensemble_indices_for_current_class = ensemble_indices_per_class[class_name]

        for label_count, label_meaning in enumerate(dict_of_unique_labels[class_name]):
            labels_class_d = labels_dict[class_name]
            trial_idx = np.where(np.array(labels_class_d) == label_meaning)[0]
            assert len(trial_idx) > 0, f"No trials for class {d} label # {label_count}, {label_meaning}"

            # vertical stack Y across trials with this label
            # Y_current_label: vertical concatenation of trials
            Y_current_label = Y[:, :, trial_idx]  # N x T x n_trials
            if len(trial_idx) == 1:
                Y_current_label = np.expand_dims(Y_current_label, 2)
            Y_current_label = np.vstack([Y_current_label[:, :, layer].T for layer in range(Y_current_label.shape[2])])  # (T*n_trials) x N
            if Y_current_label.ndim == 3:
                assert 1 in Y_current_label.shape
                Y_current_label = np.squeeze(Y_current_label )
            # Phi_current_label: vertical concatenation of ensemble traces
            Phi_current_label = Phi[:, ensemble_indices_for_current_class, :][:,:, trial_idx]  # T x p_d x n_trials
            if Phi_current_label.ndim == 2: #len(trial_idx) == 1:
                Phi_current_label = np.expand_dims(Phi_current_label, 2)
            assert Phi_current_label.ndim == 3, Phi_current_label.shape
            Phi_current_label = np.vstack([Phi_current_label[:, :, layer] for layer in range(Phi_current_label.shape[2])])  # (T*n_trials) x p_d

            # regularization sum
            reg = sum(lambdas[class_name][label_count, k] for k in range(num_unique_labels_class_d) if k != label_count)

            # hat matrix
            #print('Phi_current_label.shape %s'%str(Phi_current_label.shape))
            #print('p_d %d' %p_d)
            assert Phi_current_label.shape[1] == p_d
            H_ell = np.linalg.inv(Phi_current_label.T @ Phi_current_label + reg * np.eye(p_d)) @ Phi_current_label.T
            df_ell = np.trace(Phi_current_label @ H_ell)
            df_d += df_ell

            #########################################################
            # Q term: inverse of regularized Phi^T Phi
            Q = np.linalg.inv(Phi_current_label.T @ Phi_current_label + reg * np.eye(p_d))
            
            # Term1: sum over other labels of Phi^T @ Phi_other @ A_other
            term1 = np.zeros((p_d, N))
            for other_label in range(num_unique_labels_class_d):
                if other_label == label_count:
                    to_skip = True
                else:
                    to_skip = False
                if not to_skip:
                    # vertical stack Phi for the other label
                    trial_idx_other = np.where(np.array(labels_class_d) == dict_of_unique_labels[class_name][other_label])[0]
                    Phi_other = Phi[:, ensemble_indices_for_current_class , :][:,:, trial_idx_other] # phi shape should be T X ensmebles X number of trials other
                    if  Phi_other.ndim == 2:#len(trial_idx_other) == 1:
                        Phi_other = np.expand_dims(Phi_other, 2)
                    Phi_other = np.vstack([Phi_other[:, :, layer] for layer in range(Phi_other.shape[2])])  # (T*n_trials_other) x p_d
                    assert Phi_current_label.T.shape == (p_d, T * len(trial_idx)), \
                        f"Phi_current_label.T.shape is {Phi_current_label.T.shape}, expected ({p_d}, {T * len(trial_idx)})"
                    
                    assert Phi_other.shape == (T * len(trial_idx_other), p_d), \
                        "Phi_other.shape is %s, expected (%d, %d), len_trial is %d, T is %d" % (Phi_other.shape, T * len(trial_idx_other), p_d, len(trial_idx_other), T)
    
                    
                    assert A_d[:, :, other_label].shape == (N, p_d), \
                        "A_d[:,:,other_label].shape is %s, expected (%d, %d)" % (A_d[:, :, other_label].shape, N, p_d)
    
    
                    term1 += Phi_current_label.T @ (Phi_current_label @ A_d[:, :, other_label].T)
                    
            # Term2: sum over other labels of lambda * A_other
            term2 = np.sum([lambdas[class_name][label_count, k] * A_d[:, :, k] for k in range(num_unique_labels_class_d) if k != label_count], axis=0)

            

            Psi = Q @ (-term1 + term2.T)


            
            ########################################################
                        
            Y_hat_current_label =  Phi_current_label @ H_ell @ Y_current_label + Phi_current_label @ Psi
            resid_sum += np.sum((Y_current_label - Y_hat_current_label) ** 2)

        df_total += df_d

    # estimate sigma^2 from residuals
    sigma2_hat = resid_sum / (N * T * M)

    # log-likelihood
    log_l = -0.5 / sigma2_hat * resid_sum - 0.5 * N * M * np.log(2 * np.pi * sigma2_hat)

    # penalized AIC
    AIC_penalized = 2 * df_total - 2 * log_l
    return AIC_penalized, df_total, log_l



    