from functools import partial
from typing import Dict, List
from warnings import warn

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.data import Data


class CKA:
    """Centered Kernel Alignment (CKA) metric,
    where the features of the networks are compared.

    Parameters
    ----------
    model1 : nn.Module
        model 1
    model2 : nn.Module
        model 2
    model1_name : str, optional
        name of model 1, by default None
    model2_name : str, optional
        name of model 2, by default None
    model1_layers : List[str], optional
        List of layers to extract features from, by default None
    model2_layers : List[str], optional
        List of layers to extract features from, by default None
    training : bool, optional
        whether to set training mode (True) or evaluation
        mode (False) for models. by default False.
    device : str, optional
        device to run the models, by default 'cpu'

    Example
    -------
    .. code-block:: python

        data = ... # get your graph
        m1 = ... # get your model1
        m2 = ... # get your model2
        cka = CKA(m1, m2)
        cka.compare(data)
        cka.plot_results()

    Reference:

    * Paper: https://arxiv.org/abs/2010.15327
    * Code: https://github.com/AntixK/PyTorch-Model-Compare
    """
    def __init__(self, model1: nn.Module, model2: nn.Module,
                 model1_name: str = None, model2_name: str = None,
                 model1_layers: List[str] = None,
                 model2_layers: List[str] = None, training: bool = False,
                 device: str = 'cpu'):
        self.model1 = model1
        self.model2 = model2

        self.device = torch.device(device)

        self.model1_info = {}
        self.model2_info = {}

        if model1_name is None:
            self.model1_info['Name'] = model1.__repr__().split('(')[0]
        else:
            self.model1_info['Name'] = model1_name

        if model2_name is None:
            self.model2_info['Name'] = model2.__repr__().split('(')[0]
        else:
            self.model2_info['Name'] = model2_name

        if self.model1_info['Name'] == self.model2_info['Name']:
            warn("Both model have identical names - "
                 f"{self.model2_info['Name']}. "
                 "It may cause confusion when interpreting the results. "
                 "Consider giving unique names to the models :)")

        self.model1_info['Layers'] = []
        self.model2_info['Layers'] = []

        self.model1_features = {}
        self.model2_features = {}

        self.model1_layers = model1_layers
        self.model2_layers = model2_layers

        self._insert_hooks()
        self.model1 = self.model1.to(self.device)
        self.model2 = self.model2.to(self.device)

        self.model1.train(training)
        self.model2.train(training)

    def _log_layer(self, model: str, name: str, layer: nn.Module, inp: Tensor,
                   out: Tensor):
        if out.ndim != 2:
            # ignore those features that dimensions not equal to 2
            return

        if model == "model1":
            self.model1_features[name] = out
        elif model == "model2":
            self.model2_features[name] = out
        else:
            raise RuntimeError(f"Unknown model name `{model}`.")

    def _insert_hooks(self):

        # Model 1
        for name, layer in self.model1.named_modules():
            if self.model1_layers is not None:
                if name in self.model1_layers:
                    self.model1_info['Layers'] += [name]
                    layer.register_forward_hook(
                        partial(self._log_layer, "model1", name))
            else:
                self.model1_info['Layers'] += [name]
                layer.register_forward_hook(
                    partial(self._log_layer, "model1", name))

        # Model 2
        for name, layer in self.model2.named_modules():
            if self.model2_layers is not None:
                if name in self.model2_layers:
                    self.model2_info['Layers'] += [name]
                    layer.register_forward_hook(
                        partial(self._log_layer, "model2", name))
            else:
                self.model2_info['Layers'] += [name]
                layer.register_forward_hook(
                    partial(self._log_layer, "model2", name))

    def _HSIC(self, K, L):
        """Computes the unbiased estimate of HSIC metric.
        Reference: https://arxiv.org/pdf/2010.15327.pdf Eq (3)
        """
        N = K.shape[0]
        ones = torch.ones(N, 1).to(self.device)
        result = torch.trace(K @ L)
        result += ((ones.t() @ K @ ones @ ones.t() @ L @ ones) /
                   ((N - 1) * (N - 2))).item()
        result -= ((ones.t() @ K @ L @ ones) * 2 / (N - 2)).item()
        result = (1 / (N * (N - 3)) * result).item()
        return result

    @torch.no_grad()
    def compare(self, data1: Data, data2: Data = None) -> None:
        """
        Computes the feature similarity between the models on the
        given datasets.

        Parameters
        ----------
        data1 : Data
            the dataset where model 1 run on.
        data2 : Data, optional
            If given, model 2 will run on this dataset. by default None
        """
        data1.to(self.device)
        if data2 is None:
            warn("Data for Model 2 is not given. "
                 "Using the same data for both models.")
            data2 = data1
        else:
            data2 = data2.to(self.device)

        self.model1_features = {}
        self.model2_features = {}

        self.model1(data1.x, data1.edge_index, data1.edge_weight)
        self.model2(data2.x, data2.edge_index, data2.edge_weight)

        N = len(self.model1_layers) if self.model1_layers is not None else len(
            self.model1_features)
        M = len(self.model2_layers) if self.model2_layers is not None else len(
            self.model2_features)
        num_batches = 1

        self.hsic_matrix = torch.zeros(N, M, 3)

        for i, (name1, feat1) in enumerate(self.model1_features.items()):
            X = feat1.flatten(1)
            K = X @ X.t()
            K.fill_diagonal_(0.0)
            self.hsic_matrix[i, :, 0] += self._HSIC(K, K) / num_batches

            for j, (name2, feat2) in enumerate(self.model2_features.items()):
                Y = feat2.flatten(1)
                L = Y @ Y.t()
                L.fill_diagonal_(0)
                if K.shape != L.shape:
                    raise RuntimeError(
                        f"Feature shape mistach! {K.shape} and {L.shape}")

                self.hsic_matrix[i, j, 1] += self._HSIC(K, L) / num_batches
                self.hsic_matrix[i, j, 2] += self._HSIC(L, L) / num_batches
        self.hsic_matrix = self.hsic_matrix[:, :, 1] / (
            self.hsic_matrix[:, :, 0].sqrt() *
            self.hsic_matrix[:, :, 2].sqrt())

        assert not torch.isnan(
            self.hsic_matrix).any(), "HSIC computation resulted in NANs"
        return self

    def export(self) -> Dict:
        """
        Exports the CKA data along with the respective model layer names.
        :return:
        """
        return {
            "model1_name": self.model1_info['Name'],
            "model2_name": self.model2_info['Name'],
            "CKA": self.hsic_matrix,
            "model1_layers": self.model1_info['Layers'],
            "model2_layers": self.model2_info['Layers'],
        }

    def plot_results(self, save_path: str = None, title: str = None):
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots()
        im = ax.imshow(self.hsic_matrix, origin='lower', cmap='magma')
        ax.set_xlabel(f"Layers of {self.model2_info['Name']}", fontsize=15)
        ax.set_ylabel(f"Layers of {self.model1_info['Name']}", fontsize=15)

        if title is not None:
            ax.set_title(f"{title}", fontsize=18)
        else:
            ax.set_title(
                f"{self.model1_info['Name']} vs {self.model2_info['Name']}",
                fontsize=18)

        add_colorbar(im)
        plt.tight_layout()

        if save_path is not None:
            plt.savefig(save_path, dpi=300)

        plt.show()


def add_colorbar(im, aspect=10, pad_fraction=0.5, **kwargs):
    """Add a vertical color bar to an image plot."""
    import matplotlib.pyplot as plt
    from mpl_toolkits import axes_grid1
    divider = axes_grid1.make_axes_locatable(im.axes)
    width = axes_grid1.axes_size.AxesY(im.axes, aspect=1. / aspect)
    pad = axes_grid1.axes_size.Fraction(pad_fraction, width)
    current_ax = plt.gca()
    cax = divider.append_axes("right", size=width, pad=pad)
    plt.sca(current_ax)
    return im.axes.figure.colorbar(im, cax=cax, **kwargs)
