import numpy as np
import seaborn as sns 
import pandas as pd 
from matplotlib import pyplot as plt
import sys
sys.path.append('/home3/name/what-is-brainscore/')
from helper_funcs import * 
from scipy.stats import pearsonr 
import matplotlib
import nibabel as nib

def load_bil_r2(model_name, resultsFolder, neural_data):

    best_layer = find_best_layer(model_name, resultsFolder, required_str=['encoder'], 
                exclude_str=['384', '243', 'layer1'])
    
    print("Best Layer: ", best_layer)

    # get the best layer name,
    # annoying name differences so have to do different processing steps
    # for roberta vs. gpt
    
    split_str = best_layer.split('_')
    if 'bert' in model_name:
        bl = f"{split_str[2]}_{split_str[3]}_1"
    else:
        bl = f"{split_str[2]}_1"
        
    model_dict_best_layer = {model_name: [bl]}
    
    bil_r2_pd = compute_R2(model_dict_best_layer, neural_data, dataset='pereira', 
            resultsFolder=resultsFolder, exp='both')
    bil_r2_pd
    
    return bil_r2_pd

def plot_test_perf_across_layers(model_arr, dataset, layers_range, layer_name_arr, best_layer, saveName, 
                                 figurePath, resultsFolder, c, yticks, exp=None, val=False):
    
    '''
    :param list model_arr: model names to load
    :param str dataset: which dataset to load data from 
    :param int layers_range: number of layers in model 
    :param str title: plot title 
    :param str layer_name_arr: name of layer to load 
    :param list best_layer: best layer of each model in model_arr
    :param str saveName: where to save model 
    :param str figurePath: where to save figures
    :param str resultsFolder: where to retrieve results from
    :param list c: colors for each model 
    :param list yticks: yticks to plot
    '''

    counter = 0
    
    if val:
        load_str = 'val_perf'
    else:
        load_str = 'out_of_sample_r2'
    
    plt.figure(figsize=(14,8))
    
    for model, layer_range, layer_name, bl in zip(model_arr, layers_range, layer_name_arr, best_layer):
        
        
        r2_layer = []
        
        # because the embedding layer is always sum pooled, I don't add an sp to indicate
        # sum pooling in its filename (that's why the .replace('-sp', '') is there).
        
        if exp is not None:
            filename = f"{resultsFolder}{dataset}_{model.replace('-sp', '')}-static_layer1_1_{exp}.npz"
        else:
            filename = f"{resultsFolder}{dataset}_{model.replace('-sp', '')}-static_layer1_1.npz"
            
        results = np.load(filename)
        r2_emb_pos_m = np.nanmean(results[load_str])
        
        for i in range(layer_range[0], layer_range[1]+1):
            if exp is not None:
                filename = f"{resultsFolder}{dataset}_{model}_{layer_name}{i}_1_{exp}.npz"
            else:
                filename = f"{resultsFolder}{dataset}_{model}_{layer_name}{i}_1.npz"
                
            results = np.load(filename)
            r2_layer.append(np.nanmean(results[load_str]))
        
        if r2_emb_pos_m is not None:
            plt.axhline(r2_emb_pos_m, linestyle='--', color=c[counter], label=model)
        plt.plot(r2_layer, marker='o', color=c[counter])
        plt.axvline(bl, color=c[counter], linestyle='--')
        counter += 1
        
    plt.xlabel("Layer number", fontsize=40)
    plt.ylabel('R2' + r"$_{oos}$", fontsize=40)
    plt.xticks(fontsize=30) 
    plt.yticks(yticks, fontsize=30) 
    plt.legend()
    plt.legend(fontsize=20)
    plt.show()
    
def plot_by_network(models, layers, updated_names, neural_data, dataset, resultsFolder, saveName, figurePath):
    
    
    model_layer_dict = {}
    for i, m in enumerate(models):
        model_layer_dict[m] = layers[i]
            
    model_new_name_dict = {}
    for i, m in enumerate(models):
        for l, n in zip(layers[i], updated_names[i]):
            model_new_name_dict[f"{m}"] = n

    R2_dict = compute_R2(model_layer_dict, neural_data, dataset, resultsFolder)

    plt.figure(figsize=(14,8))
    
    R2_dict.rename(columns={'brain_network': 'Network'}, inplace=True)
    
    for key, val in model_new_name_dict.items():
        R2_dict['Model'] = R2_dict['Model'].replace(key, val)
   
    sns.despine()
    sns.set_theme()
    sns.barplot(data=R2_dict, y='r2', x='Network', hue='Model')
    plt.ylabel('R2' + r"$_{oos}$", fontsize=40)
    plt.xticks(fontsize=30) 
    plt.yticks([0, 0.01, 0.02], fontsize=30) 
    plt.xlabel('')
    plt.legend(fontsize=30)
    plt.savefig(f'{figurePath}/{dataset}{saveName}.pdf', bbox_inches='tight')
    plt.show()


# define some helpful functions
def extract_number_from_string(input_string):
    import re
    match = re.search(r'encoder\.h\.(\d+)', input_string)
    return int(match.group(1))

def load_r2_across_seeds(model_dict, dataset, neural_data, resultsFolder, num_seeds, exp):

    r2 = compute_R2(model_dict, neural_data, dataset, resultsFolder, exp=exp)
    r2_by_vox = np.reshape(np.array(r2.r2), (num_seeds, -1), order='C')
        
    return r2_by_vox
        
