#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
from tqdm import tqdm
import captum.attr as attr
import skimage
import saliency.core as saliency
from functools import partial

from xai.thresholds import make_order_thresholds, make_fraction_thresholds


def ig(model, img, label, baseline, num_steps, internal_batch_size=1):
    ig = attr.IntegratedGradients(model)
    heatmap = ig.attribute(img, baselines=baseline, target=label, n_steps=num_steps, internal_batch_size=internal_batch_size)
    return heatmap

def gradcam(model, img, label, gradcam_layer, resize=True):
    gradcam = attr.LayerGradCam(model, gradcam_layer)
    heatmap = gradcam.attribute(img, target=label)
    device = heatmap.device
    if resize:
        heatmap = torch.from_numpy(skimage.transform.resize(heatmap.detach().cpu().numpy()[0, 0, ...], img.shape[-2:], order=1))[None, None, ...].to(device=device)
    return heatmap

def convert_to_saliency_format(img):
    img_sal = np.moveaxis(img.cpu().numpy(), 1, 3)
    return img_sal

def convert_from_saliency_format(img_sal, device):
    img_sal = np.moveaxis(img_sal, 3, 1)
    img = torch.from_numpy(img_sal).to(dtype=torch.float32, device=device)
    return img

def saliency_call_model_function(img_sal, model, device, call_model_args=None, expected_keys=None):
    if not saliency.base.INPUT_OUTPUT_GRADIENTS in expected_keys:
        raise Exception("One-hot Saliency call_model_function not implemented.")

    img = convert_from_saliency_format(img_sal, device)

    img.requires_grad_(True)
    target_class_idx = call_model_args

    output = model(img)[:,target_class_idx]

    grads = torch.autograd.grad(output, img)[0]
    grads_sal = convert_to_saliency_format(grads)
    return {saliency.base.INPUT_OUTPUT_GRADIENTS: grads_sal}


def gig(model, img, label, baseline, num_steps, max_dist=1, fraction=0.1):
    gig = saliency.GuidedIG()
    img_sal = convert_to_saliency_format(img)[0, ...]
    baseline_sal = convert_to_saliency_format(baseline)[0, ...]
    call_model_function = partial(saliency_call_model_function, model=model, device=img.device)

    heatmap_sal = gig.GetMask(img_sal, call_model_function, label, x_steps=num_steps, x_baseline=baseline_sal, max_dist=max_dist, fraction=fraction)[None, ...]
    heatmap = convert_from_saliency_format(heatmap_sal, device=img.device)
    return heatmap

def xrai(model, img, label, baselines=None):
    xrai = saliency.XRAI()
    img_sal = convert_to_saliency_format(img)[0, ...]
    if baselines is None:
        baselines_sal = None
    else:
        baselines_sal = [convert_to_saliency_format(b)[0, ...] for b in baselines]
    call_model_function = partial(saliency_call_model_function, model=model, device=img.device)

    heatmap_sal = xrai.GetMask(img_sal, call_model_function, label, baselines=baselines_sal)[None, None]

    heatmap = torch.from_numpy(heatmap_sal).to(dtype=torch.float32, device=img.device)
    return heatmap

def mmbs_discrete(model, img, label, imputation, num_outer_steps, num_inner_steps, num_paths):
    heatmap = torch.zeros_like(img)

    for _ in tqdm(range(num_paths)):
        num_total_pixels = np.prod(list(img.shape))
        order_map = torch.reshape(torch.randperm(num_total_pixels, device=img.device), img.shape)

        imputed_img_prev = img
        in_img_prev = img

        for threshold in make_order_thresholds(num_total_pixels, num_outer_steps):
            mask = (order_map < threshold)

            imputed_img = imputation(img, mask)

            for mix_fraction in make_fraction_thresholds(num_inner_steps):
                in_img = mix_fraction*imputed_img + (1-mix_fraction)*imputed_img_prev
                in_img.requires_grad_(True)
                out = model(in_img)[0,label]
                grad = torch.autograd.grad(out, in_img, grad_outputs=torch.ones_like(out))[0]
                in_img.requires_grad_(False)

                diff = in_img_prev - in_img
                in_img_prev = in_img

                heatmap += diff * grad

            imputed_img_prev = imputed_img

    return heatmap/num_paths

def sample_order_map(img):
    num_pixels = np.prod(list(img.shape))
    return torch.reshape(torch.randperm(num_pixels, device=img.device), img.shape)

def mmbs(model, img, label, imputation, num_steps, num_paths, order_maps=None, progress_bar=True):
    heatmap = torch.zeros_like(img)
    num_pixels = np.prod(list(img.shape))

    ran = range(num_paths)
    if progress_bar:
        ran = tqdm(ran)
    for i in ran:
        if order_maps is None:
            order_map = sample_order_map(img)
        else:
            order_map = order_maps[i]

        imputed_img_prev = img

        for threshold in make_order_thresholds(num_pixels, num_steps):
            mask = (order_map < threshold)

            imputed_img = imputation(img, mask)

            mix_fraction = torch.rand(1)[0]

            in_img = mix_fraction*imputed_img + (1-mix_fraction)*imputed_img_prev
            in_img.requires_grad_(True)
            out = model(in_img)[0,label]
            grad = torch.autograd.grad(out, in_img, grad_outputs=torch.ones_like(out))[0]
            in_img.requires_grad_(False)

            diff = imputed_img_prev - imputed_img

            heatmap += diff * grad

            imputed_img_prev = imputed_img

    return heatmap/num_paths

def mbshap(model, img, label, imputation, num_paths, order_maps=None, progress_bar=False):
    num_pixels = np.prod(list(img.shape))
    with torch.no_grad():
        heatmap = torch.zeros_like(img)

        ran = range(num_paths)
        if progress_bar:
            ran = tqdm(ran)
        for i in ran:
            if order_maps is None:
                order_map = sample_order_map(img)
            else:
                order_map = order_maps[i]

            out_prev = model(img)[0,label]

            for j in range(num_pixels):
                threshold = num_pixels-(j+1)
                mask = (order_map < threshold)

                imputed_img = imputation(img, mask)
                out = model(imputed_img)[0,label]
                out_diff = out_prev-out

                heatmap += (order_map == threshold)*out_diff

                out_prev = out

    return heatmap/num_paths

def smoothgrad(partial_fun, img, noise_std_factor=0.15, num_samples=25, progress_bar=False):
    result = None

    ran = range(num_samples)
    if progress_bar:
        ran = tqdm(ran)
    for i in ran:
        noise_std = (torch.max(img)-torch.min(img))*noise_std_factor
        noise = torch.randn_like(img)*noise_std
        heatmap = partial_fun(img=img+noise)
        if result is None:
            result = heatmap
        else:
            result += heatmap

    return result/num_samples
