"""
Modules for labeling edges.
"""
import math
from typing import Sequence, Union

import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import orthogonal
from gauche.kernels.fingerprint_kernels.tanimoto_kernel import TanimotoKernel
from gpytorch.kernels import RQKernel

from krt.models.mlp import MLP


###########################################################################
###########################################################################
#                        Edge Types Over X Embeddings                     #
###########################################################################
###########################################################################

class L2(nn.Module):

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        return torch.cdist(x_embed, x_embed)


class RBF(nn.Module):

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        return (-0.5 * torch.cdist(x_embed, x_embed).pow(2)).exp()


class Matern(nn.Module):

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        dist = torch.cdist(x_embed, x_embed)
        return ((1 + math.sqrt(5.0) * dist + 5.0 * dist.pow(2) / 3.0)
                * (-math.sqrt(5.0) * dist).exp())


class Periodic(nn.Module):

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        dist = torch.cdist(x_embed, x_embed)
        return (-2.0 * torch.sin(math.pi * dist.abs()).pow(2)).exp()

class Tanimoto(nn.Module):

    def __init__(
        self,
    ):
        super().__init__()    
        self.kernel = TanimotoKernel()

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        return self.kernel(x_embed, x_embed).evaluate()

class RQ(nn.Module):

    def __init__(
        self,
    ):
        super().__init__()    
        self.kernel = RQKernel()

    def forward(self, x_embed: torch.Tensor) -> torch.Tensor:
        """Compute the comparison between the embeddings.

        Args:
            x_embed: Embedded x inputs w shape (B, nhead, L_C + L_T, D // nhead)

        Returns: Comparison matrix with size (B, nhead, L_C + L_T, L_C + L_T)
        """
        return self.kernel(x_embed, x_embed).evaluate()        

###########################################################################
###########################################################################
#                        Modules for Embedding X                          #
###########################################################################
###########################################################################

class IdentityXEmbedder(nn.Module):

    def __init__(
        self,
        dim_x: int,
        d_model: int,
        nhead: int,
    ):
        """Constructor.

        Args:
            dim_x: x dimension.
            d_model: The encoding dimension.
            nhead: how many heads.
        """
        super().__init__()
        self.dim_x = dim_x
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        self.x_scalings = nn.Parameter(torch.empty(nhead, dim_x))
        nn.init.normal_(self.x_scalings, mean=1.0, std=0.1)

    def forward(self, xpts):
        """Embed x.

        Args:
            xpts: The x points (B, L, x dim).

        Returns: (B, L, x_dim * nhead)
        """
        B, L, D = xpts.shape
        x_embed = xpts.view(B, L, 1, D) * self.x_scalings.view(1, 1, self.nhead, D)
        return x_embed


class OrthoXEmbedder(nn.Module):

    def __init__(
        self,
        dim_x: int,
        d_model: int,
        nhead: int,
    ):
        """Constructor.

        Args:
            dim_x: x dimension.
            d_model: The encoding dimension.
            nhead: how many heads.
        """
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        self.x_embedder = nn.ModuleList([
            orthogonal(nn.Linear(dim_x, self.d_head, bias=False))
            for _ in range(nhead)
        ])
        self.embedding_scaling = nn.Parameter(torch.empty(nhead, self.d_head))
        nn.init.normal_(self.embedding_scaling, mean=1.0, std=0.1)

    def forward(self, xpts):
        """Embed x.

        Args:
            xpts: The x points (B, L, x dim).

        Returns: (B, L, d_model)
        """
        B, L, _ = xpts.shape
        x_embed = torch.stack([
            embd(xpts)
            for embd in self.x_embedder
        ], dim=2)  # (B, L, nhead, D / nhead)
        x_embed = x_embed * self.embedding_scaling.view(1, 1, self.nhead, self.d_head)
        return x_embed.view(B, L, self.d_model)

###########################################################################
###########################################################################
#                        Full Edge Encoder                                #
###########################################################################
###########################################################################


