import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from utils import load_json, save_json
from tqdm import tqdm

### AUC HELPERS

def get_observed_vs_detected(group_ids, flipped_indices, sorted_unique_group_ids):
    flipped_indices_set = set(flipped_indices.tolist()).copy()
    n_observed = np.arange(len(group_ids) + 1)
    n_detected = np.zeros(len(group_ids) + 1)
    idx = 1
    for unique_group_id in sorted_unique_group_ids:
        group_indices = torch.where(group_ids == unique_group_id)[0]
        # randomly shuffle the group indices
        group_indices = group_indices[torch.randperm(len(group_indices))]
        for group_idx in group_indices:
            if group_idx.item() in flipped_indices_set:
                n_detected[idx] = n_detected[idx - 1] + 1
                # remove the detected index from the set
                flipped_indices_set.remove(group_idx.item())
            else:
                n_detected[idx] = n_detected[idx - 1]
            idx += 1

    frac_observed = n_observed / n_observed[-1]
    frac_detected = n_detected / n_detected[-1]
    return frac_observed, frac_detected

def get_save_dir(config, invert, exp='attr_scores'):
    dataset_name, model_name, attributor_name, grouping =\
        config['dataset_name'], config['model_name'], config['attributor_name'], config['grouping']
    invert_str = '_invert' if invert else ''
    save_dir = f'experiments/results/{exp}/{dataset_name}_{model_name}/{attributor_name}_{grouping}{invert_str}'
    return save_dir

def get_group_ids(load_dir, group_size, trial, use_magnitude=False):
    group_str = f'group_scores_size_{group_size}_trial_{trial}'
    scores_per_group = load_json(os.path.join(load_dir, 'scores_json', group_str + '.json'), convert_keys_to_int=True)
    if use_magnitude:
        scores_per_group = {k: np.abs(v) for k, v in scores_per_group.items()}
    group_ids = torch.load(os.path.join(load_dir, 'scores_pt', group_str+'.pt'))['group_ids']
    sorted_unique_group_ids = sorted(scores_per_group, key=lambda x: scores_per_group[x], reverse=use_magnitude)  # low to high unless use_magnitude
    return group_ids, sorted_unique_group_ids

def get_runtimes(config, times='attribution', group_sizes=[1, 4, 16, 64, 256, 1024]):
    save_dir = get_save_dir(config, invert=False, exp='attr_scores')
    # times are in mean_times.csv
    df = pd.read_csv(f'{save_dir}/mean_times.csv')
    # Load values from the times column where group_size is in group_sizes
    df = df[df['group_size'].isin(group_sizes)]
    return df[f'mean_{times}_time'].values

def save_auc_results(config, group_sizes=[1, 4, 16, 64, 256, 1024], use_magnitude=False):
    groupings = ['equal', 'kmeans', 'grad_kmeans', 'repr_kmeans'] if config['dataset_name']!='qnli_noisy' else ['equal', 'grad_kmeans', 'repr_kmeans']
    flipped_indices = torch.load(f"../data/{config['dataset_name']}/flipped_indices.pt")
    for i, grouping in enumerate(tqdm(groupings)):
        config['grouping'] = grouping
        aucs, auc_stds = {}, {}
        grouping_fracs_observed, grouping_fracs_detected = {}, {}
        for group_size in group_sizes:
            save_dir = get_save_dir(config, config['invert'])
            trial_aucs = []
            fracs_observed, fracs_detected = [], []
            for trial in range(1, config['n_trials']+1):
                group_ids, sorted_unique_group_ids = get_group_ids(save_dir, group_size=group_size, trial=trial, use_magnitude=use_magnitude)
                frac_observed, frac_detected = get_observed_vs_detected(group_ids, flipped_indices, sorted_unique_group_ids)
                grouping_auc = np.trapezoid(frac_detected, frac_observed).item()
                trial_aucs.append(grouping_auc)
                fracs_observed.append(frac_observed)
                fracs_detected.append(frac_detected)
            grouping_fracs_observed[group_size] = np.array(fracs_observed)
            grouping_fracs_detected[group_size] = np.array(fracs_detected)
            aucs[group_size] = np.mean(trial_aucs).item()
            auc_stds[group_size] = np.std(trial_aucs).item()

        # Save the results
        save_dir = get_save_dir(config, config['invert'], exp='group_auc')
        filename = save_dir.split('/')[-1]
        save_dir = os.path.join(os.getcwd(), '/'.join(save_dir.split('/')[:-1]))
        os.makedirs(save_dir, exist_ok=True)

        # Save as json
        mag_str = '_mag' if use_magnitude else ''
        save_json(aucs, f'{save_dir}/{filename}_aucs{mag_str}.json')
        save_json(auc_stds, f'{save_dir}/{filename}_aucs_std{mag_str}.json')

        # Save the observed and detected fractions as pt
        torch.save(grouping_fracs_observed, f'{save_dir}/{filename}_fracs_observed{mag_str}.pt')
        torch.save(grouping_fracs_detected, f'{save_dir}/{filename}_fracs_detected{mag_str}.pt')

### PLOTTING HELPERS
def get_aucs(config, use_magnitude=False):
    # Load jsons
    save_dir = get_save_dir(config, config['invert'], exp='group_auc')
    filename = save_dir.split('/')[-1]
    save_dir = '/'.join(save_dir.split('/')[:-1])
    mag_str = '_mag' if use_magnitude else ''

    aucs = load_json(os.path.join(os.getcwd(), f'{save_dir}/{filename}_aucs{mag_str}.json'), convert_keys_to_int=True)
    auc_stds = load_json(os.path.join(os.getcwd(), f'{save_dir}/{filename}_aucs_std{mag_str}.json'), convert_keys_to_int=True)
    return aucs, auc_stds

config = {
    'dataset_name': '',
    'model_name': '',
    'attributor_name': 'influence_lissa',
    'grouping': 'grad_kmeans',
    'invert': False,
}

n_trials_dict = {
    'heloc': 10,
    'mnist': 5,
    'cifar10': 5,
    'qnli': 3
}

model_names_dict = {
    'heloc': ['lr'],#, 'ann_s', 'ann_m'],
    'mnist': ['ann_l'],
    'cifar10': ['resnet18'],
    'qnli': ['bert']
}
display_group_sizes = [1, 4, 16, 64, 256]
dataset_names = ['cifar10_noisy']#, 'mnist_noisy', 'cifar10_noisy', 'qnli_noisy']
for dataset_name in dataset_names:
    save_group_sizes = [1, 4, 16, 64, 256] if dataset_name.startswith('heloc') else [1, 4, 16, 64, 256, 1024]
    print(dataset_name)
    config['dataset_name'] = dataset_name
    model_names = model_names_dict[dataset_name.split('_')[0]]
    for model_name in model_names:
        # Set config
        print(model_name)
        config['model_name'] = model_name
        n_trials = n_trials_dict[dataset_name.split('_')[0]]
        config['n_trials'] = n_trials

        # Save AUC results
        # save_auc_results(config, use_magnitude=True, group_sizes=save_group_sizes)
        aucs, auc_stds = get_aucs(config, use_magnitude=False)
        performances = [aucs[group_size] for group_size in display_group_sizes]
        stds = [auc_stds[group_size] for group_size in display_group_sizes]
        runtimes = get_runtimes(config, times='attribution', group_sizes=display_group_sizes)

        # Display results
        print('Runtimes:', runtimes)
        print('AUCs:', performances)
        print('STDs:', stds)
        print()