import torch
import torchvision.transforms as transforms

from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from captum.attr import Saliency, IntegratedGradients, LayerGradCam
from captum.attr import NoiseTunnel
import matplotlib.pyplot as plt
import time, os, copy, numpy as np
from sklearn.metrics import confusion_matrix
from skimage.transform import rescale, resize
from tqdm import tqdm
import sys
sys.path.append(os.path.abspath(os.path.join('..', '..', '..', 'mask_learning_V5_kl_div_lowscale_fast_modular')))
import explain_VOC as DAME

from pdb import set_trace as bp

def get_annt_transform(shape): 
    assert len(shape)==2
    transf = transforms.Compose([
        transforms.Resize(shape),
        transforms.CenterCrop(shape[0])
    ])    

    return transf

def evaluate_explanation(model, xai_dataloader, xai_dataset_size, batch_predict, save_explanations_path, expl_thr, device_id):
    device = torch.device(device_id if torch.cuda.is_available() else "cpu")
    # since = time.time()
    if not os.path.exists(save_explanations_path): os.mkdir(save_explanations_path)
    print('...Evaluating explanations...')

    running_loss = 0.0
    running_corrects = 0
    miss_count = 0
    store_ious = []
    store_ious_cls = {x: [] for x in range(21)}
    # Iterate over data.
    for i, (inputs, labels, annts) in enumerate(xai_dataloader):
        if i%100==0: print('iter:', i)
        model.eval()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        assert preds.shape[0] == 1

        if preds[0] == labels[0]:
            
            # #.......................................
            attributions_ig = DAME.generate_explanation(inputs, 
                                         batch_predict, # classification function
                                         top_labels=1, 
                                         hide_color=None,
                                         batch_size=128, 
                                         num_samples=5000, idx_expl = labels.item(), device_id=device_id, random_seed=42) # number of images that will be sent to classification function
            
            # temp, attributions_ig = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=100, hide_rest=False)
            attributions_ig = (attributions_ig-attributions_ig.min())/(attributions_ig.max()-attributions_ig.min())
            attributions_ig = resize(attributions_ig, (inputs.shape[-1], inputs.shape[-2]),anti_aliasing=True)

            save_expl_cls_path = os.path.join(save_explanations_path, str(labels[0].item()))
            if not os.path.exists(save_expl_cls_path): os.mkdir(save_expl_cls_path)
            np.save(os.path.join(save_expl_cls_path, 'mask_'+str(i)+'.npy'), attributions_ig)

            attributions_ig = torch.from_numpy(attributions_ig).unsqueeze(0)
            
            mask = torch.zeros(attributions_ig.shape)
            attributions_ig = (attributions_ig - attributions_ig.min())/(attributions_ig.max()-attributions_ig.min())
            if expl_thr == 'mean':
                mask[attributions_ig>attributions_ig.mean()] = 1
            else:
                mask[attributions_ig>expl_thr] = 1
            mask = mask.squeeze(0)
            assert len(mask.shape) == 2

            annt_transf = get_annt_transform(mask.shape)
            annt_mask = torch.zeros(annts.shape)
            annt_mask[annts==labels[0].item()] = 1
            annt_mask = annt_transf(annt_mask).squeeze(0)
            annt_mask[annt_mask>=0.4] = 1
            annt_mask[annt_mask<0.4] = 0
            np.save(os.path.join(save_expl_cls_path, 'mask_'+str(i)+'_annt.npy'), annt_mask)

            mask_iou = mask+annt_mask
            iou = torch.sum(mask_iou==2)/torch.sum(mask_iou>=1)
            print(f'IoU: {iou}')
            store_ious.append(iou)
            store_ious_cls[labels[0].item()].append(iou)
            sys.stdout.flush()
            sys.stderr.flush()
        else: miss_count+=1
    
    print(f'miss classfication percentage that shd not have happened: {miss_count*100/(i+1)}')
    print(f'IoU stats: mean- {np.mean(store_ious)}, std- {np.std(store_ious)}')
    print(f'classwise IoU values: {[[np.mean(store_ious_cls[c]), np.std(store_ious_cls[c])] for c in range(21)]}')
    return np.array(store_ious), store_ious_cls

