"""Transformer Neural Process (TNP) base implementation.

This is a reimplementation of TNP from github.com/tung-nd/TNP-pytorch
that is compatible with our DataAttr format.
"""

import torch
import torch.nn as nn
from typing import Optional
import abc

from .modules import TNPEmbedder
from src.utils import DataAttr, LossAttr

class TNP(abc.ABC, nn.Module):
    """Transformer Neural Process base class."""
    _support_non_ar_joint: bool=False # whether the model can do joint samples without ar & permutation

    @property
    def support_non_ar_joint(self):
        return self._support_non_ar_joint
    
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        bound_std: bool = False,
        pos_emb_init: bool = False,
    ):
        super().__init__()
        
        # Use TNP's original embedder
        self.embedder = TNPEmbedder(dim_x, dim_y, dim_feedforward, d_model, emb_depth, pos_emb_init=pos_emb_init)
        
        # Use PyTorch's transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.bound_std = bound_std

    @abc.abstractmethod # force subclasses to implement
    def create_mask(self, batch: DataAttr, device: str) -> torch.Tensor:
        """create attention mask for the transformer encode."""
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def encode(self, batch: DataAttr) -> torch.Tensor:
        """
        Encode context and target points.
        
        Args:
            batch: DataAttr containing context and target data, in particular, the shape of xt is [B, Nt, Dx].
        Returns:
            Encoded tensor of shape [B, Nt, d_model], 
        """
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """
        Forward pass through TNP.
        
        :param batch: DataAttr containing context, buffer and target data.
        :param reduce_ll: Whether to reduce log likelihood across batch.

        :return: LossAttr containing loss and predictions.
        """
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: Optional[int] = None,
        return_samples: bool = False,
    ) -> torch.Tensor:
        """
        Predict target values.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        raise NotImplementedError

    def sample(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        """Sample from the model (convenience method).
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy]
        """
        return self.predict(xc, yc, xt, num_samples, return_samples=True)

    @abc.abstractmethod # force subclasses to implement
    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        """
        Sample joint predictive distribution.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            
        Returns:
            Samples [B, Nt, num_samples, Dy]
        """
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        raise NotImplementedError

class SampleReshaper:
    """reshape distribution batch dimension to [B, Nt, num_samples, Dy]"""

    @staticmethod
    def torch_dist2custom(samples: torch.Tensor) -> torch.Tensor:
        """
        Reshape samples from [num_samples, *batch_shape, Dy] to [*batch_shape, num_samples, Dy].

        Args:
            samples: Tensor of shape [num_samples, *batch_shape, Dy]
        Returns:
            Tensor of shape [*batch_shape, num_samples, Dy]
        """
        batch_shape = samples.shape[1:-1]  # Exclude num_samples and Dy
        num_samples = samples.shape[0]
        dim_samples = samples.shape[-1]

        out = samples.view( num_samples, -1, dim_samples )
        out = out.permute(1, 0, 2)
        out = out.view( *batch_shape, num_samples, dim_samples )

        return out

    @staticmethod
    def custom2torch_dist(samples: torch.Tensor) -> torch.Tensor:
        """
        Reshape samples from [*batch_shape, num_samples, Dy] to [num_samples ,*batch_shape, Dy].

        Args:
            samples: Tensor of shape [*batch_shape, num_samples, Dy]
        Returns:
            Tensor of shape [num_samples, *batch_shape, Dy]
        """
        batch_shape = samples.shape[:-2]
        num_samples = samples.shape[-2]
        dim_samples = samples.shape[-1]

        out = samples.view( -1, num_samples, dim_samples )
        out = out.permute(1, 0, 2)
        out = out.view( num_samples, *batch_shape, dim_samples )

        return out

def create_mask(batch: DataAttr, device: str, autoreg=False):
    num_ctx = batch.xc.shape[-2]
    num_tar = batch.xt.shape[-2]
    num_all = num_ctx + num_tar
    if not autoreg:
        mask = torch.zeros(num_all, num_all, device=device).fill_(float('-inf'))
        mask[:, :num_ctx] = 0.0
    else:
        mask = torch.zeros((num_all+num_tar, num_all+num_tar), device=device).fill_(float('-inf'))
        mask[:, :num_ctx] = 0.0 # all points attend to context points
        mask[num_ctx:num_all, num_ctx:num_all].triu_(diagonal=1) # each real target point attends to itself and precedding real target points
        mask[num_all:, num_ctx:num_all].triu_(diagonal=0) # each fake target point attends to preceeding real target points

    return mask
