import os
import pandas as pd
import numpy as np
from tqdm import tqdm 
from scipy import stats

def group_summary_stats(df, groupby_cols, value_col, method='mean', use_value_name = True):
    '''
    Group stats on pandas dataframe
    '''
    if method not in ['mean', 'median']:
        raise ValueError("method should be either 'mean' or 'median'")

    def ci_95(data):
        lower, upper = stats.t.interval(alpha=0.95, df=len(data) - 1, loc=np.mean(data), scale=stats.sem(data))
        return pd.Series({'lower_ci': lower, 'upper_ci': upper})
    
    if not isinstance(groupby_cols, list):
        groupby_cols = [groupby_cols]

    grouped = df.groupby(groupby_cols)[value_col]
    
    if method == 'mean':
        summary = grouped.agg(['count', 'mean'])
        ci_df = grouped.apply(ci_95).unstack().reset_index()
    if method == 'median':
        summary = grouped.agg(['count', 'median'])
        ci_df = grouped.apply(lambda x: pd.Series({'lower_ci': np.percentile(x, 2.5), 
                                                   'upper_ci': np.percentile(x, 97.5)})).reset_index()

    summary = pd.concat([summary.reset_index(), ci_df[['lower_ci', 'upper_ci']]], axis=1)
    summary.columns = groupby_cols + ['count', method, 'lower_ci', 'upper_ci']
    
    if use_value_name:
        new_columns = {method: f'{value_col}_{method}',
                       'lower_ci': f'{value_col}_lower_ci',
                       'upper_ci': f'{value_col}_upper_ci'}
        summary.rename(columns = new_columns, inplace = True)
    
    return summary

def nested_dict_values(dictionary):
    for value in dictionary.values():
        if isinstance(value, dict):
            yield from nested_dict_values(value)
        else:
            yield value

def make_ranklists(data):
    '''
    Loop through each model_id and each electrode_id, bootstrap the mean score across times and save the results in a new dataframe
    '''
    model_ranklists = []
    group_vars = ['model_id', 'electrode_id', 'data_alignment']
    summary = {'val':  group_summary_stats(data, group_vars, 'val_score'),
            'test': group_summary_stats(data, group_vars, 'test_score_original')}
    for electrode_id in tqdm(data['electrode_id'].unique()):
        nest1 = summary['val'][summary['val']['electrode_id'] == electrode_id]
        for alignment in ['vision','language']:
            nest2 = nest1[nest1['data_alignment'] == alignment]
            if len(nest2) > 1:
                model_ranklist = []
                for model_id in nest2['model_id'].unique():
                    nest3 = nest2[nest2['model_id'] == model_id]
                    scores = nest3[['val_score_mean', 'val_score_lower_ci', 'val_score_upper_ci']].copy()
                    scores.rename(columns = {'val_score_mean': 'score', 'nobs': 'count',
                                            'val_score_lower_ci': 'lower_ci',
                                            'val_score_upper_ci': 'upper_ci'}, 
                                inplace = True)
                    scores = scores.to_dict(orient='records')[0]
                
                    model_ranklist.append({'model_id': model_id, 'data_alignment': alignment,
                                        'electrode_id': electrode_id, **scores})
                    
                model_ranklist = pd.DataFrame(model_ranklist)
                model_ranklist = model_ranklist.sort_values('score', ascending=False)
                model_ranklist.reset_index(inplace=True, drop=True)
                model_ranklist['rank'] = model_ranklist.index + 1
                
                model_ranklists.append(model_ranklist)
            
    model_ranklists = pd.concat(model_ranklists)
    return model_ranklists

