import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import os
import json
import seaborn as sns
import numpy as np
from sklearn.metrics import auc
import json

from datasets import load_dataset
from sklearn.metrics import auc
import random

from utils.misc_utils import *
from statsmodels.stats.anova import anova_lm
from scipy import stats

# Randomly select 32 examples to calculate task difficulties
BATCH_SIZE = 32
# function to calculate cosine metrics per each random sample 
def aggregate_chunks(group, chunk_size=32):
    chunk_ids = np.arange(len(group)) // chunk_size
    return group.groupby(chunk_ids).agg({
            'cos_sim': 'median',
            'cos_sim_log': 'median',
    })

# 1.a Load task difficulty metrics

df_grad = load_gradient_stats(path="../results/per_sample_result/llama_gradient_metrics_per_sample_2500_v2.json")
df_proba = load_proba_stats(path="../results/per_sample_result/llama_model_proba_per_sample_2500_v2.json")
df_cos_avg = load_cosine_stats(path=f"../results/per_sample_result/llama_cosine_similarity_metrics_32_v2.json")
df_cos = load_cosine_stats(path=f"../results/per_sample_result/llama_cosine_similarity_metrics_32_v2_conf0.json")

# 1.b Load the AUC 
df_auc = load_auc(path="../results/auc_res/llama_auc_logscale_by_task_2500.json", tasks_to_exclude=[])

# 1.c Join all data together
df_stats = df_auc.set_index(['task']).join(
    df_grad.set_index(['task','example_num'])
).join(
    df_proba.set_index(['task','example_num'])
)

# 2. Simulate sampling 32 examples for 10 iterations

agg_method = 'median'
res = []

# 3. process cosine samples for the batch size
df_cos1_pre = df_cos.groupby('task').apply(aggregate_chunks).reset_index().rename(columns={'level_1':'iter'})
df_cos1_pre = df_cos1_pre.set_index('task').join(
    df_cos1_pre.groupby('task').agg({'iter':'max'}), rsuffix='_max'
    )

df_cos2_pre = df_cos_avg.groupby('task').apply(aggregate_chunks).reset_index().rename(columns={'level_1':'iter'})
df_cos2_pre = df_cos2_pre.set_index('task').join(
    df_cos2_pre.groupby('task').agg({'iter':'max'}), rsuffix='_max'
    )

for i in range(10):
    df_stats_subset = df_stats.groupby('task').sample(BATCH_SIZE)
    df_cos1_subset = df_cos1_pre[df_cos1_pre['iter']==i%(df_cos1_pre['iter_max']+1)]
    df_cos2_subset = df_cos2_pre[df_cos2_pre['iter']==i%(df_cos2_pre['iter_max']+1)]
    df_stats_agg = df_stats_subset.groupby('task').agg({
                                            'avg_confidence':agg_method,
                                            'avg_error':agg_method,
                                            'l1_norm':agg_method,
                                            'l2_norm':agg_method,
                                            'fisher':agg_method,
                                            'variance':agg_method,
                                            'min_acc':agg_method,
                                            'auc':agg_method,
                                            })

    df_cos1_agg = df_cos1_subset.groupby('task').agg({
                                            'cos_sim': agg_method,
                                            'cos_sim_log': agg_method,
                                            }).rename(columns={'cos_sim':'cos_sim_conf0','cos_sim_log':'cos_sim_log_conf0'})
    df_cos2_agg = df_cos2_subset.groupby('task').agg({
                                            'cos_sim': agg_method,
                                            'cos_sim_log': agg_method,
                                            }).rename(columns={'cos_sim':'cos_sim_avg','cos_sim_log':'cos_sim_log_avg'})

    df_agg = df_stats_agg.join(df_cos1_agg).join(df_cos2_agg).reset_index() #.join(df_cos2_agg).reset_index() #.join(df_cos3_agg).reset_index()
    df_agg['iter'] = i
    res.append(df_agg)
    
# aggregate the results over 10 iterations
df_res = pd.concat(res)
for c in df_res.columns:
    if c not in ['task','iter']:
        df_res[c] = df_res[c].astype(float)

# 4. Get the AUC prediction result from each metric

d_res = {'conf_avg': [], 'grad_norm': [], 'cos_avg': [],'cos_low':[]}
for i in range(10):
    train = df_res[df_res['iter']==i]
    d_res['conf_avg'].append(run_leave_one_out(train, depvar='auc', predvar=['avg_confidence']))
    d_res['grad_norm'].append(run_leave_one_out(train, depvar='auc', predvar=['l2_norm']))
    d_res['cos_avg'].append(run_leave_one_out(train, depvar='auc', predvar=['cos_sim_avg']))
    d_res['cos_low'].append(run_leave_one_out(train, depvar='auc', predvar=['cos_sim_conf0']))

# 5. Save the prediction results!

pd.concat(d_res['cos_low']).groupby('task').agg({'pred':'mean',
                                                 'depvar':'mean',
                                                 'abs_diff':'mean',
                                                 'coeff_cos_sim_conf0':'mean',
                                                 'pval_cos_sim_conf0':'mean',
                                                 'intercept':'mean',
                                                 'pval_intercept':'mean'
                                                 }).to_json('../results/metrics_to_auc/cos_low_prediction.json', orient='index', indent=4)
pd.concat(d_res['conf_avg']).groupby('task').agg({'pred':'mean',
                                                 'depvar':'mean',
                                                 'abs_diff':'mean',
                                                 'coeff_avg_confidence':'mean',
                                                 'pval_avg_confidence':'mean',
                                                 'intercept':'mean',
                                                 'pval_intercept':'mean'
                                                 }).to_json('../results/metrics_to_auc/conf_avg_prediction.json', orient='index', indent=4)
pd.concat(d_res['grad_norm']).groupby('task').agg({'pred':'mean',
                                                 'depvar':'mean',
                                                 'abs_diff':'mean',
                                                 'coeff_l2_norm':'mean',
                                                 'pval_l2_norm':'mean',
                                                 'intercept':'mean',
                                                 'pval_intercept':'mean'
                                                 }).to_json('../results/metrics_to_auc/grad_norm_prediction.json', orient='index', indent=4)
pd.concat(d_res['cos_avg']).groupby('task').agg({'pred':'mean',
                                                 'depvar':'mean',
                                                 'abs_diff':'mean',
                                                 'coeff_cos_sim_avg':'mean',
                                                 'pval_cos_sim_avg':'mean',
                                                 'intercept':'mean',
                                                 'pval_intercept':'mean'
                                                 }).to_json('../results/metrics_to_auc/cos_avg_prediction.json', orient='index', indent=4)