"""
Diagonal Transformer Neural Process (TNPD).
Based on github.com/tung-nd/TNP-pytorch/
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from src.models.benchmarks.tnp import TNP, SampleReshaper, create_mask
from src.utils import DataAttr, LossAttr


class TNPD(TNP):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=False,
        pos_emb_init: bool = False,
    ):
        super().__init__(
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std, pos_emb_init=pos_emb_init
        )
        
        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y*2)
        )

    def create_mask(self, batch, device: str):
        mask = create_mask(batch, device=device, autoreg=False)
        return mask

    def encode(self, batch: DataAttr) -> torch.Tensor:
        """Encode context and target points."""
        # Embed context
        xc_enc = self.embedder.embed_context(batch)
        # Embed targets (without y values)
        xt_0_enc = self.embedder.embed_target(
            DataAttr(xt=batch.xt, yt=torch.zeros_like(batch.yt))
        )

        # Concatenate all embeddings
        encoder_input = torch.cat([xc_enc, xt_0_enc], dim=-2)

        # Pass through transformer
        mask = self.create_mask(batch, device=encoder_input.device)
        out = self.encoder(encoder_input, mask=mask)
        num_targets = batch.xt.shape[-2]

        # Return only target encodings
        return out[:, -num_targets:, :]

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        z_target = self.encode(batch)

        # Predict mean and std
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)

        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)

        # Create distribution
        pred_tar = Normal(mean, std)

        # Compute log likelihood
        tar_ll = pred_tar.log_prob(batch.yt).mean(-1)

        if reduce_ll:
            loss = -tar_ll.mean()
        else:
            loss = -tar_ll

        return LossAttr(
            loss=loss,
            log_likelihood=tar_ll,
            mean=mean,
            std=std
        )

    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Make predictions at target locations.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        batch_size = xc.shape[0]
        dim_y = yc.shape[-1]
        num_target = xt.shape[1]

        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros((batch_size, num_target, dim_y), device=xt.device, dtype=xt.dtype) # Dummy targets
        )

        z_target = self.encode(batch)

        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)

        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)

        if return_samples:
            print("Original TNPD did not implement sampling. Sample from prediction Normal.")
            samples = Normal(mean, std).sample([num_samples]) # [num_samples, B, Nt, Dy]
            samples = SampleReshaper.torch_dist2custom(samples) # Reshape to [B, Nt, num_samples, Dy]
            return samples

        return Normal(mean, std)

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:

        B = xc.shape[0]
        Dx = xc.shape[-1]
        Dy = yc.shape[-1]
        C = xc.shape[1]
        T = xt.shape[1]

        output = []

        xc_stacked = xc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, xc.shape[-2], Dx) # [num_samples*B, Nc, Dx]
        yc_stacked = yc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, yc.shape[-2], Dy) # [num_samples*B, Nc, Dy=1]
        xt_stacked = xt.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, T, Dx) # [num_samples*B, T, Dx]
        yt_stacked = torch.zeros(num_samples*B, T, Dy, device=xt.device, dtype=xt.dtype) # [num_samples*B, T, Dy]

        x_stacked = torch.cat([xc_stacked, xt_stacked], dim=1)
        y_stacked = torch.cat([yc_stacked, yt_stacked], dim=1)

        for t in range(T):
            batch = DataAttr(
                xc=x_stacked[:, :C+t, :],
                yc=y_stacked[:, :C+t, :],
                xt=x_stacked[:, C+t:C+t+1, :],
                yt=y_stacked[:, C+t:C+t+1, :],
            )
            torch.cuda.empty_cache()
            zt_stacked = self.encode(batch) # [num_samples*B, 1, D_model]

            mean, std = torch.chunk(
                self.predictor(zt_stacked), 2, dim=-1
            )

            if self.bound_std:
                std = 0.05 + 0.95 * F.softplus(std)
            else:
                std = torch.exp(std)

            # Sample without extra dimension
            yhat = Normal(mean, std).sample()  # [num_samples*B, 1, Dy]

            output.append(yhat.view(num_samples, B, 1, Dy))

            if t < T - 1:
                y_stacked[:, C+t, :] = yhat.squeeze(1)

        samples = torch.cat(output, dim=-2) # [num_samples, B, T, Dy]
        return SampleReshaper.torch_dist2custom(samples) # [B, T, num_samples, Dy]

    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        output = []
        C = xc.shape[1]  # Number of context points
        T = xt.shape[1]  # Number of target points
        x_full = torch.cat([xc, xt], dim=1)
        y_full = torch.cat([yc, yt], dim=1)

        for t in range(xt.shape[1]):
            batch = DataAttr(
                xc=x_full[:, :C+t, :],
                yc=y_full[:, :C+t, :],
                xt=x_full[:, C+t:C+t+1, :],
                yt=y_full[:, C+t:C+t+1, :]
            )
            torch.cuda.empty_cache()

            zt = self.encode(batch) # [B, 1, D_model]
            out = self.predictor(zt)
            mean, std = torch.chunk( out, 2, dim=-1 )

            if self.bound_std:
                std = 0.05 + 0.95 * F.softplus(std)
            else:
                std = torch.exp(std)

            pred_tar = Normal(mean, std)
            tar_ll = pred_tar.log_prob(batch.yt).sum(-1) # [B, 1]

            output.append(tar_ll)

        tar_ll = torch.cat(output, dim=-1)  # [B, T]
        return tar_ll.sum(-1)

    def eval_log_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at individual target point.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], sum_i log p( [yt]_i | [xt]_i, yc, xc )
        """
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=yt
        )
        ll = self.forward(batch).log_likelihood
        return ll.sum(-1)
