"""
Code taken from
https://github.com/tung-nd/TNP-pytorch/blob/master/regression/models/anp.py
"""
import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from attrdict import AttrDict

from krt.models.np.utils import stack, logmeanexp
from krt.models.np.modules import CrossAttnEncoder, PoolingEncoder, Decoder


class AttnLNP(nn.Module):
    def __init__(
        self,
        dim_x=1,
        dim_y=1,
        dim_hid=128,
        dim_lat=128,
        enc_v_depth=4,
        enc_qk_depth=2,
        enc_pre_depth=4,
        enc_post_depth=2,
        dec_depth=3,
        tr_num_samples=4,
        eval_num_samples=50,
    ):

        super().__init__()
        self.tr_num_samples = 4
        self.eval_num_samples = 50
        self.denc = CrossAttnEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                v_depth=enc_v_depth,
                qk_depth=enc_qk_depth)

        self.lenc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                dim_lat=dim_lat,
                self_attn=True,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=dim_hid+dim_lat,
                dim_hid=dim_hid,
                depth=dec_depth)
        self.device = 'cpu'

    def to(self, device):
        self.device = device
        return super().to(device)

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        if num_samples is None:
            num_samples = self.eval_num_samples
        theta = stack(self.denc(xc, yc, xt), num_samples)
        if z is None:
            pz = self.lenc(xc, yc)
            z = pz.rsample() if num_samples is None \
                else pz.rsample([num_samples])
        z = stack(z, xt.shape[-2], -2)
        encoded = torch.cat([theta, z], -1)
        return self.dec(encoded, stack(xt, num_samples))

    def sample(self, xc, yc, xt, z=None, num_samples=None):
        pred_dist = self.predict(xc, yc, xt, z, num_samples)
        return pred_dist.loc

    def forward(self, batch):
        pz = self.lenc(batch.xc, batch.yc)
        qz = self.lenc(batch.x, batch.y)
        z = qz.rsample([self.tr_num_samples])
        py = self.predict(batch.xc, batch.yc, batch.x, z=z,
                          num_samples=self.tr_num_samples)
        return AttrDict({
            'mean': py.mean,
            'std': py.scale,
            'dist': py,
            'qz': qz,
            'pz': pz,
            'z': z,
        })

    def loss(self, batch, model_out, reduce_ll=True):
        outs = AttrDict()
        recon = model_out.dist.log_prob(
                stack(batch.y, self.tr_num_samples)).sum(-1)
        if self.tr_num_samples > 1:
            log_qz = model_out.qz.log_prob(model_out.z).sum(-1)
            log_pz = model_out.pz.log_prob(model_out.z).sum(-1)
            log_w = recon.sum(-1) + log_pz - log_qz
            outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
        else:
            kld = kl_divergence(model_out.qz, model_out.pz).sum(-1).mean()
            outs.loss = (recon.mean() + kld) / batch.x.shape[-1]
        return outs

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        autoreg: bool = True,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).
            autoreg: Whether to do autoregressive approach to compute joint
                joint log likelihood. If this is false then it is assumed that
                target set is conditionally independent.

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        if autoreg:
            lls = torch.zeros(B, device=self.device)
            for lt in range(LT):
                curr_xc = torch.cat([xc, xt[:, :lt]], dim=1)
                curr_yc = torch.cat([yc, yt[:, :lt]], dim=1)
                curr_xt = xt[:, lt:]
                curr_yt = yt[:, lt:]
                with torch.no_grad():
                    dist = self.predict(curr_xc, curr_yc, curr_xt,
                                        num_samples=self.eval_num_samples)
                curr_lls = dist.log_prob(stack(curr_yt, self.eval_num_samples))
                curr_lls = curr_lls.sum(dim=-1)[..., 0]
                lls += logmeanexp(curr_lls)
            return lls
        else:
            with torch.no_grad():
                dist = self.predict(xc, yc, xt, num_samples=self.eval_num_samples)
            lls = dist.log_prob(stack(yt, self.eval_num_samples)).sum(-1).sum(-1)
            lls = logmeanexp(lls)
            return lls

    @property
    def num_samples(self) -> int:
        if self.training:
            return self.tr_num_samples
        return self.eval_num_samples
