import copy
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import reduce
from torch import nn
from torchvision.transforms.functional import gaussian_blur
from tqdm import tqdm


class ParallelBackprop():
    ''' Computes parallel backpropagation'''
    def __init__(self, model, device):
        '''
        Args:
            model: model of class GaussianReadoutModel
            device: device for calculations, e.g. 'cpu' or 'cuda'
        '''
        self.model = model
        self.device = device

    def joint_gradient(self, imgs, neuron):
        '''
        Main function. Computes parallel backpropagation
        Args:
            imgs: list containing images as torch.tensor()
            neuron: neuron idx w.r.t. the matrix of readout weights of the model

        Returns: list containing the gradients in same order as img input

        '''
        size = (len(imgs), 3, imgs[0].shape[-1], imgs[0].shape[-1])

        all_imgs = torch.zeros((size))
        for i, img in enumerate(imgs):
            all_imgs[i, :, :, :] = img

        # Cast imgs to Parameter to compute its grad during call to .backward()
        img_tensor = torch.nn.Parameter(data=all_imgs.to(self.device))
        img_tensor.retain_grad()

        model = copy.deepcopy(self.model)

        y = model(img_tensor)[:, neuron]

        grads = []
        for i in range(len(imgs)):
            # Get readout weights of model, and latent features of imgs
            _, w, a = model(img_tensor, return_features=True)
            a = a.squeeze(dim=-1)[:, :, neuron]
            w = w.squeeze()[:, neuron]
            w = nn.functional.relu(w)   # Rectifiy w to exlude dimensions leading to less activity

            # Compute weights for features (\omega)
            grad_weights = ((a[0, :] * w) / torch.linalg.norm((a[0, :] * w))) * ((a[1, :] * w) /
                                                                                 torch.linalg.norm((a[1, :] * w)))
            max_grad_weights = torch.max(grad_weights)

            # Sparsify dimensions to those with nonzero weights to speed up Jacobian computation
            jacobian_dims = torch.argwhere(grad_weights > 0.001 * max_grad_weights)

            grad_weights = grad_weights[jacobian_dims]

            def func(img_tensor):
                # write forward pass as separate function to be passed into autograd
                _, w, a = model(img_tensor, return_features=True)
                a = a.squeeze()[:, :, neuron]
                return a[i, jacobian_dims]

            # compute the jacobian
            jacob = torch.autograd.functional.jacobian(func, img_tensor, vectorize=False).squeeze()
            jacob = nn.functional.relu(jacob)
            jacob = jacob[:, i, :, :].sum(1)

            # spatially smooth the output
            jacob = gaussian_blur(jacob, kernel_size=41, sigma=5)
            jacob = jacob / (torch.linalg.norm(jacob, dim=(1, 2), keepdim=True) + 1e-5) # norm
            grad = (jacob * grad_weights.unsqueeze(-1)).sum(0).detach().cpu().numpy()

            # Arbitrary scaling for plotting purposes
            grad = 50 * grad
            grads.append(grad)

            # Zero grads
            img_tensor.grad = None
            for param in model.parameters():
                param.grad = None
            model.zero_grad()

        activities = []

        for i in range(len(imgs)):
            activities.append(y[i].detach().cpu())
        return grads

    def show_image(self, imgs, grads, activities=None, print_activities=True, save_path=None):
        for i, (img, grad, activ) in enumerate(zip(imgs, grads, activities)):
            plt.imshow(img, 'gray', interpolation='none')
            plt.imshow(grad, cmap='hot', interpolation='nearest', alpha=np.minimum(grad + 1e-5, 0.9))
            if print_activities:
                plt.title(str(activ.item()))
            if save_path is not None:
                plt.savefig(join(save_path, str(i)))
            if save_path is None:
                plt.show()
            plt.cla()

    def show_side_by_side(self, imgs, grads, save_path=None):
        '''
        Plots images and gradients as seen in the paper.
        Args:
            imgs: list containing imgs
            grads: list containing gradients
            save_path:
        '''
        img = np.concatenate(imgs, axis=1)
        grad = np.concatenate(grads, axis=1)

        plt.imshow(img, 'gray', interpolation='none')
        plt.imshow(grad, cmap='hot', interpolation='nearest', alpha = np.minimum(1, grad), vmin=0, vmax=1)
        plt.axis('off')
        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path, dpi=300, bbox_inches = 'tight', pad_inches = 0)
        plt.cla()
        return None


