"""Transformer Neural Process (TNP) base implementation.

This is a reimplementation of TNP from github.com/tung-nd/TNP-pytorch
that is compatible with our DataAttr format.
"""

import torch
import torch.nn as nn
from typing import Optional
import abc

from src.utils import DataAttr, LossAttr


def build_mlp(d_in, d_hid, d_out, depth):
    """Build a simple MLP."""
    modules = [nn.Linear(d_in, d_hid), nn.ReLU()]
    for _ in range(depth - 2):
        modules.extend([nn.Linear(d_hid, d_hid), nn.ReLU()])
    modules.append(nn.Linear(d_hid, d_out))
    return nn.Sequential(*modules)


class TNPEmbedder(nn.Module):
    """Original TNP embedder (different from our Embedder)."""
    
    def __init__(self, dim_x, dim_y, d_model, emb_depth):
        super().__init__()
        self.enc = build_mlp(dim_x + dim_y, d_model, d_model, emb_depth)
        self.pos_enc = build_mlp(dim_x, d_model, d_model, emb_depth)
        
    def forward(self, x, y):
        # Concatenate x and y for encoding
        if y is not None:
            return self.enc(torch.cat([x, y], dim=-1))
        else:
            return self.pos_enc(x)


class TNP(abc.ABC, nn.Module):
    """Transformer Neural Process base class."""
    
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        bound_std: bool = False
    ):
        super().__init__()
        
        # Use TNP's original embedder
        self.embedder = TNPEmbedder(dim_x, dim_y, d_model, emb_depth)
        
        # Use PyTorch's transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.bound_std = bound_std

    @abc.abstractmethod # force subclasses to implement
    def create_mask(self, batch: DataAttr, device: str) -> torch.Tensor:
        """create attention mask for the transformer encode."""
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def encode(self, batch: DataAttr) -> torch.Tensor:
        """
        Encode context and target points.
        
        Args:
            batch: DataAttr containing context and target data, in particular, the shape of xt is [B, Nt, Dx].
        Returns:
            Encoded tensor of shape [B, Nt, d_model], 
        """
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """
        Forward pass through TNP.
        
        :param batch: DataAttr containing context, buffer and target data.
        :param reduce_ll: Whether to reduce log likelihood across batch.

        :return: LossAttr containing loss and predictions.
        """
        raise NotImplementedError

    @abc.abstractmethod # force subclasses to implement
    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: Optional[int] = None,
        return_samples: bool = False,
    ) -> torch.Tensor:
        """
        Predict target values.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        raise NotImplementedError

    def sample(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        """Sample from the model (convenience method).
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy]
        """
        return self.predict(xc, yc, xt, num_samples, return_samples=True)

    @abc.abstractmethod # force subclasses to implement
    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        """
        Sample joint predictive distribution.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            
        Returns:
            Samples [B, Nt, num_samples, Dy]
        """
        raise NotImplementedError

class SampleReshaper:
    """reshape distribution batch dimension to [B, Nt, num_samples, Dy]"""

    @staticmethod
    def torch_dist2custom(samples: torch.Tensor) -> torch.Tensor:
        """
        Reshape samples from [num_samples, *batch_shape, Dy] to [*batch_shape, num_samples, Dy].

        Args:
            samples: Tensor of shape [num_samples, *batch_shape, Dy]
        Returns:
            Tensor of shape [*batch_shape, num_samples, Dy]
        """
        batch_shape = samples.shape[1:-1]  # Exclude num_samples and Dy
        num_samples = samples.shape[0]
        dim_samples = samples.shape[-1]

        out = samples.view( num_samples, -1, dim_samples )
        out = out.permute(1, 0, 2)
        out = out.view( *batch_shape, num_samples, dim_samples )

        return out

    @staticmethod
    def custom2torch_dist(samples: torch.Tensor) -> torch.Tensor:
        """
        Reshape samples from [*batch_shape, num_samples, Dy] to [num_samples ,*batch_shape, Dy].

        Args:
            samples: Tensor of shape [*batch_shape, num_samples, Dy]
        Returns:
            Tensor of shape [num_samples, *batch_shape, Dy]
        """
        batch_shape = samples.shape[:-2]
        num_samples = samples.shape[-2]
        dim_samples = samples.shape[-1]

        out = samples.view( -1, num_samples, dim_samples )
        out = out.permute(1, 0, 2)
        out = out.view( num_samples, *batch_shape, dim_samples )

        return out

def create_mask(batch: DataAttr, device: str, autoreg=False):
    num_ctx = batch.xc.shape[-2]
    num_tar = batch.xt.shape[-2]
    num_all = num_ctx + num_tar
    if not autoreg:
        mask = torch.zeros(num_all, num_all, device=device).fill_(float('-inf'))
        mask[:, :num_ctx] = 0.0
    else:
        mask = torch.zeros((num_all+num_tar, num_all+num_tar), device=device).fill_(float('-inf'))
        mask[:, :num_ctx] = 0.0 # all points attend to context points
        mask[num_ctx:num_all, num_ctx:num_all].triu_(diagonal=1) # each real target point attends to itself and precedding real target points
        mask[num_all:, num_ctx:num_all].triu_(diagonal=0) # each fake target point attends to preceeding real target points

    return mask
