import torch
# import torchvision
# import torch.nn as nn
# import torch.optim as optim
# from torch.optim import lr_scheduler
# import torchvision.datasets as datasets
# import torch.utils.data as data
import torchvision.transforms as transforms
# from torch.autograd import Variable
# from torch.utils.data import Dataset, DataLoader
# import torchvision.models as models
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from captum.attr import Saliency, IntegratedGradients, LayerGradCam
from captum.attr import NoiseTunnel
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import sys

from utils_RISE import *
from RISE import RISE

from pdb import set_trace as bp

def get_annt_transform(shape): 
    assert len(shape)==2
    transf = transforms.Compose([
        transforms.Resize(shape),
        transforms.CenterCrop(shape[0])
    ])    

    return transf

def score_calc(model, xai_dataloader, xai_dataset_size, save_explanations_path, expl_thr, device_id):
    device = torch.device(device_id if torch.cuda.is_available() else "cpu")
    # since = time.time()
    if not os.path.exists(save_explanations_path): os.mkdir(save_explanations_path)
    print('...Evaluating explanations...')

    running_loss = 0.0
    running_corrects = 0
    miss_count = 0
    store_ious = []
    store_ious_cls = {x: [] for x in range(21)}
    store_scores = []
    store_cls = []

    # Iterate over data.
    for i, (inputs, labels, annts) in enumerate(xai_dataloader):
        if i%100==0: print('iter:', i)
        model.eval()
        # inputs = inputs.to(device)
        # labels = labels.to(device)

        # weights = torch.tensor([torch.count_nonzero(labels==i)  for i in range(21)]).to(device)
        # weights = weights/weights.sum()

        outputs = model(inputs).softmax(dim=1)
        score, preds = torch.max(outputs, 1)

        assert preds.shape[0] == 1
        
        if preds[0] == labels[0]:
            # int_grad = LayerGradCam(model, model.layer4[-1])
            # # vanilla_gradient = Saliency(model)
            # # noise_tunnel = NoiseTunnel(vanilla_gradient)
            # # attributions_ig = int_grad.attribute(inputs, nt_samples=30, nt_type='smoothgrad', target=labels)
            # # attributions_ig = noise_tunnel.attribute(img, nt_samples=30, nt_type='smoothgrad', target=labels, abs=False)
            # attributions_ig = int_grad.attribute(inputs, target=labels)
            # upsample = torch.nn.Upsample(size=(inputs.shape[-2], inputs.shape[-1]), mode='bilinear')
            # attributions_ig = upsample(attributions_ig).squeeze(0)
            # # attributions_ig = attributions_ig.mean(axis=1)

            # #......alternate implementation.........
            # cam = GradCAMPlusPlus(model=model, target_layers=[model.layer4[-1]], use_cuda=True)
            # attributions_ig = cam(input_tensor=inputs, targets=[ClassifierOutputTarget(labels.item())])
            # attributions_ig = torch.from_numpy(attributions_ig)
            # #.......................................

            print(f'score: {score.item()}')
            store_scores.append(score.item())
            store_cls.append(labels[0].item())

            # # Generate masks for RISE or use the saved ones.
            # explainer = RISE(model, (inputs.shape[-2], inputs.shape[-1]), device_id = device_id, gpu_batch=32)

            # maskspath = 'utilities/masks.npy'
            # generate_new = True
            # p1_mask = 0.1
            # if generate_new or not os.path.isfile(maskspath):
            #     explainer.generate_masks(N=3000, s=8, p1=p1_mask, savepath=maskspath)
            # else:
            #     explainer.load_masks(maskspath, p1_mask)
            #     print('Masks are loaded.')

            # saliency = explainer(inputs).cpu().numpy()
            # attributions_ig = saliency[labels[0].item()]
            # attributions_ig = (attributions_ig-attributions_ig.min())/(attributions_ig.max()-attributions_ig.min())

            # save_expl_cls_path = os.path.join(save_explanations_path, str(labels[0].item()))
            # if not os.path.exists(save_expl_cls_path): os.mkdir(save_expl_cls_path)
            # np.save(os.path.join(save_expl_cls_path, 'mask_'+str(i)+'.npy'), attributions_ig)

            # attributions_ig = torch.from_numpy(attributions_ig).unsqueeze(0)
            
            # mask = torch.zeros(attributions_ig.shape)
            # attributions_ig = (attributions_ig - attributions_ig.min())/(attributions_ig.max()-attributions_ig.min())
            # if expl_thr == 'mean':
            #     mask[attributions_ig>attributions_ig.mean()] = 1
            # else:
            #     mask[attributions_ig>expl_thr] = 1
            # mask = mask.squeeze(0)
            # assert len(mask.shape) == 2

            # annt_transf = get_annt_transform(mask.shape)
            # annt_mask = torch.zeros(annts.shape)
            # annt_mask[annts==labels[0].item()] = 1
            # annt_mask = annt_transf(annt_mask).squeeze(0)
            # annt_mask[annt_mask>=0.4] = 1
            # annt_mask[annt_mask<0.4] = 0

            # mask_iou = mask+annt_mask
            # iou = torch.sum(mask_iou==2)/torch.sum(mask_iou>=1)
            # print(f'IoU: {iou}')
            # store_ious.append(iou)
            # store_ious_cls[labels[0].item()].append(iou)
            # sys.stdout.flush()
            # sys.stderr.flush()
        else: miss_count+=1
    bp()
    print(f'miss classfication percentage that shd not have happened: {miss_count*100/(i+1)}')

