#!/usr/bin/env python
import functools
import operator

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.autograd import grad
from torch.utils.data import DataLoader

DEFAULT_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def gather_nd(params, indices):
    """
    Args:
        params: Tensor to index
        indices: k-dimension tensor of integers.
    Returns:
        output: 1-dimensional tensor of elements of ``params``, where
            output[i] = params[i][indices[i]]

            params   indices   output

            1 2       1 1       4
            3 4       2 0 ----> 5
            5 6       0 0       1
    """
    max_value = functools.reduce(operator.mul, list(params.size())) - 1
    indices = indices.t().long()
    ndim = indices.size(0)
    idx = torch.zeros_like(indices[0]).long()
    m = 1

    for i in range(ndim)[::-1]:
        idx += indices[i]*m
        m *= params.size(i)

    idx[idx < 0] = 0
    idx[idx > max_value] = 0
    return torch.take(params, idx)


class IntegratedGradients(object):
    def __init__(self, model, k=1, scale_by_inputs=True):
        self.model = model
        self.model.eval()
        self.k = k
        self.scale_by_inputs = scale_by_inputs

    def _get_samples_input(self, input_tensor, reference_tensor):
        '''
        calculate interpolation points
        Args:
            input_tensor: Tensor of shape (batch, ...), where ... indicates
                          the input dimensions.
            reference_tensor: A tensor of shape (batch, k, ...) where ...
                indicates dimensions, and k represents the number of background
                reference samples to draw per input in the batch.
        Returns:
            samples_input: A tensor of shape (batch, k, ...) with the
                interpolated points between input and ref.
        '''
        input_dims = list(input_tensor.size())[1:]
        num_input_dims = len(input_dims)

        batch_size = reference_tensor.size()[0]
        k_ = self.k

        # Grab a [batch_size, k]-sized interpolation sample
        if k_ == 1:
            t_tensor = torch.cat([torch.Tensor([1.0]) for _ in range(batch_size)]).to(DEFAULT_DEVICE)
        else:
            t_tensor = torch.cat([torch.linspace(0, 1, k_) for _ in range(batch_size)]).to(DEFAULT_DEVICE)
        # t_tensor = torch.FloatTensor(batch_size, 1).uniform_(0.75, 1).to(DEFAULT_DEVICE)

        shape = [batch_size, k_] + [1] * num_input_dims
        interp_coef = t_tensor.view(*shape)

        # Evaluate the end points
        end_point_ref = (1.0 - interp_coef) * reference_tensor
        input_expand_mult = input_tensor.unsqueeze(1)
        end_point_input = interp_coef * input_expand_mult

        # A fine Affine Combine
        samples_input = end_point_input + end_point_ref
        return samples_input

    def _get_samples_delta(self, input_tensor, reference_tensor):
        input_expand_mult = input_tensor.unsqueeze(1)
        sd = input_expand_mult - reference_tensor
        return sd

    def _get_grads(self, samples_input, sparse_labels=None, require_grad=True):
        samples_input.requires_grad = True

        grad_tensor = torch.zeros(samples_input.shape).float().to(DEFAULT_DEVICE)

        for i in range(self.k):
            particular_slice = samples_input[:, i]
            output = self.model(particular_slice)  # [5, 200] [5, 2000]
            # output = torch.log_softmax(output, 1)  # /output.shape[0]

            # should check that users pass in sparse labels
            # Only look at the user-specified label
            if sparse_labels is not None and output.size(1) > 1:
                sample_indices = torch.arange(0, output.size(0)).to(DEFAULT_DEVICE)
                indices_tensor = torch.cat([
                        sample_indices.unsqueeze(1),
                        sparse_labels.unsqueeze(1)], dim=1)
                output = gather_nd(output, indices_tensor)
            self.model.zero_grad()
            model_grads = grad(
                    outputs=output,
                    inputs=particular_slice,
                    grad_outputs=torch.ones_like(output).to(DEFAULT_DEVICE),
                    create_graph=True)
            grad_tensor[:, i, :] = model_grads[0] if require_grad else model_grads[0].data.detach()
        return grad_tensor

    def _get_input_grads(self, samples_input, sparse_labels=None, require_grad=True, log_loss=False, taylor_sri=False, stability=True):
        if stability:
            samples_input.requires_grad = True
            output = self.model(samples_input)
            if log_loss:
                output = torch.log_softmax(output, 1)  # /output.shape[0]
            if taylor_sri:
                output = output - torch.log_softmax(output, 1)  # /output.shape[0]
            if sparse_labels is not None:
                sample_indices = torch.arange(0, output.size(0)).to(DEFAULT_DEVICE)
                indices_tensor = torch.cat([
                    sample_indices.unsqueeze(1),
                    sparse_labels.unsqueeze(1)], dim=1)
                output = gather_nd(output, indices_tensor)
            self.model.zero_grad()
            model_grads = grad(
                outputs=output,
                inputs=samples_input,
                grad_outputs=torch.ones_like(output).to(DEFAULT_DEVICE),
                create_graph=True)
            return model_grads[0] if require_grad else model_grads[0].data.detach()
        else:
            samples_input.requires_grad = True
            output_logits = self.model(samples_input)
            output_logits = torch.exp(output_logits)
            output = output_logits
            self.model.zero_grad()
            model_grads = grad(
                outputs=output,
                inputs=samples_input,
                grad_outputs=torch.ones_like(output).to(DEFAULT_DEVICE),
                create_graph=True)
            grads = model_grads[0] / torch.sum(output_logits, dim=1).view(output.shape[0], 1, 1, 1)
            return grads if require_grad else grads.data.detach()

    def shap_values(self, input_tensor, sparse_labels=None, require_grad=True, log_loss=False, taylor_sri=False, stability=True):
        """
        Calculate expected gradients approximation of Shapley values for the
        sample ``input_tensor``.

        Args:
            model (torch.nn.Module): Pytorch neural network model for which the
                output should be explained.
            input_tensor (torch.Tensor): Pytorch tensor representing the input
                to be explained.
            sparse_labels (optional, default=None):
            require_grad
            log_loss
            taylor_sri
        """

        if self.scale_by_inputs:
            shape = list(input_tensor.shape)
            reference_tensor = torch.zeros(*shape).cuda().to(DEFAULT_DEVICE)
            reference_tensor = reference_tensor.repeat([self.k, 1, 1, 1])
            shape.insert(1, self.k)
            reference_tensor = reference_tensor.view(shape)

            samples_input = self._get_samples_input(input_tensor, reference_tensor)
            samples_delta = self._get_samples_delta(input_tensor, reference_tensor)
            grad_tensor = self._get_grads(samples_input, sparse_labels, require_grad)
            mult_grads = samples_delta * grad_tensor if self.scale_by_inputs else grad_tensor
            expected_grads = mult_grads.mean(1)
        else:
            expected_grads = self._get_input_grads(input_tensor, sparse_labels, require_grad, log_loss, taylor_sri, stability)
        return expected_grads

    def hessian_values(self, input_tensor, sparse_labels=None):
        """
        Calculate hessian
        """

        # input_tensor.requires_grad = True
        # output = self.model(input_tensor)  # [5, 200] [5, 2000]
        # self.model.zero_grad()
        # model_grads = grad(
        #         outputs=output,
        #         inputs=input_tensor,
        #         grad_outputs=torch.ones_like(output).to(DEFAULT_DEVICE),
        #         retain_graph=True,
        #         create_graph=True)
        # hessian_grads = grad(
        #         outputs=model_grads[0],
        #         inputs=input_tensor,
        #         grad_outputs=torch.ones_like(model_grads[0]).to(DEFAULT_DEVICE),
        #         retain_graph=True,
        #         create_graph=True)

        input_tensor.requires_grad = True
        output = self.model(input_tensor)  # [5, 200] [5, 2000]
        y = torch.nn.functional.log_softmax(output, 1)
        output = torch.nn.functional.nll_loss(y, sparse_labels, reduction='none')
        gradients = grad(outputs=output.sum(), inputs=input_tensor, create_graph=True)[0]  # create_graph is core retain?
        hessian_grads = grad(outputs=gradients.sum(), inputs=input_tensor, create_graph=True)

        return hessian_grads[0]

    def shap_values_diff(self, input_tensor, ref_tensor, sparse_labels=None):
        """
        Calculate expected gradients approximation of Shapley values for the
        sample ``input_tensor``.

        Args:
            model (torch.nn.Module): Pytorch neural network model for which the
                output should be explained.
            input_tensor (torch.Tensor): Pytorch tensor representing the input
                to be explained.
            sparse_labels (optional, default=None):
            inter (optional, default=None)
        """
        shape = list(input_tensor.shape)

        reference_tensor = ref_tensor
        reference_tensor = reference_tensor.repeat([self.k, 1, 1, 1])
        shape.insert(1, self.k)
        reference_tensor = reference_tensor.view(shape)

        samples_input = self._get_samples_input(input_tensor, reference_tensor)
        samples_delta = self._get_samples_delta(input_tensor, reference_tensor)
        grad_tensor = self._get_grads(samples_input, sparse_labels)
        mult_grads = samples_delta * grad_tensor if self.scale_by_inputs else grad_tensor
        expected_grads = mult_grads.mean(1)

        return expected_grads
