"""
The graph transformer neural process.
"""
import math
from typing import Optional, Sequence, Union

from attrdict import AttrDict
import torch
from torch.distributions.normal import Normal
import torch.nn as nn
from torch.nn import functional as F

from krt.models.mlp import MLP
from krt.models.edges import EdgeEncoder
from krt.utils import get_activation


class GTNPBlock(nn.Module):

    def __init__(
        self,
        dim_y: int,
        d_model: int,
        nhead: int,
        dim_feedforward: int,
        hidden_activation: str = 'relu',
    ):
        super().__init__()
        assert d_model % nhead == 0
        self.d_model = d_model
        self.nhead = nhead
        self.d_head = d_model // nhead
        # Self attention parameters (for condition set).
        self.sa_aproj = nn.Linear(self.d_model,
                                  3 * self.d_model, bias=False)
        self.edge_aproj = nn.Linear(self.d_model,
                                    2 * self.d_model, bias=False)
        self.out_proj = nn.Linear(self.d_model, self.d_model, bias=False)
        self.ln1, self.ln2 = [nn.LayerNorm(d_model) for _ in range(2)]
        self.compute_fc = nn.Linear(d_model, dim_feedforward, bias=False)
        self.compute_proj = nn.Linear(dim_feedforward, d_model, bias=False)
        self.hidden_activation = get_activation(hidden_activation)
        self.device = 'cpu'

    def to(self, device):
        self.device = device
        return super().to(device)

    def forward(
        self,
        nodes: torch.Tensor,
        edges: torch.Tensor,
        num_ctx: int,
    ):
        """Forward pass to evolve the nodes.

        Args:
            nodes: The node features (B, L + LT, D)
            edges: The edge featues (B, L + LT, L + LT, D)
            num_ctx: The number of context points in the set.

        Returns: The next evolution of the nodes.
        """
        B, L, D = nodes.shape
        LC = num_ctx
        LT = (L - LC) // 2
        # Each of these has shape (B, nhead, L, D // nhead).
        q, k, v = [vout.view(B, L, self.nhead, self.d_head).transpose(1, 2)
                   for vout in self.sa_aproj(self.ln1(nodes))
                                   .split(self.d_model, dim=2)]
        # Shape (B, nhead, L, L, D // nhead).
        ek, ev = [vout.view(B, L, L, self.nhead, self.d_head).transpose(1, 3)
                  for vout in self.edge_aproj(edges).split(self.d_model, dim=-1)]
        attn = q @ k.transpose(-2, -1)
        attn = (
            attn
            + ((q.unsqueeze(3).unsqueeze(-2) @ ek.unsqueeze(-1)))
            .view(B, self.nhead, L, L)
        )
        attn = attn / math.sqrt(D)
        mask = torch.ones(L, L, device=self.device)
        mask[:LC, LC:] = 0.0   # No context should attend to targets.
        mask[LC:LT + LC, LC:LT + LC].tril_()  # Target set has causal.
        mask[-LT:, LC:LT + LC].tril_(diagonal=-1)  # Padded targets attend to previous.
        mask[:, -LT:] = 0.0  # Never attend to any padded targets.
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        attn_out = attn @ v
        attn_out = attn_out + torch.einsum('bnij,bnijd->bnid', attn, ev)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, D)
        nodes = nodes + self.out_proj(attn_out)
        nodes = nodes + self.compute_proj(
                self.hidden_activation(
                    self.compute_fc(
                        self.ln2(nodes))))
        return nodes


