import numpy as np
import torch
import time
import argparse
from utils.metrics import AverageDrop, AverageGain, AverageIncrease
from captum.attr import DeepLiftShap, DeepLift
from lime import lime_image
from utils.util_models import (
    resnet50_model_target_layer,
    convnext_base_model_target_layer,
    vgg16_model_target_layer,
)
from utils.util_datasets import get_imagenet

import tqdm
class _TQDM(tqdm.tqdm):
    def __init__(self, *argv, **kwargs):
        kwargs['disable'] = True
        if kwargs.get('disable_override', 'def') != 'def':
            kwargs['disable'] = kwargs['disable_override']
        super().__init__(*argv, **kwargs)
tqdm.tqdm = _TQDM

class WrappedModel(torch.nn.Module):
    def __init__(self, model):
        super(WrappedModel, self).__init__()
        self.model = model

    def forward(self, *args, **kwargs):
        output = self.model(*args, **kwargs)
        return output.clone()


def eval(integration_dataset, val_dataset, model, method, batch_size, nb_batches):
    model.eval()
    ad, ai, ag = AverageDrop(), AverageIncrease(), AverageGain()

    start_time = time.time()
    means = torch.from_numpy(np.array([0.485, 0.456, 0.406])).unsqueeze(
        0).unsqueeze(-1).unsqueeze(-1).float()
    stds = torch.from_numpy(np.array([0.229, 0.224, 0.225])).unsqueeze(
        0).unsqueeze(-1).unsqueeze(-1).float()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    means, stds = means.to(device), stds.to(device)

    def pred(img):
        img = torch.from_numpy(img).permute(0,3,1,2).to('cuda').float()
        return model(img).cpu().detach().numpy()
    
    for nb_batch, (img, labels) in enumerate(val_dataset):
        img = img.to(device)
        labels = labels.to(device)

        tmp_img = (img - means)/stds
        cams = []
        for i,x in enumerate(tmp_img):
            numpy_img = x.permute(1, 2, 0).cpu().detach().numpy()
            explanation = method.explain_instance(numpy_img,
                                        pred,
                                        top_labels=1,
                                        hide_color=0,
                                        num_samples=100,
                                        batch_size=1,
                                        labels = labels[i].detach().cpu().numpy()
                                        )
            _, cam = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=10, hide_rest=False)
            cam = np.expand_dims(cam, axis=0)
            cam = np.expand_dims(cam, axis=0)
            cams.append(cam)
        
        cam = np.vstack(cams)
        cam = np.clip(cam, 0, 1)
        cam = torch.from_numpy(cam).type_as(img)

        masked_img = (img * cam).type_as(img)

        combin = torch.concat((img, masked_img), dim=0).type_as(img)
        combin = (combin - means)/stds
        output = torch.softmax(model(combin), dim=-1).cpu()

        try:
            mid_point = len(output)//2
            targets, preds = output[:mid_point], output[mid_point:]
            
            targets = [[targets[i, label].item(), preds[i, label].item()]
                    for i, label in enumerate(labels)]
            targets = torch.from_numpy(np.array(targets))
            targets, preds = targets[:, 0], targets[:, 1]

            ad.update(preds, targets)
            ai.update(preds, targets)
            ag.update(preds, targets)
        except:
            print('skip')
            print(targets.shape, preds.shape, labels.shape)

        if nb_batch == nb_batches:
            break
    ad, ai, ag = ad.compute(), ai.compute(), ag.compute()
    print(f'Average Drop : {ad} Average Increase : {ai} Average Gain : {ag}')

    print(f"Execution time : {time.time()-start_time}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='CAM evaluation.')
    parser.add_argument('--method', type=str,
                        default='lime', help='Latent dim size')
    parser.add_argument('--model', type=str,
                        default='vgg16', help='Latent dim size')
    parser.add_argument('--batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('--nb_batches', type=int, default=2, help='Batch size')
    parser.add_argument('--image_size', type=int,
                        default=224, help='Batch size')
    args = parser.parse_args()

    models = {
        'resnet50': resnet50_model_target_layer,
        'vgg16': vgg16_model_target_layer,
        'convnext_b': convnext_base_model_target_layer,
    }
    assert args.model in models.keys()

    # Loading the dataset
    integration, dataset = get_imagenet(args)

    methods = {
        'lift': DeepLift,
        'shap': DeepLiftShap,
        'lime': lime_image.LimeImageExplainer()
    }

    model, _ = models[args.model]()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    integration = integration.to(device)
    print(len(integration), len(dataset))

    eval(integration_dataset=integration,
         val_dataset=dataset,
         method=methods[args.method],
         model=WrappedModel(model),
         batch_size=args.batch_size,
         nb_batches=args.nb_batches
         )