def run_bootstrap(target_data, model_ranklist, model1, model2, n_boots = 1000):
    '''
    Run bootstrap comparison between model1 and model2
    '''
    model1_data = target_data[target_data['model_id'] == model1].copy()
    model2_data = target_data[target_data['model_id'] == model2].copy()
    
    model1_rank = model_ranklist[model_ranklist['model_id'] == model1]['rank'].values[0]
    model2_rank = model_ranklist[model_ranklist['model_id'] == model2]['rank'].values[0]
    
    model1_data['score'] = model1_data['test_score_original']
    model2_data['score'] = model2_data['test_score_original']
    
    model1_data = model1_data[['times','score']]
    model2_data = model2_data[['times','score']]
    
    top_model = model_ranklist.iloc[0,:]['model_id']
    
    differences  = np.empty(1000)
    model1_means = np.empty(1000)
    model2_means = np.empty(1000)
    for i in range(1000):
        model1_sample = model1_data.sample(frac=1, replace=True)
        model2_sample = model2_data.sample(frac=1, replace=True)
        model1_mean = model1_sample['score'].mean()
        model2_mean = model2_sample['score'].mean()
        model1_nobs = len(model1_sample)
        model2_nobs = len(model2_sample)
        model1_means[i] = model1_mean
        model2_means[i] = model2_mean
        differences[i] = model1_mean - model2_mean
        
    results_dict = {'n_bootstraps': 1000, 
                    'top_model': top_model,
                    'model_id_1': model1, 
                    'model_id_2': model2, 
                    'model1_nobs': model1_nobs,
                    'model2_nobs': model2_nobs,
                    'model1_rank': model1_rank,
                    'model2_rank': model2_rank,
                    'score_1': model1_means.mean(), 
                    'score_2': model2_means.mean(),
                    'p_value': 1 - (differences > 0).mean()}
                            
    return results_dict

def get_bootstrap_dict(electrode_id, alignment, comparison, 
                       target_data, model_ranklist,
                       model1, model2, n_boots = 1000):

    '''
    Access comparison dictionary
    '''
    
    boot_results = run_bootstrap(target_data, model_ranklist, 
                                 model1, model2, n_boots)
    
    return {'electrode_id': electrode_id, 'alignment': alignment,
            'comparison': comparison, **boot_results}
    