class GTNP(nn.Module):

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        nhead: int,
        num_blocks: int,
        dim_x_to_kernelize: int,
        dim_feedforward: Optional[int] = None,
        embedder_depth: int = 4,
        x_embed_types: Union[Sequence[str], str] = 'ortho',
        x_embed_depth: int = 4,
        edge_types: Union[Sequence[str], str] = 'RBF',
        hidden_activation: str = 'relu',
        std_bounds: Optional[Sequence[float]] = None,
        log_stats: bool = False,
        conservative_weight_init: bool = False,
    ):
        """Constructor.

        Args:
            dim_x: X data dimension.
            dim_y: Y data dimension.
            d_model: Embedding size.
            n_head: Number of heads.
            embedder_depth: Depth of the embedding for the nodes and edges.
            x_embed_types: Type of embeddings to use for x.
            x_embed_depth: Depth of embedder to use for x if mlp is a type.
            edge_types: Type of distance/kernel to use for edges.
            hidden_activation: hidden activation to use.
            std_bounds: Bounds for the standard deviation. If None then do not bound.
            conservative_weight_init: Whether to intialize the weights closer to 0.
                This is the same weight initialization as used in Karpathy's
                nanoGPT repo. However, it seems that the model often gets stuck at
                the intialization most of the time if we turn this on.
        """
        super().__init__()
        assert d_model % nhead == 0
        self.dim_x = dim_x
        self.dim_x_to_kernelize = dim_x_to_kernelize
        self.dim_y = dim_y
        self.d_model = d_model
        self.log_stats = log_stats
        if dim_feedforward is None:
            dim_feedforward = 4 * d_model
        self.blocks = nn.ModuleList([
            GTNPBlock(
                dim_y=dim_y,
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                hidden_activation=hidden_activation,
            )
            for _ in range(num_blocks)
        ])
        self.node_encoder = MLP(
            input_dim=dim_y,
            output_dim=d_model,
            hidden_layer_width=d_model,
            hidden_layer_depth=embedder_depth,
        )
        self.edge_encoder = EdgeEncoder(
            dim_x=dim_x_to_kernelize,
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model,
            out_net_depth=embedder_depth,
            x_embed_types=x_embed_types,
            x_embed_depth=x_embed_depth,
            edge_types=edge_types,
            hidden_activation=hidden_activation,
        )
        self.decoder = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 2 * dim_y)
        )
        if conservative_weight_init:
            self.apply(self._init_weights)
            for pn, p in self.named_parameters():
                if pn.endswith('aproj.weight'):
                    torch.nn.init.normal_(p, mean=0.0,
                                          std=0.02/math.sqrt(2 * num_blocks))
        if std_bounds is None:
            self.std_min, self.std_range = None, None
        else:
            self.std_min = std_bounds[0]
            self.std_range = std_bounds[1] - std_bounds[0]
        self.device = 'cpu'

    def _init_weights(self, module: torch.nn.Module):
        """Intialize the weights of a module."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def to(self, device):
        self.device = device
        for block in self.blocks:
            block.to(self.device)
        return super().to(device)

    def forward(
        self,
        batch: AttrDict,
    ):
        """Forward

        Args:
            batch: Batch containing:
                xt: Target x points with shape (batch, L_T, x dim).
                xc: Condition x points with shape (batch, L_C, x dim).
                yc: Condition y points with shape (batch, L_C, y dim)

        Returns: Output containing
            mean: Mean prediction w shape (batch, L_T, y_dim)
            std: Standard deviation prediction w shape (batch, L_T, y_dim)
        """
        B, LC, YD = batch.yc.shape
        _, LT, _ = batch.xt.shape
        x_in = torch.cat([batch.xc, batch.xt, batch.xt], dim=1)
        y_in = torch.cat([batch.yc, batch.yt,
                          torch.zeros(B, LT, YD, device=self.device)], dim=1)
        edges = self.edge_encoder(x_in[:, :, :self.dim_x_to_kernelize])
        nodes = self.node_encoder(y_in)
        for block in self.blocks:
            nodes = block(nodes, edges, LC)
        mean, std_out = self.decoder(nodes[:, -LT:]).split(self.dim_y, dim=-1)
        if self.std_min is None:
            std = std_out.exp()
        else:
            std = torch.sigmoid(std_out) * self.std_range + self.std_min
        stats = {}
        if self.log_stats:
            for sk, sv in [('mean', mean), ('std', std)]:
                stats[f'{sk}/mean'] = sv.mean().item()
                stats[f'{sk}/std'] = sv.std().item()
                stats[f'{sk}/min'] = sv.min().item()
                stats[f'{sk}/max'] = sv.max().item()
        return AttrDict({
            'mean': mean,
            'std': std,
            'stats': stats,
        })

    def loss(self, batch, model_out, reduce_ll=True):
        pred_tar = Normal(model_out.mean, model_out.std)
        outs = AttrDict()
        if reduce_ll:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
#            outs.tar_ll = torch.clip(pred_tar.log_prob(batch.yt),-1e2, 1e2).sum(-1).mean()
        else:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
#            outs.tar_ll = torch.clip(pred_tar.log_prob(batch.yt),-1e2, 1e2).sum(-1)
        outs.loss = - (outs.tar_ll)
        stats = {}
        if self.log_stats:
            stats['mse'] = (model_out.mean - batch.yt).pow(2).sum(dim=-1).mean().item()
        outs.stats = stats
        return outs

    def predict(self, xc, yc, xt) -> Normal:
        """Predict y for the x given the condition.

        Args:
            xc: The x data with shape (num_points, num_conditions, dim_x).
            yc: The y data with shape (num_points, num_conditions, dim_y).
            xt: The x points to predict for as shape (num_points, num_targets, dim_x).
        """
        B, LC, dy = yc.shape
        _, LT, dx = xt.shape
        # Predict.
        model_out = self.forward(AttrDict({
            'xc': xc,
            'yc': yc,
            'xt': xt,
            'yt': torch.zeros(B, LT, dy, device=self.device),
        }))
        return Normal(model_out.mean, model_out.std)

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        batch = AttrDict({'xc': xc, 'yc': yc, 'xt': xt, 'yt': yt})
        with torch.no_grad():
            model_out = self.forward(batch)
        loss_out = self.loss(batch, model_out, reduce_ll=False)
        return -1 * loss_out.loss.sum(dim=1)
