import numpy as np
from scipy.ndimage.filters import gaussian_filter
from scipy import interpolate
from shap import KernelExplainer
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch
from tqdm.auto import tqdm
import math
import shap

def mask_image(masks, image, background=None): # for kernelshap
    # Reshape/size Mask 
    mask_shape = int(masks.shape[1]**.5)
    masks = np.reshape(masks, (masks.shape[0], 1, mask_shape, mask_shape))
    resize_aspect = image.shape[-1]/mask_shape
    masks = np.repeat(masks, resize_aspect, axis =2)
    masks = np.repeat(masks, resize_aspect, axis =3)
    
    # Mask Image 
    if background is not None:
        if len(background.shape) == 3:
            masked_images = np.vstack([np.expand_dims(
                (mask * image) + ((1-mask)*background[0]), 0
            ) for mask in masks])
        else:
            # Fill with Background
            masked_images = []
            for mask in masks:
                bg = [im * (1-mask) for im in background]
                masked_images.append(np.vstack([np.expand_dims((mask*image) + fill, 0) for fill in bg]))     
    else:     
        masked_images = np.vstack([mask * image for mask in masks])
        
    return masked_images #masks, image

def normalize(data, mtype='del'):
    if mtype == "del":
        data_bottom = data[:, -1][:, np.newaxis]
        data = (data - data_bottom)
        data_peak = data[:, 0][:, np.newaxis]
        data = data / data_peak
    elif mtype == "ins":
        data_bottom = data[:, 0][:, np.newaxis]
        data = (data - data_bottom)
        data_peak = data[:, -1][:, np.newaxis]
        data = data / data_peak
    return data
def auc(data: Tensor, mtype: str):
    if mtype == "del":
        data_peak = data[:, 0].unsqueeze(-1)
    elif mtype == "ins":
        data_peak = data[:, -1].unsqueeze(-1)
    data = data / data_peak
    area = data.sum(dim=-1) / data.size(-1)
    return np.array(area)
