import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
from pdb import set_trace as bp
import torch
import random
import cv2
from torchvision import models, transforms
from torch.autograd import Variable
from captum.attr import Saliency, IntegratedGradients, InputXGradient
from captum.attr import NoiseTunnel
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from matplotlib.colors import LinearSegmentedColormap
from captum.attr import visualization as viz
import torch.nn.functional as F
import captum.attr as attr
from skimage.segmentation import mark_boundaries

from utils_RISE import *
from RISE import RISE
from lime import lime_image

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

def get_image(path):
    with open(os.path.abspath(path), 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB') 
        

#.............NeurIPS23...........
# img = get_image('../strawberry.jpeg')
# img = get_image('../boat.jpeg')
# img = get_image('../wine.jpeg')
# img = get_image('../palace.jpeg')
# img = get_image('../cello.jpeg')
# img = get_image('../bus.jpeg')
# img = get_image('../bird.jpeg')
# img = get_image('../ant.jpeg')
# img = get_image('../barbecue.jpg')
# img = get_image('../tent.jpg')
# img = get_image('../icecream.jpg')
img = get_image('../crab.jpeg')
# img = get_image('../calculator.jpeg')
# img = get_image('../stethoscope.jpeg')
# img = get_image('../fountain_pen.jpeg')
#..............................

device_id = "cuda:0"
device = torch.device(device_id if torch.cuda.is_available() else "cpu")

# resize and take the center part of image to what our model expects
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    # unsqeeze converts single image to batch of 1
    return transf(img).unsqueeze(0)

model = models.resnet101(pretrained=True)

labels_path = os.getenv('HOME') + '/.torch/models/imagenet_class_index.json'
idx2label, cls2label, cls2idx = [], {}, {}
with open(os.path.abspath(labels_path), 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
    cls2label = {class_idx[str(k)][0]: class_idx[str(k)][1] for k in range(len(class_idx))}
    cls2idx = {class_idx[str(k)][0]: k for k in range(len(class_idx))}

img_t = get_input_tensors(img)
model.eval()
logits = model(img_t)

probs = F.softmax(logits, dim=1)
probs5 = probs.topk(5)
print(tuple((p,c, idx2label[c]) for p, c in zip(probs5[0][0].detach().numpy(), probs5[1][0].detach().numpy())))

def get_pil_transform(): 
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224)
    ])    

    return transf

def get_preprocess_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

def batch_predict(images):
    localtransf = transforms.ToTensor()
    model.eval()
    batch = torch.stack(tuple(localtransf(i) for i in images), dim=0)

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(device_id)
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

test_pred = batch_predict([pill_transf(img)])
print(test_pred.squeeze().argmax())

img_orig = img_t[0].permute((1,2,0)).numpy()
img_orig = (img_orig-img_orig.min())/(img_orig.max()-img_orig.min())
plt.imsave('img-orig.png', img_orig)

# common function to save explanations at once
def save_explanation_and_image(img, mask, method):
    mask = ((mask-mask.min())/(mask.max()-mask.min()))*256
    heatmap_img = cv2.applyColorMap(mask.astype(np.uint8), cv2.COLORMAP_JET)
    fin = cv2.addWeighted(heatmap_img, 1.0, heatmap_img.astype(np.uint8), 0.0, 0)
    RGBimage = cv2.cvtColor(fin, cv2.COLOR_BGR2RGB)
    PILimage = Image.fromarray(RGBimage)
    PILimage.save('mask_'+method+'.png', dpi=(172,172))

    img = np.array(pill_transf(img))
    img_for_masking = (img-img.min())/(img.max()-img.min())
    mask_for_masking = (mask-mask.min())/(mask.max()-mask.min())
    img_for_masking[mask_for_masking<mask_for_masking.mean()+1.2*mask_for_masking.std()] = 0.6
    plt.imsave('img_'+method+'.png', img_for_masking)
    return None