def model_comparisons(data, model_ranklists):
    '''
    Compare models according to four general comparisons. Two of them are part of our multimodality tests and the third is part of the trained-versus-randomly initialized test
    '''
    default_winners = []
    comparison_list = []
    nest = model_ranklists.copy()
    model_data = data[['model_id','model','train_type','model_modality']]
    model_data = model_data.drop_duplicates().reset_index(drop=True)
    boot_min_nobs = 10
    for electrode_id in tqdm(model_ranklists['electrode_id'].unique()):
        nest1 = nest[nest['electrode_id'] == electrode_id]
        data1 = data[data['electrode_id'] == electrode_id]
        for alignment in nest1['data_alignment'].unique():
            nest2 = nest1[nest1['data_alignment'] == alignment]
            data2 = data1[data1['data_alignment'] == alignment]
            
            ranklist = nest2.copy()
            sub_data = data2.copy()
            
            ranklist = ranklist.merge(model_data, on='model_id', how='left')
            
            # comparison 0 -- rank1-versus-rank2
            
            comparison = 'rank1-beats-rank2'

            model1, model2 = ranklist['model_id'].values[:2]
            
            model1_data = data2[data2['model_id'] == model1].copy()
            
            comparison_list.append(get_bootstrap_dict(electrode_id, alignment, comparison, 
                                                    sub_data, ranklist, model1, model2))
                                
            # comparison 1+2 - trained-versus-random
            
            if not 'trained' in ranklist['train_type'].values:
                if 'randomized' in ranklist['train_type'].values:
                    if len(model1_data) > boot_min_nobs:
                        default_winners.append({'electrode_id': electrode_id, 'alignment': alignment,
                                                    'top_ranking_model': model1, 'comparison': 'random-beats-trained'})
            
            if not 'randomized' in ranklist['train_type'].values:
                if 'trained' in ranklist['train_type'].values:
                    if len(model1_data) > boot_min_nobs:
                        default_winners.append({'electrode_id': electrode_id, 'alignment': alignment,
                                                    'top_ranking_model': model1, 'comparison': 'trained-beats-random'})
                    
            if 'trained' in ranklist['train_type'].values and 'randomized' in ranklist['train_type'].values:
            
                trained_info = ranklist[ranklist['train_type'] == 'trained'].iloc[0,:]
                random_info = ranklist[ranklist['train_type'] == 'randomized'].iloc[0,:]
                
                trained_rank, random_rank = trained_info['rank'], random_info['rank']
                
                if trained_rank < random_rank:
                    comparison = 'trained-beats-random'
                    
                    model1 = trained_info['model_id']
                    model2 = random_info['model_id']
                
                if random_rank < trained_rank:
                    comparison = 'random-beats-trained'
                    
                    model1 = random_info['model_id']
                    model2 = trained_info['model_id']
                    
                comparison_list.append(get_bootstrap_dict(electrode_id, alignment, comparison, 
                                                        sub_data, ranklist, model1, model2))
                
            # comparison 3 -- multimodality-versus-all
            
            comparison = 'multi-beats-unimodal'
            
            if 'multimodal' in ranklist['model_modality'].values:
                if not 'unimodal' in ranklist['model_modality'].values:
                    if len(model1_data) > boot_min_nobs:
                        default_winners.append({'electrode_id': electrode_id, 'alignment': alignment,
                                                'top_ranking_model': model1, 'comparison': comparison})
            
            # comparisons onward: check multimodality
                    
            if ranklist.iloc[0,:]['model_modality'] != 'multimodal':
                continue
            
            if 'unimodal' in ranklist['model_modality'].values:
                
                multimodal_info = ranklist[ranklist['model_modality'] == 'multimodal'].iloc[0,:]
                unimodal_info = ranklist[ranklist['model_modality'] == 'unimodal'].iloc[0,:]
                
                multimodal_rank = multimodal_info['rank']
                unimodal_rank = unimodal_info['rank']
                
                if multimodal_rank < unimodal_rank:
                    
                    model1 = multimodal_info['model_id']
                    model2 = unimodal_info['model_id']
                    
                    comparison_list.append(get_bootstrap_dict(electrode_id, alignment, comparison,
                                                            sub_data, ranklist, model1, model2))
            
            # comparison 4 -- simclr-versus-slip
            
            comparison = 'slip-beats-simclr'
            
            all_slip_info = ranklist[ranklist['model_id'].str.contains('slip')]
            
            if len(all_slip_info) > 0:
        
                if 'slip-combo-vision-trained' in all_slip_info['model_id'].values:
                    if (not 'slip-simclr-trained' in all_slip_info['model_id'].values
                        or not 'slip-simclr-randomized' in all_slip_info['model_id'].values):
                            if len(model1_data) > boot_min_nobs:
                                default_winners.append({'electrode_id': electrode_id, 'alignment': alignment,
                                                        'top_ranking_model': model1, 'comparison': comparison})
            
                if ('slip-combo-vision-trained' in all_slip_info['model_id'].values and 
                    ('slip-simclr-trained' in all_slip_info['model_id'].values
                    or 'slip-simclr-randomized' in all_slip_info['model_id'].values)):
                
                    slipco_info = all_slip_info[all_slip_info['model_id'] == 'slip-vision-trained'].iloc[0,:]
                    simclr_info = all_slip_info[all_slip_info['model_id'].str.contains('simclr')].iloc[0,:]
                    
                    slipco_rank = slipco_info['rank']
                    simclr_rank = simclr_info['rank']
                        
                    if slipco_rank < simclr_rank:
                            
                        model1 = slipco_info['model_id']
                        model2 = simclr_info['model_id']
                        
                        comparison_list.append(get_bootstrap_dict(electrode_id, alignment, comparison, 
                                                                sub_data, ranklist, model1, model2))
                    
    default_winners = pd.DataFrame(default_winners)
    comparison_data = pd.DataFrame(comparison_list)
    return default_winners, comparison_data