"""
Code adapted from
https://github.com/tung-nd/TNP-pytorch/blob/master/regression/models/canp.py
"""
import torch
import torch.nn as nn
from attrdict import AttrDict

from krt.models.np.modules import CrossAttnEncoder, Decoder, PoolingEncoder


class AttnCNP(nn.Module):
    def __init__(
        self,
        dim_x=1,
        dim_y=1,
        dim_hid=128,
        enc_v_depth=4,
        enc_qk_depth=2,
        enc_pre_depth=4,
        enc_post_depth=2,
        dec_depth=3,
    ):

        super().__init__()

        self.enc1 = CrossAttnEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                v_depth=enc_v_depth,
                qk_depth=enc_qk_depth)

        self.enc2 = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                self_attn=True,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=2*dim_hid,
                dim_hid=dim_hid,
                depth=dec_depth)
        self.device = 'cpu'

    def to(self, device):
        self.device = device
        return super().to(device)

    def predict(self, xc, yc, xt, num_samples=None):
        theta1 = self.enc1(xc, yc, xt)
        theta2 = self.enc2(xc, yc)
        encoded = torch.cat([theta1,
                             torch.stack([theta2]*xt.shape[-2], -2)], -1)
        return self.dec(encoded, xt)

    def forward(self, batch):
        py = self.predict(batch.xc, batch.yc, batch.x)
        return AttrDict({
            'mean': py.mean,
            'std': py.scale,
            'dist': py,
        })

    def loss(self, batch, model_out, reduce_ll=True):
        outs = AttrDict()
        outs.loss = -1 * model_out.dist.log_prob(batch.y).sum(-1).mean()
        return outs

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        autoreg: bool = True,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).
            autoreg: Whether to do autoregressive approach to compute joint
                joint log likelihood. If this is false then it is assumed that
                target set is conditionally independent.

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        if autoreg:
            lls = torch.zeros(B, device=self.device)
            for lt in range(LT):
                curr_xc = torch.cat([xc, xt[:, :lt]], dim=1)
                curr_yc = torch.cat([yc, yt[:, :lt]], dim=1)
                curr_xt = xt[:, lt:]
                curr_yt = yt[:, lt:]
                with torch.no_grad():
                    dist = self.predict(curr_xc, curr_yc, curr_xt)
                lls += dist.log_prob(curr_yt).sum(dim=-1)[:, 0]
            return lls
        else:
            with torch.no_grad():
                dist = self.predict(xc, yc, xt)
            lls = dist.log_prob(yt).sum(-1).sum(-1)
            return lls