#...................collection of XAI methods ........................
def get_mask(attr_hmap):
    attr_hmap_tr = np.transpose(attr_hmap.detach().cpu().squeeze().numpy(), (1,2,0))

    default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                        [(0, '#ffffff'),
                                        (0.25, '#000000'),
                                        (1, '#000000')], N=256)

    viz_attr = viz.visualize_image_attr_custom(attr_hmap_tr,
                                img_t.detach().cpu().permute(2,3,1,0).squeeze().numpy(),
                                method='heat_map',
                                cmap=default_cmap,
                                show_colorbar=False,
                                sign='positive',
                                outlier_perc=1)
    mask = viz_attr[0]
    return mask

# vanilla gradient
vanilla_gradient = Saliency(model)
attr_hmap = vanilla_gradient.attribute(img_t.to(device), target=torch.tensor(test_pred.squeeze().argmax()))
mask = get_mask(attr_hmap)
save_explanation_and_image(img, mask, 'vanilla-grad')

#smoothgrad
vanilla_gradient = Saliency(model)
noise_tunnel = NoiseTunnel(vanilla_gradient)
attr_hmap = noise_tunnel.attribute(img_t.to(device), nt_samples=100, nt_type='smoothgrad', target=torch.tensor(test_pred.squeeze().argmax()))
mask = get_mask(attr_hmap)
save_explanation_and_image(img, mask, 'smoothgrad')

#integrated gradients
int_grad = IntegratedGradients(model, multiply_by_inputs=True)
attr_hmap = int_grad.attribute(img_t.to(device), target=torch.tensor(test_pred.squeeze().argmax()), n_steps=200)
mask = get_mask(attr_hmap)
save_explanation_and_image(img, mask, 'int-grad')

#guided gradcam
explainer = attr.GuidedGradCam(model, model.layer4[-1])
attr_hmap = explainer.attribute(img_t.to(device), torch.tensor(test_pred.squeeze().argmax()))
mask = get_mask(attr_hmap)
save_explanation_and_image(img, mask, 'guided-gc')

#input x gradient
int_grad = InputXGradient(model)
attr_hmap = int_grad.attribute(img_t.to(device), target=torch.tensor(test_pred.squeeze().argmax()))
mask = get_mask(attr_hmap)
save_explanation_and_image(img, mask, 'grad_input')
#............................................
#...........................................................
#gradcam
cam = GradCAM(model=model, target_layers=[model.layer4[-1]], use_cuda=True)
attributions_ig = cam(input_tensor=img_t.to(device), targets=[ClassifierOutputTarget(torch.tensor(test_pred.squeeze().argmax()).item())])
assert attributions_ig.shape[0]==1
mask = attributions_ig[0]
save_explanation_and_image(img, mask, 'gradcam')

#gradcam++
cam = GradCAMPlusPlus(model=model, target_layers=[model.layer4[-1]], use_cuda=True)
attributions_ig = cam(input_tensor=img_t.to(device), targets=[ClassifierOutputTarget(torch.tensor(test_pred.squeeze().argmax()).item())])
assert attributions_ig.shape[0]==1
mask = attributions_ig[0]
save_explanation_and_image(img, mask, 'gradcam++')
#...........................................................
#RISE
explainer = RISE(model, (img_t.to(device).shape[-2], img_t.to(device).shape[-1]), device_id = device_id, gpu_batch=128)

maskspath = 'utilities/masks.npy'
generate_new = False
p1_mask = 0.1
if generate_new or not os.path.isfile(maskspath):
    explainer.generate_masks(N=1500, s=8, p1=p1_mask, savepath=maskspath)
else:
    explainer.load_masks(maskspath, p1_mask)
    print('Masks are loaded.')

saliency = explainer(img_t.to(device)).cpu().numpy()
mask = saliency[torch.tensor(test_pred.squeeze().argmax()).item()]
save_explanation_and_image(img, mask, 'rise')

#LIME
explainer = lime_image.LimeImageExplainer_custom()
explanation = explainer.explain_instance(img_t.to(device)[0].detach().cpu().permute(1,2,0).numpy(), 
                                batch_predict, # classification function
                                top_labels=1, 
                                hide_color=None,
                                batch_size=256, 
                                num_samples=1500, random_seed=42) # number of images that will be sent to classification function

temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=100, hide_rest=False)
save_explanation_and_image(img, mask, 'lime')
#...........................................................



