"""
Graph Edge Evolution Neural Process.

Structure of code inspired by:
    https://github.com/karpathy/nanoGPT
"""
import math
from typing import Optional

from attrdict import AttrDict
import torch
from torch.distributions.normal import Normal
import torch.nn as nn

from krt.models.mlp import MLP
from krt.models.edges import EdgeModule


class RowAttention(nn.Module):
    """Module that does self attention over rows in a matrix."""

    def __init__(
        self,
        d_model: int,
        nhead: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // self.nhead
        self.c_attn = nn.Linear(self.d_model, 3 * self.d_model)
        self.c_proj = nn.Linear(self.d_model, self.d_model)

    def forward(self, x):
        """
        The idea here is that we want to be able to take linear combinations of
        the rows of K, similar to Gaussian Elimination, which would allow us to learn
        operations similar to matrix inversions, which is how classical GP inference
        works. Calculating row attentions is one way of doing this, but we need to
        make sure that whatever we do to compute the row attentions is also invariant
        to index invariance of the original K matrix (self-attention alone on the rows
        will only achieve half of this).

        Args:
            x: Current matrix of shape (B, Nr, Nc, D)

        Returns: Residuals to apply of shape (B, Nr, Nc, D)
        """
        B, Nr, Nc, D = x.shape
        q, k, v = self.c_attn(x).split(self.d_model, dim=3)
        # All of the following has shape (B, nh, Nc, Nr, d)
        q = q.view(B, Nr, Nc, self.nhead, self.d_head).permute(0, 3, 2, 1, 4)
        k = k.view(B, Nr, Nc, self.nhead, self.d_head).permute(0, 3, 2, 1, 4)
        v = v.view(B, Nr, Nc, self.nhead, self.d_head).permute(0, 3, 2, 1, 4)
        # Use multi-headed self-attention to process each column independently, giving
        # us a set of independent row attentions.
        # Shape (B, nh, Nc, Nr, Nr)
        att = (q @ k.transpose(-1, -2)) * (1.0 / math.sqrt(k.size(-1)))
        # Then average the row attentions over the columns to "fuse" information
        # across columns while still maintaining index permutation invariance
        att = torch.softmax(att.mean(2), dim=-1).unsqueeze(-1)
        att = att.permute(0, 1, 4, 2, 3)  # (B, nh, 1, Nr, Nr)
        v = v.transpose(2, 4)  # (B, nh, d, Nr, Nc)
        x = att @ v
        # (B, nh, d, Nr, Nc)
        x = x.permute(0, 3, 4, 1, 2).contiguous().view(B, Nr, Nc, self.d_model)
        # [B, Nr, Nc, D]
        return self.c_proj(x)


class GEENPBlock(nn.Module):

    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = None,
        dropout: float = 0.0,
    ):
        super().__init__()
        if dim_feedforward is None:
            dim_feedforward = 4 * d_model
        self.ln_1, self.ln_2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(dropout),
        )
        self.attn = RowAttention(d_model, nhead)

    def forward(self, x):
        """Forward pass.

        Args:
            x: Matrix of shape (B, Nr, Nc, D)

        Returns: Matrix of shape (B, Nr, Nc, D)
        """
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class GEENP(nn.Module):

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        nhead: int,
        num_blocks: int,
        dim_x_to_kernelize: int,
        embedder_depth: int = 4,
        dim_feedforward: Optional[int] = None,
        x_embed_type: str = 'identity',
        edge_type: str = 'L2',
        hidden_activation: str = 'relu',
        only_one_target: bool = False,  # Only support predicting one target at a time.
    ):
        super().__init__()
        assert d_model % nhead == 0
        self.dim_x = dim_x
        self.dim_x_to_kernelize = dim_x_to_kernelize
        self.dim_y = dim_y
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        self.only_one_target = only_one_target
        if dim_feedforward is None:
            dim_feedforward = 4 * d_model
        self.edge_module = EdgeModule(
            dim_x=dim_x_to_kernelize,
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model,
            x_embed_type=x_embed_type,
            edge_type=edge_type,
            hidden_activation=hidden_activation,
        )
        self.edge_mlp = MLP(
            input_dim=d_model + 2 * self.dim_y + 2,
            output_dim=d_model,
            hidden_layer_width=d_model,
            hidden_layer_depth=embedder_depth,
            hidden_activation=hidden_activation,
        )
        self.blocks = nn.ModuleList([GEENPBlock(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
        ) for _ in range(num_blocks)])
        self.decoder = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 2 * dim_y)
        )
        self.device = 'cpu'

    def to(self, device):
        self.device = device
        return super().to(device)

    def forward(
        self,
        batch: AttrDict,
    ):
        """Forward

        Args:
            batch: Batch containing:
                xt: Target x points with shape (batch, L_T, x dim).
                xc: Condition x points with shape (batch, L_C, x dim).
                yc: Condition y points with shape (batch, L_C, y dim)

        Returns: Output containing
            mean: Mean prediction w shape (batch, L_T, dim_y)
            std: Standard deviation prediction w shape (batch, L_T, dim_y)
        """
        B, LC, YD = batch.yc.shape
        if self.only_one_target:
            LT = 1
            xt = batch.xt[:, :1]
        else:
            _, LT, _ = batch.xt.shape
            xt = batch.xt
        L = LC + LT
        # Prepare edge matrix.
        x = torch.cat([batch.xc, xt], dim=1)
        x_edges = self.edge_module(x[:, :, :self.dim_x_to_kernelize])  # (B, L, L, D)
        y_edges = torch.zeros(B, L, L, 2 * self.dim_y, device=self.device)
        y_edges[:, :LC, :, :self.dim_y] = batch.yc.unsqueeze(2)
        y_edges[:, :, :LC, self.dim_y:] = batch.yc.unsqueeze(1)
        real_edge_indicator = torch.zeros(B, L, L, 2, device=self.device)
        real_edge_indicator[:, :LC, :, 0] = 1
        real_edge_indicator[:, :, :LC, 1] = 1
        edges = self.edge_mlp(torch.cat([
            x_edges,
            y_edges,
            real_edge_indicator,
        ], dim=-1))  # (B, L, L, D)
        # Pass through row attention blocks.
        for block in self.blocks:
            edges = block(edges)
        # Predict mean and std.
        mean, std = self.decoder(
            edges[:, LC:, LC:].diagonal(dim1=1, dim2=2)
                              .transpose(-1, -2)
        ).split(self.dim_y, dim=-1)
        std = std.exp()
        return AttrDict({
            'mean': mean,
            'std': std,
        })

    def loss(self, batch, model_out, reduce_ll=True):
        pred_tar = Normal(model_out.mean, model_out.std)
        outs = AttrDict()
        if self.only_one_target:
            yt = batch.yt[:, :1]
        else:
            yt = batch.yt
        if reduce_ll:
            outs.tar_ll = pred_tar.log_prob(yt).sum(-1).mean()
        else:
            outs.tar_ll = pred_tar.log_prob(yt).sum(-1)
        outs.loss = - (outs.tar_ll)
        return outs

    def predict(self, xc, yc, xt) -> Normal:
        """Predict y for the x given the condition.

        Args:
            xc: The x data with shape (num_points, num_conditions, dim_x).
            yc: The y data with shape (num_points, num_conditions, dim_y).
            xt: The x points to predict for as shape (num_points, num_targets, dim_x).
        """
        B, LC, dy = yc.shape
        _, LT, dx = xt.shape
        # Predict.
        model_out = self.forward(AttrDict({
            'xc': xc,
            'yc': yc,
            'xt': xt,
        }))
        return Normal(model_out.mean, model_out.std)

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        autoreg: bool = True,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).
            autoreg: Whether to do autoregressive approach to compute joint
                joint log likelihood. If this is false then it is assumed that
                target set is conditionally independent.

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        if autoreg:
            lls = torch.zeros(B, device=self.device)
            for lt in range(LT):
                curr_xc = torch.cat([xc, xt[:, :lt]], dim=1)
                curr_yc = torch.cat([yc, yt[:, :lt]], dim=1)
                curr_xt = xt[:, lt:]
                curr_yt = yt[:, lt:]
                with torch.no_grad():
                    dist = self.predict(curr_xc, curr_yc, curr_xt)
                lls += dist.log_prob(curr_yt).sum(dim=-1)[:, 0]
            return lls
        else:
            with torch.no_grad():
                dist = self.predict(xc, yc, xt)
            lls = dist.log_prob(yt).sum(-1).sum(-1)
            return lls
