"""
Transformer Neural Process - Autoregressive (TNPA).
Based on github.com/tung-nd/TNP-pytorch/
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

from src.models.benchmarks.tnp import TNP, SampleReshaper, create_mask
from src.utils import DataAttr, LossAttr


class TNPA(TNP):
    """TNP with autoregressive decoding."""
    
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        bound_std: bool = False,
        permute: bool = False,
    ):
        super().__init__(
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std
        )
        
        # Predictor outputs mean and std
        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y * 2)
        )
        
        self.permute = permute

    def create_mask(self, batch, device: str):
        mask = create_mask(batch, device=device, autoreg=True)
        return mask

    def encode(self, batch: DataAttr) -> torch.Tensor:
        """Encode context and target points."""
        # Embed context
        xc_enc = self.embedder(batch.xc, batch.yc)
        # Embed targets (without y values)
        xt_0_enc = self.embedder(batch.xt, torch.zeros_like(batch.yt))
        # Embed targets (with y values)
        if self.training and self.bound_std:
            yt_noise = batch.yt + 0.05 * torch.randn_like(batch.yt)
        else:
            yt_noise = batch.yt
        xt_enc = self.embedder(batch.xt, yt_noise)

        # Concatenate all embeddings
        encoder_input = torch.cat([xc_enc, xt_enc, xt_0_enc], dim=-2)

        # Pass through transformer
        mask = self.create_mask(batch, device=encoder_input.device)
        out = self.encoder(encoder_input, mask=mask)
        num_targets = batch.xt.shape[-2]

        # Return only target encodings
        return out[:, -num_targets:, :]

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """Forward pass with autoregressive encoding."""
        # Encode with autoregressive mask
        z_target = self.encode(batch)
        
        # Predict mean and std
        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        
        if self.bound_std:
            std = 0.05 + 0.95 * F.softplus(std)
        else:
            std = torch.exp(std)
        
        # Create distribution
        pred_tar = Normal(mean, std)
        
        # Compute log likelihood
        tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
        
        if reduce_ll:
            loss = -tar_ll.mean()
        else:
            loss = -tar_ll
        
        return LossAttr(
            loss=loss,
            log_likelihood=tar_ll,
            mean=mean,
            std=std
        )
    
    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Generate predictions autoregressively.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        batch_size = xc.shape[0]
        num_target = xt.shape[1]
        dim_y = yc.shape[-1]
        device = xc.device
        
        # Stack samples for parallel generation
        xc_stacked = self._stack(xc, num_samples)
        yc_stacked = self._stack(yc, num_samples)
        xt_stacked = self._stack(xt, num_samples)
        
        # Initialize target predictions with zeros
        yt_pred = torch.zeros((batch_size, num_target, dim_y), device=device)
        yt_stacked = self._stack(yt_pred, num_samples)
        
        # Optionally permute target order
        if self.permute:
            xt_stacked, yt_stacked, perm_info = self._permute_targets(
                xt_stacked, yt_stacked, num_samples, batch_size, num_target
            )
        
        # Reshape for batched processing
        def squeeze(x):
            return x.view(-1, x.shape[-2], x.shape[-1])
        
        def unsqueeze(x):
            return x.view(num_samples, batch_size, x.shape[-2], x.shape[-1])
        
        # Create batch for autoregressive generation
        batch_stacked = DataAttr(
            xc=squeeze(xc_stacked),
            yc=squeeze(yc_stacked),
            xt=squeeze(xt_stacked),
            yt=squeeze(yt_stacked)
        )
        
        # Autoregressively generate each target
        for step in range(num_target):
            # Encode with autoregressive mask
            z_target_stacked = self.encode(batch_stacked)
            
            # Predict for current step
            out = self.predictor(z_target_stacked)
            mean, std = torch.chunk(out, 2, dim=-1)
            
            if self.bound_std:
                std = 0.05 + 0.95 * F.softplus(std)
            else:
                std = torch.exp(std)
            
            # Reshape predictions
            mean, std = unsqueeze(mean), unsqueeze(std)
            batch_stacked.yt = unsqueeze(batch_stacked.yt)
            
            # Sample for current position
            batch_stacked.yt[:, :, step] = Normal(
                mean[:, :, step], std[:, :, step]
            ).sample()
            
            # Flatten again for next iteration
            batch_stacked.yt = squeeze(batch_stacked.yt)
        
        # Unpermute if necessary
        if self.permute:
            mean, std = self._unpermute_targets(mean, std, perm_info)
            batch_stacked.yt = unsqueeze(batch_stacked.yt)
            batch_stacked.yt = self._unpermute_targets(
                batch_stacked.yt, None, perm_info
            )[0]
        
        if return_samples:
            # Unsqueeze output to [num_samples, B, Nt, Dy]
            out = unsqueeze(batch_stacked.yt) if not self.permute else batch_stacked.yt
            return SampleReshaper.torch_dist2custom(out) # [B, Nt, num_samples, Dy]

        # Return distribution with mean across samples
        samples = unsqueeze(batch_stacked.yt) if not self.permute else batch_stacked.yt
        sample_mean = samples.mean(dim=0)
        sample_std = samples.std(dim=0)
        
        return Normal(sample_mean, sample_std)

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        return self.predict(xc, yc, xt, num_samples=num_samples, return_samples=True) # [B, Nt, num_samples, Dy]

    def _stack(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:
        """Stack tensor for parallel sampling."""
        return x.unsqueeze(0).repeat(num_samples, 1, 1, 1)

    def _permute_targets(
        self,
        xt: torch.Tensor,
        yt: torch.Tensor,
        num_samples: int,
        batch_size: int,
        num_target: int
    ) -> tuple:
        """Permute target order for each sample."""
        device = xt.device
        
        # Generate random permutations (same permutation within each batch)
        perm_ids = torch.rand(
            num_samples, num_target, device=device
        ).unsqueeze(1).repeat(1, batch_size, 1)
        perm_ids = torch.argsort(perm_ids, dim=-1)
        
        # Also compute inverse permutation for unpermuting later
        deperm_ids = torch.argsort(perm_ids, dim=-1)
        
        # Create indices for gathering
        dim_sample = torch.arange(num_samples, device=device).view(-1, 1, 1).expand(
            num_samples, batch_size, num_target
        )
        dim_batch = torch.arange(batch_size, device=device).view(1, -1, 1).expand(
            num_samples, batch_size, num_target
        )
        
        # Apply permutation
        xt_perm = xt[dim_sample, dim_batch, perm_ids]
        yt_perm = yt[dim_sample, dim_batch, perm_ids]
        
        perm_info = (dim_sample, dim_batch, deperm_ids)
        return xt_perm, yt_perm, perm_info
    
    def _unpermute_targets(
        self,
        tensor1: torch.Tensor,
        tensor2: torch.Tensor,
        perm_info: tuple
    ) -> tuple:
        """Unpermute targets back to original order."""
        dim_sample, dim_batch, deperm_ids = perm_info
        
        tensor1_unperm = tensor1[dim_sample, dim_batch, deperm_ids]
        
        if tensor2 is not None:
            tensor2_unperm = tensor2[dim_sample, dim_batch, deperm_ids]
            return tensor1_unperm, tensor2_unperm
        
        return tensor1_unperm, None
