"""
Diagonal Transformer Neural Process (TNPD).
Based on github.com/tung-nd/TNP-pytorch/
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from src.models.benchmarks.tnp import TNP, SampleReshaper, create_mask
from src.utils import DataAttr, LossAttr


class TNPD(TNP):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=False
    ):
        super().__init__(
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std
        )
        
        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y*2)
        )

    def create_mask(self, batch, device: str):
        mask = create_mask(batch, device=device, autoreg=False)
        return mask

    def encode(self, batch: DataAttr) -> torch.Tensor:
        """Encode context and target points."""
        # Embed context
        xc_enc = self.embedder(batch.xc, batch.yc)
        # Embed targets (without y values)
        xt_0_enc = self.embedder(batch.xt, torch.zeros_like(batch.yt))

        # Concatenate all embeddings
        encoder_input = torch.cat([xc_enc, xt_0_enc], dim=-2)

        # Pass through transformer
        mask = self.create_mask(batch, device=encoder_input.device)
        out = self.encoder(encoder_input, mask=mask)
        num_targets = batch.xt.shape[-2]

        # Return only target encodings
        return out[:, -num_targets:, :]

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        z_target = self.encode(batch)

        # Predict mean and std
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)

        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)

        # Create distribution
        pred_tar = Normal(mean, std)

        # Compute log likelihood
        tar_ll = pred_tar.log_prob(batch.yt).sum(-1)

        if reduce_ll:
            loss = -tar_ll.mean()
        else:
            loss = -tar_ll

        return LossAttr(
            loss=loss,
            log_likelihood=tar_ll,
            mean=mean,
            std=std
        )

    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Make predictions at target locations.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        batch_size = xc.shape[0]
        dim_y = yc.shape[-1]
        num_target = xt.shape[1]

        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros_like(xt[..., :dim_y])  # Dummy targets
        )

        z_target = self.encode(batch)

        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)

        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)

        if return_samples:
            print("Original TNPD did not implement sampling. Sample from prediction Normal.")
            samples = Normal(mean, std).sample([num_samples]) # [num_samples, B, Nt, Dy]
            samples = SampleReshaper.torch_dist2custom(samples) # Reshape to [B, Nt, num_samples, Dy]
            return samples

        return Normal(mean, std)

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:

        B = xc.shape[0]
        Dx = xc.shape[-1]
        Dy = yc.shape[-1]
        T = xt.shape[1]

        output = []

        xc = xc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, xc.shape[-2], Dx) # [num_samples*B, Nc, Dx]
        yc = yc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, yc.shape[-2], Dy) # [num_samples*B, Nc, Dy]
        xt = xt.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, T, Dx) # [num_samples*B, T, Dx]
        
        batch = DataAttr(
            xc=xc, 
            yc=yc, 
            xt=xt[:, 0:1, :], 
            yt=torch.zeros(num_samples*B, 1, Dy, device=xt.device)
        )

        for t in range(T):
            zt = self.encode(batch) # [num_samples*B, 1, D_model]
            
            mean, std = torch.chunk(
                self.predictor(zt), 2, dim=-1
            )

            if self.bound_std:
                std = 0.05 + 0.95 * F.softplus(std)
            else:
                std = torch.exp(std)

            # Sample without extra dimension
            yhat = Normal(mean, std).sample()  # [num_samples*B, 1, Dy]

            output.append(yhat.view(num_samples, B, 1, Dy))

            if t < T - 1:
                xc = torch.cat([xc, xt[:, t:t+1, :]], dim=1)
                yc = torch.cat([yc, yhat], dim=1)

                batch = DataAttr(
                    xc=xc, 
                    yc=yc, 
                    xt=xt[:, t+1:t+2, :], 
                    yt=torch.zeros(num_samples*B, 1, Dy, device=xt.device)
                )

        samples = torch.cat(output, dim=-2) # [num_samples, B, T, Dy]
        return SampleReshaper.torch_dist2custom(samples) # [B, T, num_samples, Dy]
