import os, glob
from tqdm import tqdm
import numpy as np
import argparse
import sys
sys.path.append(os.getcwd())

import torch

# Import saliency methods and models
from dataset.expl_hdf5 import ImagenetResults
from baselines.models import load_baseline_models

from src.utilities import normalize


def setup_perturbations(config, base_vit='vit'):
    # PATH variables
    PATH = os.getcwd()

    os.makedirs(os.path.join(PATH, 'experiments'), exist_ok=True)
    os.makedirs(os.path.join(PATH, 'experiments/perturbations'), exist_ok=True)

    exp_name = config['method']
    exp_name += '_neg' if config['neg'] else '_pos'
    print(exp_name)

    if config['vis_class'] == 'index':
        config['runs_dir'] = os.path.join(PATH, 'experiments/perturbations/{}/{}_{}'.format(exp_name,
                                                                                        config['vis_class'],
                                                                                        config['class_id']))
    else:
        ablation_fold = 'ablation' if config['is_ablation'] else 'not_ablation'
        config['runs_dir'] = os.path.join(PATH, 'experiments/perturbations/{}/{}/{}'.format(exp_name,
                                                                                    config['vis_class'], ablation_fold))

    if config['wrong']:
        config['runs_dir'] += '_wrong'

    experiments = sorted(glob.glob(os.path.join(config['runs_dir'], 'experiment_*')))
    experiment_id = int(experiments[-1].split('_')[-1]) + 1 if experiments else 0
    config['experiment_dir'] = os.path.join(config['runs_dir'], 'experiment_{}'.format(str(experiment_id)))
    os.makedirs(config['experiment_dir'], exist_ok=True)


    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    attack_type = 'attack_'+str(config['noise_level']) if config['attacked'] else 'no_attack'
   
    if config['vis_class'] == 'index':
        vis_method_dir = os.path.join(PATH,'visualizations/{}/{}_{}'.format(config['method'],
                                                            config['vis_class'],
                                                            config['class_id']))
    else:
        ablation_fold = 'ablation' if config['is_ablation'] else 'not_ablation'
        vis_method_dir = os.path.join(PATH,'visualizations/{}/{}/{}/{}'.format(config['method'],
                                                        config['vis_class'], ablation_fold, attack_type))
    print(vis_method_dir)
    imagenet_ds = ImagenetResults(vis_method_dir)
    
    models = load_baseline_models(base_vit, config['device'])

    return run_perturbations(config, models['vit'], imagenet_ds)


