"""
Transformer Neural Process - Autoregressive (TNPA) with our Head.
"""

import torch

from src.models.benchmarks.tnp import TNP, SampleReshaper
from src.models.benchmarks.tnpa import TNPA
from src.models.modules import NeuralProcessHead
from src.utils import DataAttr, LossAttr


class TNPAFlexHead(TNPA):
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        head: NeuralProcessHead,
        permute: bool = False,
        pos_emb_init: bool = False,
    ):
        TNP.__init__(
            self,
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std=False, pos_emb_init=pos_emb_init
        )
        
        self.head = head
        
        self.permute = permute

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """Forward pass with autoregressive encoding."""
        z_target = self.encode(batch)
        return self.head(z_target, yt=batch.yt, loss_mask=None)
    
    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Generate predictions autoregressively.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        batch_size = xc.shape[0]
        num_target = xt.shape[1]
        dim_y = yc.shape[-1]
        device = xc.device
        permute = self.permute and num_samples > 1 and num_target > 1

        # Stack samples for parallel generation
        xc_stacked = self._stack(xc, num_samples)  # [num_samples, B, Nc, Dx]
        yc_stacked = self._stack(yc, num_samples)  # [num_samples, B, Nc, Dy]
        xt_stacked = self._stack(xt, num_samples)  # [num_samples, B, Nt, Dx]

        # Initialize target predictions with zeros
        yt_pred = torch.zeros((batch_size, num_target, dim_y), device=device, dtype=xt.dtype)
        yt_stacked = self._stack(yt_pred, num_samples)  # [num_samples, B, Nt, Dy]

        # Optionally permute target order
        if permute:
            xt_stacked, yt_stacked, perm_info = self._permute_targets(
                xt_stacked, yt_stacked, num_samples, batch_size, num_target
            )

        # Reshape for batched processing
        def squeeze(x):
            return x.view(-1, x.shape[-2], x.shape[-1]).contiguous()

        def unsqueeze(x):
            return x.view(num_samples, batch_size, x.shape[-2], x.shape[-1]).contiguous()

        # Create batch for autoregressive generation
        batch_stacked = DataAttr(
            xc=squeeze(xc_stacked),  # [num_samples*B, Nc, Dx]
            yc=squeeze(yc_stacked),  # [num_samples*B, Nc, Dy]
            xt=squeeze(xt_stacked),  # [num_samples*B, Nt, Dx]
            yt=squeeze(yt_stacked)   # [num_samples*B, Nt, Dy]
        )

        if not return_samples:
            raise NotImplementedError("return_samples=False not implemented for TNPAMG.")

        # Autoregressively generate each target
        for step in range(num_target):
            #  TNP-A attends to context and targets, no need to explicitly put new targets into context
            batch_stacked_step = DataAttr(
                xc=batch_stacked.xc, # [num_samples*B, Nc, Dx]
                yc=batch_stacked.yc, # [num_samples*B, Nc, Dy]
                xt=batch_stacked.xt[:, :step+1, :], # [num_samples*B, T=step+1, Dx]
                yt=batch_stacked.yt[:, :step+1, :]  # [num_samples*B, T=step+1, Dy], last is zero
            )
            torch.cuda.empty_cache()
            # Encode with autoregressive mask
            z_target_stacked = self.encode(batch_stacked_step)  # [num_samples*B, T=step+1, D_model]

            # Predict for current step
            yhat = self.head.sample(z_target_stacked[:, -1:, :], num_samples=1)  # [num_samples*B, t=1, 1, Dy]
            yhat = yhat.squeeze(-2)  # [num_samples*B, T=1, Dy]

            # Sample for current position
            batch_stacked.yt[:, step, :] = yhat[:, 0, :]

        yhat = unsqueeze(batch_stacked.yt)
        # Unpermute if necessary
        if permute:
            yhat = self._unpermute_targets(
                yhat, None, perm_info
            )[0]

        if return_samples:
            # Unsqueeze output to [num_samples, B, Nt, Dy]
            return SampleReshaper.torch_dist2custom(yhat) # [B, Nt, num_samples, Dy]

        raise NotImplementedError("return_samples=False not implemented for TNPAMG.")

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        return self.predict(xc, yc, xt, num_samples=num_samples, return_samples=True) # [B, Nt, num_samples, Dy]

    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=yt
        )

        zt = self.encode(batch) # [B, T, D_model]
        tar_ll = self.head.log_likelihood(zt, batch.yt) # [B, T]
        return tar_ll.sum(-1)