class IntegratedParallelBackprop(ParallelBackprop):
    '''Computes parallel backprop using integrated gradients'''
    def integrated_gradient(self, imgs, neuron, n_steps=50):
        '''
        Args:
            imgs: list containing images as torch.tensor()
            neuron: neuron idx w.r.t. the matrix of readout weights of the model
            n_steps: Number of steps used for approximating the path integral

        Returns: list containing the gradients in same order as img input

        '''
        size = (len(imgs), 3, imgs[0].shape[-1], imgs[0].shape[-1])

        all_imgs = torch.zeros((size))
        for i, img in enumerate(imgs):
            all_imgs[i, :, :, :] = img

        img_tensor = torch.nn.Parameter(data=all_imgs.to(self.device))
        img_tensor.retain_grad()

        model = copy.deepcopy(self.model)
        y = model(img_tensor)[:, neuron]

        grads = []
        # Loop over images
        for i in range(len(imgs)):
            _, w, a = model(img_tensor, return_features=True)
            # Select proper dimensions
            a = a.squeeze(dim=-1)[:, :, neuron]
            w = w.squeeze()[:, neuron]
            w = nn.functional.relu(w)   # Rectifiy w to exclude dimensions leading to less activity
            grad_weights = ((a[0, :] * w) / torch.linalg.norm((a[0, :] * w))) * ((a[1, :] * w) /
                                                                                 torch.linalg.norm((a[1, :] * w)))
            max_grad_weights = torch.max(grad_weights)

            jacobian_dims = torch.argwhere(grad_weights > 0.001 * max_grad_weights)

            grad_weights = grad_weights[jacobian_dims]

            # Integrated Gradients
            path_tensor = img_tensor.clone()
            grad_steps = []

            # Loop over n_steps to approximate the path integral
            for step in tqdm(range(n_steps)):
                # prepare image at this step
                int_tensor = (step / n_steps) * path_tensor

                def func(img_tensor):
                    _, w, a = model(img_tensor, return_features=True)
                    a = a.squeeze()[:, :, neuron]
                    return a[i, jacobian_dims]

                jacob = torch.autograd.functional.jacobian(func, int_tensor, vectorize=False).squeeze()
                jacob = nn.functional.relu(jacob)
                jacob = jacob[:, i, :, :, :]
                jacob = reduce(jacob, 'feats c h w -> feats h w', 'sum')
                grad_steps.append(jacob)

                # Zero grads
                int_tensor.grad = None
                for param in model.parameters():
                    param.grad = None
                model.zero_grad()


            jacob = torch.stack(grad_steps)
            jacob = reduce(jacob, 'n_steps feats h w -> feats h w', 'mean')


            from torchvision.transforms.functional import gaussian_blur
            jacob = gaussian_blur(jacob, kernel_size=41, sigma=5)
            jacob = jacob / (torch.linalg.norm(jacob, dim=(1, 2), keepdim=True) + 1e-5)
            grad = (jacob * grad_weights.unsqueeze(-1)).sum(0).detach().cpu().numpy()
            grad = 50 * grad
            grads.append(grad)

            # Zero grads
            img_tensor.grad = None
            for param in model.parameters():
                param.grad = None
            model.zero_grad()

        activities = []

        for i in range(len(imgs)):
            activities.append(y[i].detach().cpu())
        return grads