def gkern(ch, klen, ksig):
    inp = np.zeros((klen, klen))
    inp[klen//2, klen//2] = 1
    k = gaussian_filter(inp, ksig)
    kern = np.zeros((ch, ch, klen, klen))
    for i in range(ch):
        kern[i, i] = k
    return torch.from_numpy(kern.astype("float32"))

def gaussian_blur(img, klen, ksig):
    device = img.device
    ch = img.size(1)
    # get gkern
    kern = gkern(ch, klen, ksig).to(device)
    # compute gaussian blur
    img_out = nn.functional.conv2d(img, kern, padding=klen//2)
    return img_out

class Metric:
    def __init__(self, step_per: int = 10, klen: int = 11, ksig: int = 5, device=torch.device("cpu"), use_softmax=False, superpixel=1):
        self.step_per = step_per
        self.device = device
        self.klen = klen
        self.ksig = ksig
        self.use_softmax = use_softmax

        self.saliency_thresholds = [0.005, 0.01, 0.02, 0.03, 0.04, 0.05, 0.07, 0.10, 0.13, 0.21, 0.34, 0.5, 0.75]
        self.superpixel = superpixel

    def __call__(self, model: nn.Module, img: np.ndarray, label: int, layer_name: str, explainer_type: str, explainer):
        """
        Args:
            model (nn.Module): Deep artificial neural network.
            img (np.ndarray): Images to be tested.
            label (int): Image label.
            explainer (object): Explainable method.
        """
        self.explainer_type = explainer_type
        BS, C, H, W = img.shape
        phi_H = int(H / self.superpixel)
        phi_W = int(W / self.superpixel)
        self.nsteps = (phi_H * phi_W + self.step_per - 1) // self.step_per
        result = None
        mask = torch.zeros(BS, phi_H, phi_W).to(self.device)
        # Generate prediction and saliency map
        if explainer_type == 'Fastshap' or explainer_type == 'Simshap':
            result = torch.tensor(explainer.shap_values(img))
            for i in range(BS):
                mask[i] = result[i][label[i]] 
        elif explainer_type == 'gradcam':
            result = explainer(model, img, layer_name)
            mask = nn.AvgPool2d(self.superpixel)(result)
        elif explainer_type == 'ig':
            result = explainer.attribute(torch.tensor(img), target=label)
            result = result.mean(dim=1, keepdim=True)
            mask = nn.AvgPool2d(self.superpixel)(result)
        elif explainer_type == 'SmoothGrad':
            for i in range(img.shape[0]):
                result = explainer.attribute(img[i:i+1], nt_type='smoothgrad', nt_samples=4, target=label[i:i+1])
                result = result.mean(dim=1, keepdim=True)
                mask[i:i+1] = nn.AvgPool2d(self.superpixel)(result)
        elif explainer_type == 'Deepshap':
            for i in range(img.shape[0]):
                result = explainer.shap_values(img[i:i+1])[0]
                result = torch.tensor(result).mean(dim=1, keepdim=True)
                mask[i:i+1] = nn.AvgPool2d(self.superpixel)(result)
        elif explainer_type == 'CaptumDeepshap':
            background = torch.zeros_like(img)
            result = explainer.attribute(img, target=label, baselines=background)
            result = result.mean(dim=1, keepdim=True)
            mask = nn.AvgPool2d(self.superpixel)(result)
        elif explainer_type == 'Kernelshap':
            for i in range(img.shape[0]):
                image = img[i:i+1].cpu().numpy()
                background = None
                def f_mask(z):
                    if background is None or len(background.shape)==3:
                        y_p = []
                        if z.shape[0] == 1:
                            masked_images = mask_image(z, image, background)
                            return(explainer(masked_images))
                        else:
                            for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                                m = z[i*100:(i+1)*100]
                                masked_images = mask_image(m, image, background)
                                y_p.append(explainer(masked_images))
                            print (np.vstack(y_p).shape)
                            return np.vstack(y_p)
                    else:
                        y_p = []
                        if z.shape[0] == 1:
                            masked_images = mask_image(z, image, background)
                            for masked_image in masked_images:
                                y_p.append(np.mean(explainer(masked_image), 0))
                        else:
                            for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                                m = z[i*100:(i+1)*100]
                                masked_images = mask_image(m, image, background)
                                for masked_image in masked_images:
                                    y_p.append(np.mean(explainer(masked_image), 0))
                        return np.vstack(y_p)
                explainer_kernelshaps = shap.KernelExplainer(f_mask, np.zeros((1, (H // self.superpixel)**2)))
                shap_values = explainer_kernelshaps.shap_values(np.ones((1, (H // self.superpixel)**2)), nsamples='auto')
                mask[i:i+1] = torch.tensor(shap_values[label[i]]).reshape(1, H // self.superpixel,H // self.superpixel)
        elif explainer_type == 'Kernelshap-S':
            for i in range(img.shape[0]):
                image = img[i:i+1].cpu().numpy()
                def f_mask(z):
                    y_p = []
                    if z.shape[0] == 1:
                        return explainer(image, z)
                    else:
                        for i in tqdm(range(int(math.ceil(z.shape[0]/100)))):
                            m = z[i*100:(i+1)*100]
                            y_p.append(explainer(image,m))
                        return np.vstack(y_p)
                explainer_kernelshaps = shap.KernelExplainer(f_mask, np.zeros((1, (H // self.superpixel)**2)))
                shap_values = explainer_kernelshaps.shap_values(np.ones((1, (H // self.superpixel)**2)), nsamples='auto')
                mask[i:i+1] = torch.tensor(shap_values[label[i]]).reshape(1, H // self.superpixel,H // self.superpixel)
        output = {}

        # Compute deletion
        deletion_score = self._compute_scores(
            start_sample=img.clone(),
            end_sample=torch.zeros_like(img),
            model=model,
            mask=mask,
            pred=label
        )
        output["deletion"] = {
            "auc": auc(deletion_score, "del"),
            "score": normalize(deletion_score, "del")
        }
        # Compute insertion
        insertion_score = self._compute_scores(
            start_sample=gaussian_blur(img, self.klen, self.ksig),
            end_sample=img.clone(),
            model=model,
            mask=mask,
            pred=label
        )
        output["insertion"] = {
            "auc": auc(insertion_score, "ins"),
            "score": normalize(insertion_score, "ins")
        }

        return output
    
    def _compute_scores(
            self, 
            start_sample: Tensor,
            end_sample: Tensor,
            model: nn.Module, 
            mask: Tensor, 
            pred: int
            ):
        nsteps = self.nsteps
        BS, C, H, W = start_sample.shape
        phi_H = int(H / self.superpixel)
        phi_W = int(W / self.superpixel)
        step = self.step_per

        scores = torch.zeros(BS, nsteps + 1).to(self.device)
        # acc = torch.zeros(nsteps + 1).to(self.device)
        # acc_std = torch.zeros(nsteps + 1).to(self.device)
        mask = mask.view(BS, -1)
        _, sort_order = torch.sort(mask, dim=-1, descending=True)
        current_mask = torch.zeros((BS, phi_H, phi_W)).to(self.device)
        with torch.no_grad():
            for i in range(nsteps + 1):
                interpolation = nn.Upsample(scale_factor=self.superpixel, mode='nearest')(current_mask.unsqueeze(1))
                output = model(start_sample * (1-interpolation) + end_sample * interpolation) # all use original model to evaluate regardless of training on surrogate
                if self.use_softmax:
                    output = torch.softmax(output, dim=-1)
                cur_score = output[torch.arange(BS), pred]
                scores[:, i] = cur_score
                # acc[i] = (output.argmax(dim=-1) == pred).float().mean().item()
                # acc_std[i] = (output.argmax(dim=-1) == pred).float().std().item()
                if i < nsteps:
                    cur_ords = sort_order[:,(step*i) : (step*(i+1))]
                    current_mask = current_mask.reshape(BS, -1)
                    current_mask.scatter_(1, cur_ords, 1)
                    current_mask = current_mask.reshape(BS, phi_H, phi_W)
        
        return scores.detach()
