import argparse
import torch
import torchvision
import numpy as np
from tqdm import tqdm
import os
import pickle
import sys
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

def compute_shap(model, network, attack, save_path, nat_loader, att_loader, eps=0, num_classes=10, batch_size=1):
    if network == 'densenet121':
        featdim = 1024
    elif network == 'resnet50':
        featdim = 2048
    elif network == 'wrn28-10' or network == 'wrn34-10':
        featdim = 640
    elif network == 'vgg':
        featdim = model.classifier.in_features
    elif network == 'trans':
        arch = model.model.default_cfg['architecture']
        featdim = model.model.num_features
    shap_class = np.zeros((num_classes, featdim)) # resnet50: (1000,2048)
    shap = [[] for i in range(num_classes)]
    
    for ind, (image, labels) in tqdm(enumerate(nat_loader), desc='Extracting SHAP'):
        image = image.cuda()

        if attack != 'no':
            adv_x = att_loader(image, labels) if eps != 0 else image
            predict = model(adv_x)
            predict = predict[0].cpu().detach().numpy()
            predict = np.argmax(predict)
            shap_batch  = model._compute_taylor_scores(adv_x, predict)
            shap[predict].append(shap_batch[0][0].squeeze().cpu().detach().numpy())

        else :
            predict = model(image)
            predict = predict[0].cpu().detach().numpy()
            predict = np.argmax(predict)
            if network == 'resnet50' or network == 'trans':
                try:
                    shap_batch  = model.module._compute_taylor_scores(image, predict)
                except AttributeError:
                    try:
                        shap_batch = model._compute_taylor_scores(image, predict)
                    except AttributeError:
                        try:
                            shap_batch = model._modules['model']._compute_taylor_scores(image, predict)
                        except AttributeError:
                            print('AttributeError: No such method')
            else:
                try:
                    shap_batch  = model.module._compute_taylor_scores(image, predict)
                except AttributeError:
                    try:
                        shap_batch = model._compute_taylor_scores(image, predict)
                    except AttributeError:
                        print('AttributeError: No such method')
            
            shap[predict].append(shap_batch[0][0].squeeze().cpu().detach().numpy())
        
        if ind % 1000 == 0:
            if num_classes == 1000:
                print(f'{ind}/{len(nat_loader.dataset.samples)}')
            elif num_classes == 10:
                print(f'{ind}/{nat_loader.dataset.data.shape[0]}')

    for c in range(len(shap)):
        if batch_size == 1:
            shap_temp = np.zeros(featdim)
        else:
            shap_temp = np.zeros((batch_size, featdim))
        for i in range(len(shap[c])):
            if shap_temp.shape != shap[c][i].shape:
                d = shap_temp.shape[0] - shap[c][i].shape[0]
                pad = np.zeros((d, featdim))
                shap[c][i] = np.concatenate((shap[c][i], pad), axis=0)
            shap_temp += shap[c][i].squeeze()
        if shap_temp.ndim == 2:
            shap_temp = shap_temp.sum(axis=0) # num of layer neuron
        shap_class[c,:] = shap_temp / len(shap[c])
    
    if save_path is not None:
        with open(save_path, 'wb') as f:
            pickle.dump(shap_class, f)
    
    return shap_class
        
def compute_shap_for_train_batch(model, network, images, prediction, batch_idx):
    model.eval()
    if network == 'densenet121':
        featdim = 1024
    elif network == 'resnet50':
        featdim = 2048
    elif network == 'wrn28-10' or network == 'wrn34-10':
        featdim = 640
    elif network == 'vgg':
        featdim = model.classifier.in_features
    
    batch_size = images.shape[0]
    shap = torch.zeros((batch_size, featdim))

    for i in range(batch_size):
        image = images[i].unsqueeze(0).cuda()
        predict = prediction[i]

        predict = predict.cpu().detach().numpy()
        predict = np.argmax(predict)

        if network == 'resnet50':
            try:
                shap_batch  = model.module._compute_taylor_scores(image, predict)
            except AttributeError:
                try:
                    shap_batch = model._compute_taylor_scores(image, predict)
                except AttributeError:
                    try:
                        shap_batch = model._modules['model']._compute_taylor_scores(image, predict)
                    except AttributeError:
                        print('AttributeError: No such method')
        else:
            try:
                shap_batch  = model.module._compute_taylor_scores(image, predict)
            except AttributeError:
                try:
                    shap_batch = model._compute_taylor_scores(image, predict)
                except AttributeError:
                    print('AttributeError: No such method')

        model.zero_grad()
        shap[i] = shap_batch[0][0].squeeze()
    batch_mean_shap = torch.mean(shap, axis=0)
    batch_shap_std = batch_mean_shap.std()

    model.train()

    return batch_shap_std