def plot_across_seeds(r2_dict, br_labels, num_vox, num_seeds, figurePath, yticks=None, saveName=None,
                      custom_palette=None):
    sns.set_theme()
    store_pd = []
    for key, val in r2_dict.items():
        
        r2_pd = pd.DataFrame({'r2': np.ravel(val), 'seeds': np.repeat(np.arange(num_seeds), num_vox), 
                                'Network': np.tile(br_labels, num_seeds), 
                                'Model': np.repeat(key, num_vox*num_seeds)})
        store_pd.append(r2_pd)
        
    store_pd = pd.concat((store_pd), ignore_index=True)
    if num_seeds > 1:
        grouped_data = store_pd.groupby(['Network', 'seeds', 'Model']).mean()
    else:
        grouped_data = store_pd
    
    plt.figure(figsize=(14,10))
    sns.set_theme()
    sns.set_style("white")
    sns.barplot(data=grouped_data, y='r2', x='Network', hue='Model', palette=custom_palette, alpha=1.0)
    sns.despine()
    plt.legend(fontsize=30)
    plt.ylabel('R2' + r"$_{oos}$", fontsize=40)
    plt.xlabel('')
    if yticks is not None:
        plt.yticks(yticks)
    plt.tick_params(axis='x', labelsize=30) 
    plt.tick_params(axis='y', labelsize=30) 
    if saveName is not None:
        plt.savefig(f'{figurePath}{saveName}.pdf', bbox_inches='tight')
    
def get_subportion(input_str):
    # Find the index where "encoder" starts
    start_index = input_str.find("encoder")
    
    # If "encoder" is found, return the substring starting from there
    if start_index != -1:
        return input_str[start_index:]
    else:
        return "Word 'encoder' not found in the input string."
    
def load_bil(num_seeds, model_name, which_exp, 
             resultsFolder):
    
    '''
    :param int num_seeds: number of seeds to load 
    :param str model_name: model results to load
    :param str which_exp: load results for, 384, 243, or both experiments
    :param str dataset: pereira, fed, blank
    :param str resultsFolder: folder to load results from 
    '''
    
    required_str = ['encoder'] # make sure encoder is in the str so that it's an intermediate layer
    exclude_str = ['no-nonlin'] # remove models we're not interested in 
    
    if which_exp == 'both':
        exclude_str.extend(['243', '384'])
    if which_exp == '243':
        required_str.append('243')
    if which_exp == '384':
        required_str.append('384')
        
    best_layer = []

    model_num = np.arange(num_seeds)
    for mn in model_num:
        bil_str = find_best_layer(model_name, resultsFolder, exclude_str=exclude_str, 
                            required_str=required_str, model_num=mn)
        bil_encoder = get_subportion(bil_str)
        bil_encoder = bil_encoder.replace('.npz', '')
        best_layer.append(bil_encoder)
        
    return best_layer

def plot_scatter(df, val_range, figurePath, saveName=None, mode=''):
    

    sns.set_theme()
    sns.set_style("white")
    
    if mode == 'all_networks':
        fig, axs = plt.subplots(1,1, figsize=(12,10))
        networks = [mode]
        
    # Create subplots   
    else:
        fig, axs = plt.subplots(2, 2, figsize=(18, 16))
        networks = np.unique(df.Network)
        
    for i, region in enumerate(networks):
        
        
        if mode == 'all_networks':
            df_region = df
            title = ''
        else:
            df_region = df.loc[df.Network==region]
            title = region
        
        r, p = pearsonr(df_region.loc[df_region.Model=='Interp'].r2, df_region.loc[df_region.Model=='BIL + Interp'].r2)
        
        if p < 0.01:
            r = f'{round(r,3)}**'
            
        sns.despine()
        
        sns_df = pd.DataFrame({'Interp': df_region.loc[df_region.Model=='Interp'].r2, 
                            'Interp + BIL': df_region.loc[df_region.Model=='BIL + Interp'].r2})
        
        # there's a very small number of voxels with high R2 in visual
        # which makes it really hard to see anything else in the plot. To enhance,
        # visualization, shift the axis to zoom more into the majority of the datapoints.
        
        if mode == 'all_networks':
            ax = axs
        else:
            ax = axs[i//2, i%2]
            
        ax.hist2d(x=sns_df['Interp'], y=sns_df['Interp + BIL'], 
                        norm=matplotlib.colors.LogNorm(), bins=100, cmap='Reds')
        ax.set_title(f"{title}", fontsize=40)
        
        max_val = val_range[region]
        
       
        min_val = -0.10
            
        y_text = max_val - 0.05
        
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)
            
        ax.plot(ax.get_xlim(), ax.get_ylim(), 'r--', alpha=0.75, color='black')
        ax.text(-.04, y_text, f'Pearson r = {r}', ha='left', va='top', size=30)
        ax.set_xticks([0, max_val])
        ax.tick_params(axis='x', labelsize=40) 

        ax.set_yticks([0, max_val])
        ax.tick_params(axis='y', labelsize=40) 
        
        if i % 2 == 0 and i // 2 == 0: 
            ax.set_ylabel("BIL + Interp R2" + r"$_{oos}$", fontsize=40)
        else:
            ax.set_ylabel("")
            
            
        if i // 2 == 1 and i % 2 == 0: 
            ax.set_xlabel("Interp R2" + r"$_{oos}$", fontsize=40)
        else:
            ax.set_xlabel("")   
            
    
        #fig.text(0.5, -0.03, "Out of sample R2", ha='center', fontsize=40)
        plt.tight_layout()

    
    if saveName is not None:
        plt.savefig(f'{figurePath}{saveName}.pdf', bbox_inches='tight')
    
    plt.show()
    

