"""
Diagonal TNP. Adapted from
https://github.com/tung-nd/TNP-pytorch/blob/master/regression/models/tnpd.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from attrdict import AttrDict

from krt.models.tnp.tnp import TNP


class TNPD(TNP):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=False
    ):
        super(TNPD, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y*2)
        )

    def forward(self, batch):
        z_target = self.encode(batch, autoreg=False)
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)
        return AttrDict({'mean': mean, 'std': std})

    def loss(self, batch, model_out, reduce_ll=True):
        pred_tar = Normal(model_out.mean, model_out.std)
        outs = AttrDict()
        if reduce_ll:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
        else:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
        outs.loss = - (outs.tar_ll)
        return outs

    def predict(self, xc, yc, xt):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]),
                               device=self.device)

        z_target = self.encode(batch, autoreg=False)
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        batch = AttrDict({'xc': xc, 'yc': yc, 'xt': xt, 'yt': yt})
        with torch.no_grad():
            model_out = self.forward(batch)
        loss_out = self.loss(batch, model_out, reduce_ll=False)
        return -1 * loss_out.loss.sum(dim=1)
