import numpy as np
from scipy.stats import ttest_rel
import pandas as pd
from scipy.stats import false_discovery_control

def mse_max_model(mse_models):
    
    # given mse values of shape num_models x num_voxels
    # returns the mse value of the best model 
    
    mse_avg = np.mean(mse_models, axis=1)
    best_model_idx = np.argmin(mse_avg, axis=0)
    
    best_models = []
    for i, bmi in enumerate(best_model_idx):
        best_models.append(mse_models[bmi, :, i])

    return np.stack(best_models)

def compute_p_val(exp, num_vox_dict, mse_A, mse_B):
    
    stat, pvals = ttest_rel(mse_A, mse_B, axis=0, alternative='less')
    
    return pvals

def arrange_pvals_pd(pvals, exp, subjects_dict, br_labels_dict, non_nan_indices):
    
    
    pvals_dict = pd.DataFrame({'pvals': pvals, 
                               'subjects': subjects_dict[exp][non_nan_indices], 
                               'network': br_labels_dict[exp][non_nan_indices]})
    pvals_dict_updated = {}
    
    pvals_adj = []
    subjects = []
    network = []
    pvals_list = []
    for s in np.unique(subjects_dict[exp]):
        for n in np.unique(br_labels_dict[exp]):
            pvals_sn = pvals_dict.loc[(pvals_dict.subjects==s)&(pvals_dict.network==n)]['pvals']
            pvals_adj_sn = false_discovery_control(pvals_sn, method='bh')
            pvals_adj.extend(pvals_adj_sn)
            subjects.extend(np.repeat(s, len(pvals_adj_sn)))
            network.extend(np.repeat(n, len(pvals_adj_sn)))
            pvals_list.extend(pvals_sn)
        
    pvals_dict['pvals'] = pvals_list   
    pvals_dict['pvals_adj'] = pvals_adj
    pvals_dict['subjects'] = subjects
    pvals_dict['network'] = network
    
    return pd.DataFrame(pvals_dict)

def max_across_nested(df, updated_model_name):
    
    max_indices = df.groupby(['voxel_id', 'Network', 'subjects'])['r2'].idxmax()
    
    # Use the indices to extract corresponding rows
    max_rows = df.loc[max_indices]

    # Reset index to create DataFrame
    result = max_rows.reset_index(drop=True)
    result.Model = np.repeat(updated_model_name, len(result))
    
    return result, max_indices
        
def remove_neg_r2(df):
    
   df['r2'] = np.clip(df['r2'], a_min=0, a_max=None)
   
   return df

def modified_r2_and_idxs(r2_stacked_pd, BL_str, nested_name, full_name):
    
    BL_model = remove_neg_r2(r2_stacked_pd.loc[r2_stacked_pd.Model==BL_str]).reset_index()
    nested_model, max_indices_nested = max_across_nested(r2_stacked_pd.loc[~r2_stacked_pd.Model.str.contains(BL_str)].reset_index(), nested_name)
    full_model, max_indices_full = max_across_nested(r2_stacked_pd.loc[r2_stacked_pd.Model.str.contains(BL_str)].reset_index(), full_name)
    
    return BL_model, nested_model, full_model

