import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import kl_divergence

from torch.distributions.normal import Normal
from attrdict import AttrDict

from .dttanp import DTTANP

from ..utils.misc import stack, logmeanexp

class DTTANPD(DTTANP):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=True
    ):
        super(DTTANPD, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(dim_feedforward+d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y*2)
        )

    def forward(self, batch, num_samples, reduce_ll=True):
        outs = AttrDict()
        if self.training:
            pz, qz = self.lencode(batch)

            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])

            py = self.predict(batch.xc, batch.yc, batch.x, z=z, num_samples=num_samples)

            if num_samples > 1 :
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                log_w = recon.sum(-1) + log_pz - log_qz
                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
        else:
            py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]
        # pred_tar = self.predict(batch.xc, batch.yc, batch.x, num_samples)
        # ll = pred_tar.log_prob(batch.y).sum(-1)
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                    # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
            # outs.loss = - (outs.tar_ll)

        return outs

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')

        z_target = self.encode(batch, z=z, num_samples=num_samples, autoreg=False)
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        if self.bound_std:
            std = 0.1 + 0.9 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)
