import warnings
from robustbench.utils import clean_accuracy, update_json
from robustbench.eval import corruptions_evaluation
from robustbench.model_zoo.enums import BenchmarkDataset, ThreatModel
from robustbench.data import CORRUPTIONS_DICT, get_preprocessing, load_clean_dataset
from typing import Callable, Optional, Sequence, Tuple, Union
from pathlib import Path
from torch import nn
from autoattack import AutoAttack
from autoattack.state import EvaluationState

import torch
import json
from tqdm import tqdm
import numpy as np
import random 
def compute_shap_class_acc_tracking(model, attack, save_path, nat_loader, att_loader, eps=0, num_classes = 10, seed = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    total, total_correct = 0., 0.
    preds, labels = torch.tensor([]).cuda(), torch.tensor([]).cuda()
    total_shap_list = []
    total_cost_list = []
    for batch_idx, (inputs, targets) in enumerate(tqdm(nat_loader, desc='Evaluation-acc')):
        inputs, targets = inputs.cuda(), targets.cuda()
        
        # apply attack 
        if att_loader != None:
            adv_x, shap_list, cost_list = att_loader(inputs, targets) if eps != 0 else inputs
            total_shap_list.append(shap_list)
            total_cost_list.append(cost_list)
            predict = model(adv_x).detach()
        else:
            predict = model(inputs).detach()
        
        predict = predict.argmax(1)

        preds = torch.cat([preds, predict], dim=0)
        labels = torch.cat([labels, targets], dim=0)

        total += targets.size(0)
        total_correct += (predict.cpu().numpy() == targets.cpu().numpy()).sum().item()

    total_acc = (total_correct / total) * 100
    print(f'Attack_{attack} / Total Accuracy: {total_acc:.2f}%')

    results = {}
    class_la = []
    class_ac = []
    preds, labels = list(preds.cpu().numpy()), list(labels.cpu().numpy())
    for c_label in tqdm(range(num_classes), desc='Calculate class accuracy'):
        all_labels_index = list(filter(lambda x: labels[x] == c_label, range(len(labels))))
        target_preds = [preds[i] for i in all_labels_index]
        correct = [1 for p in target_preds if p == c_label]
        class_acc = (sum(correct) / len(all_labels_index)) * 100
        results[f'class_{c_label}'] = class_acc
        class_la.append(c_label)
        class_ac.append(class_acc)
    
    results['total acc'] = total_acc
    high_acc = max(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == high_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['highest perform'] = {'class':real_class, 'score':high_acc}

    low_acc = min(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == low_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['lowest perform'] = {'class':real_class, 'score':low_acc}
    
    results['shap_list'] = total_shap_list
    results['cost_list'] = total_cost_list
    print(f'Shap_list: {total_shap_list}')
    print(f'Cost_list: {total_cost_list}')

    with open(save_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    return results

def compute_shap_class_acc(model, attack, save_path, nat_loader, att_loader, eps=0, num_classes = 10, seed = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    total, total_correct = 0., 0.
    preds, labels = torch.tensor([]).cuda(), torch.tensor([]).cuda()
    for batch_idx, (inputs, targets) in enumerate(tqdm(nat_loader, desc='Evaluation-acc')):
        inputs, targets = inputs.cuda(), targets.cuda()
        
        # apply attack 
        if att_loader != None:
            adv_x = att_loader(inputs, targets) if eps != 0 else inputs
            predict = model(adv_x).detach()
        else:
            predict = model(inputs).detach()
        
        predict = predict.argmax(1)

        preds = torch.cat([preds, predict], dim=0)
        labels = torch.cat([labels, targets], dim=0)

        total += targets.size(0)
        total_correct += (predict.cpu().numpy() == targets.cpu().numpy()).sum().item()

    total_acc = (total_correct / total) * 100
    print(f'Attack_{attack} / Total Accuracy: {total_acc:.2f}%')

    results = {}
    class_la = []
    class_ac = []
    preds, labels = list(preds.cpu().numpy()), list(labels.cpu().numpy())
    for c_label in tqdm(range(num_classes), desc='Calculate class accuracy'):
        all_labels_index = list(filter(lambda x: labels[x] == c_label, range(len(labels))))
        target_preds = [preds[i] for i in all_labels_index]
        correct = [1 for p in target_preds if p == c_label]
        class_acc = (sum(correct) / len(all_labels_index)) * 100
        results[f'class_{c_label}'] = class_acc
        class_la.append(c_label)
        class_ac.append(class_acc)
    
    results['total acc'] = total_acc
    high_acc = max(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == high_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['highest perform'] = {'class':real_class, 'score':high_acc}

    low_acc = min(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == low_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['lowest perform'] = {'class':real_class, 'score':low_acc}


    with open(save_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    return results

def compute_shap_class_acc_two_stage(model, attack, save_path, nat_loader, att_loader, eps=0, num_classes = 10, seed = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    total, total_correct = 0., 0.
    preds, labels = torch.tensor([]).cuda(), torch.tensor([]).cuda()
    for batch_idx, (inputs, targets) in enumerate(tqdm(nat_loader, desc='Evaluation-acc')):
        inputs, targets = inputs.cuda(), targets.cuda()
        # Compute SHAP values
        shap_batch = torch.zeros((len(inputs), 640))
        for it in range(len(inputs)):
            pred = model(inputs[it]).detach()
            pred = pred.argmax(1)
            try:
  
                shap_value  = model.module._compute_taylor_scores(inputs[it], pred)
            except AttributeError:
                try:
  
                    shap_value = model._compute_taylor_scores(inputs[it], pred)
                except AttributeError:
                    print('AttributeError: No such method')

            shap_batch[it] = shap_value[0][0].squeeze()

        shap_batch_std = shap_batch.std()

        if att_loader != None:
            adv_x = att_loader(inputs, targets, shap_batch_std) if eps != 0 else inputs
            predict = model(adv_x).detach()
        else:
            predict = model(inputs).detach()
        
        predict = predict.argmax(1)

        preds = torch.cat([preds, predict], dim=0)
        labels = torch.cat([labels, targets], dim=0)

        total += targets.size(0)
        total_correct += (predict.cpu().numpy() == targets.cpu().numpy()).sum().item()

    total_acc = (total_correct / total) * 100
    print(f'Attack_{attack} / Total Accuracy: {total_acc:.2f}%')

    results = {}
    class_la = []
    class_ac = []
    preds, labels = list(preds.cpu().numpy()), list(labels.cpu().numpy())
    for c_label in tqdm(range(num_classes), desc='Calculate class accuracy'):
        all_labels_index = list(filter(lambda x: labels[x] == c_label, range(len(labels))))
        target_preds = [preds[i] for i in all_labels_index]
        correct = [1 for p in target_preds if p == c_label]
        class_acc = (sum(correct) / len(all_labels_index)) * 100
        results[f'class_{c_label}'] = class_acc
        class_la.append(c_label)
        class_ac.append(class_acc)
    
    results['total acc'] = total_acc
    high_acc = max(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == high_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['highest perform'] = {'class':real_class, 'score':high_acc}

    low_acc = min(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == low_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['lowest perform'] = {'class':real_class, 'score':low_acc}

    with open(save_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    return results

def compute_class_acc(model, attack, save_path, nat_loader, att_loader, eps=0, num_classes = 10, seed = 0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    total, total_correct = 0., 0.
    preds, labels = torch.tensor([]).cuda(), torch.tensor([]).cuda()
    for batch_idx, (inputs, targets) in enumerate(tqdm(nat_loader, desc='Evaluation-acc')):
        inputs, targets = inputs.cuda(), targets.cuda()
        # apply attack 
        if att_loader != None:
            adv_x = att_loader(inputs, targets) if eps != 0 else inputs
            predict = model(adv_x).detach()
        else:
            predict = model(inputs).detach()
        predict = predict.argmax(1)

        preds = torch.cat([preds, predict], dim=0)
        labels = torch.cat([labels, targets], dim=0)

        total += targets.size(0)
        total_correct += (predict.cpu().numpy() == targets.cpu().numpy()).sum().item()

    total_acc = (total_correct / total) * 100
    print(f'Attack_{attack} / Total Accuracy: {total_acc:.2f}%')

    results = {}
    class_la = []
    class_ac = []
    preds, labels = list(preds.cpu().numpy()), list(labels.cpu().numpy())
    for c_label in tqdm(range(num_classes), desc='Calculate class accuracy'):
        all_labels_index = list(filter(lambda x: labels[x] == c_label, range(len(labels))))
        target_preds = [preds[i] for i in all_labels_index]
        correct = [1 for p in target_preds if p == c_label]
        class_acc = (sum(correct) / len(all_labels_index)) * 100
        results[f'class_{c_label}'] = class_acc
        class_la.append(c_label)
        class_ac.append(class_acc)
    
    results['total acc'] = total_acc
    high_acc = max(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == high_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['highest perform'] = {'class':real_class, 'score':high_acc}

    low_acc = min(class_ac)
    class_ind = list(filter(lambda x: class_ac[x] == low_acc, range(len(class_ac))))
    real_class = [class_la[i] for i in class_ind]
    results['lowest perform'] = {'class':real_class, 'score':low_acc}

    with open(save_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    return results

def benchmark(
    model: Union[nn.Module, Sequence[nn.Module]],
    n_examples: int = 10000,
    dataset: Union[str, BenchmarkDataset] = BenchmarkDataset.cifar_10,
    threat_model: Union[str, ThreatModel] = ThreatModel.Linf,
    to_disk: bool = False,
    model_name: Optional[str] = None,
    data_dir: str = "./data",
    corruptions_data_dir: Optional[str] = None,
    device: Optional[Union[torch.device, Sequence[torch.device]]] = None,
    batch_size: int = 32,
    eps: Optional[float] = None,
    log_path: Optional[str] = None,
    preprocessing: Optional[Union[str,
                                  Callable]] = None,
    aa_state_path: Optional[Path] = None,
    seed: int = 0) -> Tuple[float, float]:
    """Benchmarks the given model(s).

    It is possible to benchmark on 3 different threat models, and to save the results on disk. In
    the future benchmarking multiple models in parallel is going to be possible.

    :param model: The model to benchmark.
    :param n_examples: The number of examples to use to benchmark the model.
    :param dataset: The dataset to use to benchmark. Must be one of {cifar10, cifar100}
    :param threat_model: The threat model to use to benchmark, must be one of {L2, Linf
    corruptions}
    :param to_disk: Whether the results must be saved on disk as .json.
    :param model_name: The name of the model to use to save the results. Must be specified if
    to_json is True.
    :param data_dir: The directory where the dataset is or where the dataset must be downloaded.
    :param device: The device to run the computations.
    :param batch_size: The batch size to run the computations. The larger, the faster the
    evaluation.
    :param eps: The epsilon to use for L2 and Linf threat models. Must not be specified for
    corruptions threat model.
    :param preprocessing: The preprocessing that should be used for ImageNet benchmarking. Should be
    specified if `dataset` is `imageget`.
    :param aa_state_path: The path where the AA state will be saved and from where should be
    loaded if it already exists. If `None` no state will be used.

    :return: A Tuple with the clean accuracy and the accuracy in the given threat model.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if isinstance(model, Sequence) or isinstance(device, Sequence):
        # Multiple models evaluation in parallel not yet implemented
        raise NotImplementedError

    try:
        if model.training:
            warnings.warn(Warning("The given model is *not* in eval mode."))
    except AttributeError:
        warnings.warn(
            Warning(
                "It is not possible to asses if the model is in eval mode"))

    dataset_: BenchmarkDataset = BenchmarkDataset(dataset)
    threat_model_: ThreatModel = ThreatModel(threat_model)

    device = device or torch.device("cpu")
    model = model.to(device)

    prepr = get_preprocessing(dataset_, threat_model_, model_name,
                              preprocessing)

    clean_x_test, clean_y_test = load_clean_dataset(dataset_, n_examples,
                                                    data_dir, prepr)

    accuracy = clean_accuracy(model,
                              clean_x_test,
                              clean_y_test,
                              batch_size=batch_size,
                              device=device)
    print(f'Clean accuracy: {accuracy:.2%}')

    extra_metrics = {}  # dict to store corruptions_mce for corruptions threat models
    if threat_model_ in {ThreatModel.Linf, ThreatModel.L2}:
        if eps is None:
            raise ValueError(
                "If the threat model is L2 or Linf, `eps` must be specified.")

        adversary = AutoAttack(model,
                               norm=threat_model_.value,
                               eps=eps,
                               version='standard',
                               device=device,
                               log_path=log_path)
        x_adv = adversary.run_standard_evaluation(clean_x_test,
                                                  clean_y_test,
                                                  bs=batch_size,
                                                  state_path=aa_state_path)
        if aa_state_path is None:
            adv_accuracy = clean_accuracy(model,
                                        x_adv,
                                        clean_y_test,
                                        batch_size=batch_size,
                                        device=device)
        else:
            aa_state = EvaluationState.from_disk(aa_state_path)
            assert aa_state.robust_flags is not None
            adv_accuracy = aa_state.robust_flags.mean().item()
    
    elif threat_model_ in [ThreatModel.corruptions, ThreatModel.corruptions_3d]:
        corruptions = CORRUPTIONS_DICT[dataset_][threat_model_]
        print(f"Evaluating over {len(corruptions)} corruptions")
        # Exceptionally, for corruptions (2d and 3d) we use only resizing to 224x224
        prepr = get_preprocessing(dataset_, threat_model_, model_name, 
                                  'Res224')
        # Save into a dict to make a Pandas DF with nested index        
        corruptions_data_dir = corruptions_data_dir or data_dir
        adv_accuracy, adv_mce = corruptions_evaluation(
            batch_size, corruptions_data_dir, dataset_, threat_model_, 
            device, model, n_examples, to_disk, prepr, model_name)
    
        extra_metrics['corruptions_mce'] = adv_mce
    else:
        raise NotImplementedError
    print(f'Adversarial accuracy: {adv_accuracy:.2%}')

    if to_disk:
        if model_name is None:
            raise ValueError(
                "If `to_disk` is True, `model_name` should be specified.")

        update_json(dataset_, threat_model_, model_name, accuracy,
                    adv_accuracy, eps, extra_metrics)

    return accuracy, adv_accuracy