class EdgeModule(nn.Module):

    def __init__(
        self,
        dim_x: int,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        x_embed_type: str = 'identity',
        x_embed_depth: int = 4,
        edge_type: str = 'RBF',
        hidden_activation: str = 'relu',
    ):
        """Constructor.

        Args:
            dim_x: x dimension.
            d_model: The encoding dimension.
            nhead: how many heads.
            x_embed_type: How to embed x. Options:
                * identity: The regular x with different lengthscales for each head.
                * ortho: This will project x using an orthogonal matrix to preserve
                         distances and angles.
                * mlp: Use a fully connected network with multiple layers.
            x_embed_depth: The depth of the x embedder. Only used if an MLP is used.
            edge_type: Type of edge distance/similarity. Options:
                * L2
                * RBF
                * Matern
                * Cosine
            hidden_activation: Hidden activation function to use.
        """
        super().__init__()
        self.nhead = nhead
        if x_embed_type == 'nolengthscale':
            self.x_embedder = None
        elif x_embed_type == 'identity':
            self.x_embedder = IdentityXEmbedder(
                dim_x=dim_x,
                d_model=d_model,
                nhead=nhead,
            )
        elif x_embed_type == 'ortho':
            self.x_embedder = OrthoXEmbedder(dim_x=dim_x, d_model=d_model, nhead=nhead)
        elif x_embed_type == 'mlp':
            self.x_embedder = MLP(
                input_dim=dim_x,
                output_dim=d_model,
                hidden_layer_width=dim_feedforward,
                hidden_layer_depth=x_embed_depth,
                hidden_activation=hidden_activation,
            )
        else:
            raise ValueError(f'Unknown X Embedder {x_embed_type}')
        if edge_type == 'L2':
            self.comparison = L2()
        elif edge_type == 'RBF':
            self.comparison = RBF()
        elif edge_type == 'Tanimoto':
            self.comparison = Tanimoto()
        elif edge_type == 'RQ':
            self.comparison = RQ()
        self.out_proj = nn.Linear(nhead, d_model, bias=False)

    def forward(self, xpts: torch.Tensor):
        """Encode the edges of the graph.

        Args:
            xpts: The x points (B, L, x dim).

        Returns: Edge matrix (B, L, L, D)
        """
        B, L, _ = xpts.shape
        if self.x_embedder is not None:
            x_embed = self.x_embedder(xpts).view(B, L, self.nhead, -1).transpose(1, 2)
            edges = self.comparison(x_embed).transpose(1, 3)
        else:
            x_embed = xpts.view(B, 1, L, -1)
            edges = self.comparison(x_embed).transpose(1, 3).repeat(1, 1, 1, self.nhead)

        return self.out_proj(edges)


class EdgeEncoder(nn.Module):

    """
    Can possibly have multiple EdgeModules. If there are multiple, the encodings are
    added together before being put through the output network.
    """

    def __init__(
        self,
        dim_x: int,
        d_model: int,
        nhead: int,
        out_net_depth: int,
        dim_feedforward: int,
        x_embed_types: Union[Sequence[str], str] = 'ortho',
        x_embed_depth: int = 4,
        edge_types: Union[Sequence[str], str] = 'RBF',
        hidden_activation: str = 'relu',
    ):
        """Constructor.

        Args:
            dim_x: x dimension.
            d_model: The encoding dimension.
            nhead: how many heads.
            x_embed_types: How to embed x. Options:
                * ortho: This will project x using an orthogonal matrix to preserve
                         distances and angles.
                * mlp: Use a fully connected network with multiple layers.
                Must have the same amount of embed types as metric types if a seq.
            x_embed_depth: The depth of the x embedder. Only used if an MLP is used.
            edge_types: Type of edge distance/similarity. Options:
                * L2
                * RBF
                * Matern
                * Cosine
            hidden_activation: Hidden activation function to use.
        """
        super().__init__()
        edge_modules = []
        if isinstance(x_embed_types, str):
            x_embed_types = [x_embed_types]
        if isinstance(edge_types, str):
            edge_types = [edge_types]
        for xet, mt in zip(x_embed_types, edge_types):
            edge_modules.append(EdgeModule(
                dim_x=dim_x,
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                x_embed_type=xet,
                x_embed_depth=x_embed_depth,
                edge_type=mt,
                hidden_activation=hidden_activation,
            ))
        self.edge_modules = nn.ModuleList(edge_modules)
        self.out_net = MLP(
            input_dim=d_model,
            output_dim=d_model,
            hidden_layer_width=dim_feedforward,
            hidden_layer_depth=out_net_depth,
            hidden_activation=hidden_activation,
        )

    def forward(self, xpts: torch.Tensor):
        """Encode the edges of the graph.

        Args:
            xpts: The x points (B, L, x dim).

        Returns: Edge matrix (B, L, L, D)
        """
        edges = torch.stack([em(xpts) for em in self.edge_modules]).mean(dim=0)
        return self.out_net(edges)