def run_perturbations(config, model, imagenet_ds='data_dir/'):
    sample_loader = torch.utils.data.DataLoader(
        imagenet_ds,
        batch_size=config['batch_size'],
        num_workers=2,
        shuffle=False)
        
    device = config['device']
    len_image_data = len(imagenet_ds)
    model_index = 0
    num_correct_model = np.zeros((len_image_data))
    dissimilarity_model = np.zeros((len_image_data))

    if config['scale'] == 'per':
        base_size = 224 * 224
        perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    elif config['scale'] == '100':
        base_size = 100
        perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45]
    else:
        raise Exception('scale not valid')

    num_correct_pertub = np.zeros((9, len_image_data))
    dissimilarity_pertub = np.zeros((9, len_image_data))
    logit_diff_pertub = np.zeros((9, len_image_data))
    prob_diff_pertub = np.zeros((9, len_image_data))
    perturb_index = 0

    for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)):

        data = data.to(device)
        vis = vis.to(device)
        target = target.to(device)
        norm_data = normalize(data.clone())

        # Compute model accuracy
        pred = model(norm_data)
        pred_probabilities = torch.softmax(pred, dim=1)
        pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1)
        pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
        pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1)
        tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy()
        num_correct_model[model_index:model_index+len(tgt_pred)] = tgt_pred

        probs = torch.softmax(pred, dim=1)
        target_probs = torch.gather(probs, 1, target[:, None])[:, 0]
        second_probs = probs.data.topk(2, dim=1)[0][:, 1]
        temp = torch.log(target_probs / second_probs).data.cpu().numpy()
        dissimilarity_model[model_index:model_index+len(temp)] = temp

        if config['wrong']:
            wid = np.argwhere(tgt_pred == 0).flatten()
            if len(wid) == 0:
                continue
            wid = torch.from_numpy(wid).to(vis.device)
            vis = vis.index_select(0, wid)
            data = data.index_select(0, wid)
            target = target.index_select(0, wid)

        # Save original shape
        org_shape = data.shape

        if config['neg']:
            vis = -vis

        vis = vis.reshape(org_shape[0], -1)

        for i in range(len(perturbation_steps)):
            _data = data.clone()

            _, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1)
            idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1)
            _data = _data.reshape(org_shape[0], org_shape[1], -1)
            _data = _data.scatter_(-1, idx, 0)
            _data = _data.reshape(*org_shape)

            _norm_data = normalize(_data)

            out = model(_norm_data)

            pred_probabilities = torch.softmax(out, dim=1)
            pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1)
            diff = (pred_prob - pred_org_prob).data.cpu().numpy()
            prob_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff

            pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1)
            diff = (pred_logit - pred_org_logit).data.cpu().numpy()
            logit_diff_pertub[i, perturb_index:perturb_index+len(diff)] = diff

            target_class = out.data.max(1, keepdim=True)[1].squeeze(1)
            temp = (target == target_class).type(target.type()).data.cpu().numpy()
            num_correct_pertub[i, perturb_index:perturb_index+len(temp)] = temp

            probs_pertub = torch.softmax(out, dim=1)
            target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0]
            second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1]
            temp = torch.log(target_probs / second_probs).data.cpu().numpy()
            dissimilarity_pertub[i, perturb_index:perturb_index+len(temp)] = temp

        model_index += len(target)
        perturb_index += len(target)
    
    if config['to_save_results']:
        np.save(os.path.join(config['experiment_dir'], 'model_hits.npy'), num_correct_model)
        np.save(os.path.join(config['experiment_dir'], 'model_dissimilarities.npy'), dissimilarity_model)
        np.save(os.path.join(config['experiment_dir'], 'perturbations_hits.npy'), num_correct_pertub[:, :perturb_index])
        np.save(os.path.join(config['experiment_dir'], 'perturbations_dissimilarities.npy'), dissimilarity_pertub[:, :perturb_index])
        np.save(os.path.join(config['experiment_dir'], 'perturbations_logit_diff.npy'), logit_diff_pertub[:, :perturb_index])
        np.save(os.path.join(config['experiment_dir'], 'perturbations_prob_diff.npy'), prob_diff_pertub[:, :perturb_index])

    mean_acc = np.mean(num_correct_pertub, axis=1)  # 2D list of accuracies per perturbation step

    auc = np.trapz(mean_acc, perturbation_steps)
    print("Model Accuracy")
    print(np.mean(num_correct_model), np.std(num_correct_model))
    print("Model Dissimilarity")
    print(np.mean(dissimilarity_model), np.std(dissimilarity_model))
    # print(perturbation_steps)
    print("Mean and Std Stats for every perturbation step")
    print("Model Accuracy")
    print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1))
    print("Model Dissimilarity")
    print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1))
    print("AUC")
    print(auc)
    return auc


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a segmentation')
    parser.add_argument('--batch-size', type=int, default=16, help='Determine Batch size')
    parser.add_argument('--neg', type=bool, default=True, help='To run negative perturbations')
    parser.add_argument('--scale', type=str, default='per', choices=['per', '100'], help='scale of perturbations')
    parser.add_argument('--method', type=str, default='rollout', help='Choose method to run',
                        choices=['rollout', 'lrp', 'transformer_attribution', 'full_lrp', 'v_gradcam', 'lrp_last_layer',
                                 'lrp_second_layer', 'gradcam', 'attn_last_layer', 'attn_gradcam', 'input_grads']
                        )
    parser.add_argument('--vis-class', type=str, default='top',
                        choices=['top', 'target', 'index'], help='Choose type of visualization')
    parser.add_argument('--wrong', action='store_true', default=False, help='To run on wrong predictions')
    parser.add_argument('--to-save-results', action='store_true', default=False, help='To save results')
    parser.add_argument('--class-id', type=int, default=0, help='Class id for index visualization')
    parser.add_argument('--is-ablation', type=bool, default=False, help='')
    parser.add_argument('--attacked', action='store_true', default=False, help='') 
    parser.add_argument('--noise-level', type=int, default=4, help='')
    parser.add_argument('--base-vit', type=str, default='vit', help='')
    args = parser.parse_args()

    # Unused arguments
    parser.add_argument('--value', action='store_true', default=False, help='')

    torch.multiprocessing.set_start_method('spawn')

    args = vars(args)

    setup_perturbations(args, args['base_vit'])
