"""
Diagonal Transformer Neural Process (TNPD) with our Head.
"""

import torch

from src.models.benchmarks.tnp import TNP, SampleReshaper
from src.models.benchmarks.tnpd import TNPD
from src.models.modules import NeuralProcessHead
from src.utils import DataAttr, LossAttr


class TNPDFlexHead(TNPD):
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        head: NeuralProcessHead,
        pos_emb_init: bool = False,
    ):
        TNP.__init__(
            self,
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std=False, pos_emb_init=pos_emb_init
        )
        
        self.head = head

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        z_target = self.encode(batch)
        return self.head(z_target, yt=batch.yt, loss_mask=None)

    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Make predictions at target locations.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or LossAttr containing MixtureGaussian parameters
        """
        batch_size = xc.shape[0]
        dim_y = yc.shape[-1]
        num_target = xt.shape[1]

        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros((batch_size, num_target, dim_y), device=xt.device, dtype=xt.dtype) # Dummy targets
        )

        z_target = self.encode(batch)

        if return_samples:
            assert num_samples > 0, "num_samples must be greater than 0 for sampling."
            return self.head.sample(z_target, num_samples=num_samples)
        else:
            return self.head(z_target, yt=None, loss_mask=None)

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        assert num_samples > 0, "num_samples must be greater than 0 for sampling."

        B = xc.shape[0]
        Dx = xc.shape[-1]
        Dy = yc.shape[-1]
        C = xc.shape[1]
        T = xt.shape[1]

        output = []

        xc_stacked = xc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, xc.shape[-2], Dx) # [num_samples*B, Nc, Dx]
        yc_stacked = yc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, yc.shape[-2], Dy) # [num_samples*B, Nc, Dy=1]
        xt_stacked = xt.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, T, Dx) # [num_samples*B, T, Dx]
        yt_stacked = torch.zeros(num_samples*B, T, Dy, device=xt.device, dtype=xt.dtype) # [num_samples*B, T, Dy]

        x_stacked = torch.cat([xc_stacked, xt_stacked], dim=1)
        y_stacked = torch.cat([yc_stacked, yt_stacked], dim=1)

        for t in range(T):
            batch = DataAttr(
                xc=x_stacked[:, :C+t, :],
                yc=y_stacked[:, :C+t, :],
                xt=x_stacked[:, C+t:C+t+1, :],
                yt=y_stacked[:, C+t:C+t+1, :],
            )
            torch.cuda.empty_cache()
            zt_stacked = self.encode(batch) # [num_samples*B, 1, D_model]
            yhat = self.head.sample(zt_stacked, num_samples=1)
            yhat = yhat.squeeze(-2)  # [num_samples*B, 1, Dy]

            output.append(yhat.view(num_samples, B, 1, Dy))

            if t < T - 1:
                y_stacked[:, C+t, :] = yhat.squeeze(1)

        samples = torch.cat(output, dim=-2) # [num_samples, B, T, Dy]
        return SampleReshaper.torch_dist2custom(samples) # [B, T, num_samples, Dy]

    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        output = []
        C = xc.shape[1]  # Number of context points
        T = xt.shape[1]  # Number of target points
        x_full = torch.cat([xc, xt], dim=1)
        y_full = torch.cat([yc, yt], dim=1)

        for t in range(xt.shape[1]):
            batch = DataAttr(
                xc=x_full[:, :C+t, :],
                yc=y_full[:, :C+t, :],
                xt=x_full[:, C+t:C+t+1, :],
                yt=y_full[:, C+t:C+t+1, :]
            )
            torch.cuda.empty_cache()

            zt = self.encode(batch) # [B, 1, D_model]
            tar_ll = self.head.log_likelihood(zt, batch.yt) # [B, 1] 

            output.append(tar_ll)

        tar_ll = torch.cat(output, dim=-1)  # [B, T]
        return tar_ll.sum(-1)

    def eval_log_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at individual target point.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], sum_i log p( [yt]_i | [xt]_i, yc, xc )
        """
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=yt
        )
        z_target = self.encode(batch) # [B, Nt, d_model]
        ll = self.head.log_likelihood(z_target, batch.yt) # [B, Nt]
        return ll.sum(-1)
