import sys
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from args import *

from scipy import stats

def get_score_dist(scores):
    return np.mean(scores), np.nanquantile(scores, 0.025), np.nanquantile(scores, 0.975)

def ci_score_analysis(args):
    bootstrap_results_path = 'bootstrap_results/{}'
    actual_results_path = 'ecog-multimodal/results/{}'
    randomized_str = '-randomized' if args.randomized else ''
    bootstrap_results = pd.read_parquet(os.path.join(bootstrap_results_path.format(args.subject), f'{args.subject}_trial000_{args.alignment}_{args.model_name}{randomized_str}_{args.model_output}_bootstrap.parquet.gzip'))
    actual_results = pd.read_parquet(os.path.join(actual_results_path.format(args.subject), f'{args.subject}_trial000_{args.alignment}_{args.model_name}{randomized_str}_{args.model_output}_200mswindow_results.parquet.gzip'))

    bootstrap_results['times'] = bootstrap_results['times'].astype(int)
    bootstrap_results = bootstrap_results.sort_values(by = ['electrode', 'times', 'bootstrap']).reset_index(drop = True)
    actual_results['times'] = actual_results['times'].astype(int)
    actual_results = actual_results.sort_values(by = ['electrode', 'times']).reset_index(drop = True)

    test_mean_bootstraps = []
    test_upper_cis = []
    test_lower_cis = []
    val_mean_bootstraps = []
    val_upper_cis = []
    val_lower_cis = []
    for i in range(0, len(bootstrap_results), 1000):
        one_bts_results = bootstrap_results.iloc[i : i + 1000]
        test_scores = one_bts_results['test_score'].to_numpy()
        val_scores = one_bts_results['val_score'].to_numpy()
        test_mean_bts, test_upper_ci, test_lower_ci = get_score_dist(test_scores)
        val_mean_bts, val_upper_ci, val_lower_ci = get_score_dist(val_scores)
        test_mean_bootstraps.append(test_mean_bts)
        test_upper_cis.append(test_upper_ci)
        test_lower_cis.append(test_lower_ci)
        val_mean_bootstraps.append(val_mean_bts)
        val_upper_cis.append(val_upper_ci)
        val_lower_cis.append(val_lower_ci)
    actual_results['lower_ci'] = test_lower_cis
    actual_results['upper_ci'] = test_upper_cis
    actual_results['mean_bootstrap'] = test_mean_bootstraps
    actual_results['val_lower_ci'] = val_lower_cis
    actual_results['val_upper_ci'] = val_upper_cis
    actual_results['val_mean_bootstrap'] = val_mean_bootstraps
    actual_results.to_parquet(os.path.join(actual_results_path.format(args.subject), f'{args.subject}_trial000_{args.alignment}_{args.model_name}{randomized_str}_{args.model_output}_200mswindow_results.parquet.gzip'))
    return actual_results

def stack_results():
    results_path = 'results/{}/{}_trial000_{}_{}_{}_200mswindow_results.parquet.gzip'
    final_save_path = 'final-results'
    models = ['albef', 'beit', 'blip', 'convnext', 'flava', 'sbert', 'simcse', 'slip-clip-language', 'slip-clip-vision', 'slip-combo-language', 'slip-combo-vision', 'slip-simclr']
    subjects = ['m00183', 'm00184', 'm00185', 'm00188', 'm00193', 'm00194', 'm00195']
    mul_models = ['albef', 'albef-randomized', 'blip', 'blip-randomized', 'flava', 'flava-randomized', 'clip', 'slip']
    model_output = 'best_layer'
    all_results = []
    for model in models:
        model_results = []
        for subject in subjects:
            for alignment in ['language', 'vision']:
                for train_str in ['', '-randomized']:
                    model_str = f'{model}{train_str}'
                    results = pd.read_parquet(results_path.format(subject, subject, alignment, model_str, model_output))
                    results['subject'] = subject
                    results['alignment'] = alignment
                    if model_str in mul_models:
                        results['mul_uni'] = 'multimodal'
                    else:
                        results['mul_uni'] = 'unimodal'
                    if 'val_score' in results.columns:
                        results = results.drop('val_score', axis = 1)
                    if 'train_score' in results.columns:
                        results = results.drop('train_score', axis = 1)
                    if 'model_layer_index' in results.columns:
                        results = results.drop('model_layer_index', axis = 1)
                    results = results.drop(['alpha', 'score_set', 'score_type'], axis = 1)
                    model_results.append(results)
        model_results = pd.concat(model_results).reset_index(drop = True)
        model_results.to_parquet(os.path.join(final_save_path, f'{model}.parquet.gzip'))
        all_results.append(model_results)
    all_results = pd.concat(all_results).reset_index(drop = True)
    all_results.to_parquet(os.path.join(final_save_path, 'raw_results.parquet.gzip'))
    return all_results

if __name__ == '__main__':
    args = regression_args()
    ci_score_analysis(args)
    pass