"""
TNP, PFN embedders, adapted from their original implementations.
TNP implementation: github.com/tung-nd/TNP-pytorch
PFN implementation: github.com/automl/TransformersCanDoBayesianInference
                    github.com/automl/PFNs
"""
import torch
from torch import nn
from src.utils import DataAttr
from src.models.utils import positional_encoding_init

### TNP Embedder
###
def build_mlp(d_in, d_hid, d_out, depth):
    """Build a simple MLP."""
    modules = [nn.Linear(d_in, d_hid), nn.ReLU()]
    for _ in range(depth - 2):
        modules.extend([nn.Linear(d_hid, d_hid), nn.ReLU()])
    modules.append(nn.Linear(d_hid, d_out))
    return nn.Sequential(*modules)


class TNPEmbedder(nn.Module):
    """TNP embedder with additional context/target markers."""
    
    def __init__(self, dim_x, dim_y, d_hidden, d_model, emb_depth, pos_emb_init: bool = False):
        super().__init__()
        self.marker_lookup = {"target": 0, "context": 1, "buffer": 2}
        self.enc = build_mlp(dim_x + dim_y, d_hidden, d_model, emb_depth)
        self.pos_enc = build_mlp(dim_x, d_hidden, d_model, emb_depth)
        self.marker_embed = torch.nn.Embedding(3, d_model)

        if pos_emb_init:
            self.marker_embed.weight = positional_encoding_init(3, d_model, 2)

    def _get_marker_embedding(
        self, 
        batch_size: int,
        marker_type: str, 
        device: torch.device
     ) -> torch.Tensor:
        """Get marker embedding for the specified type."""
        marker = self.marker_lookup[marker_type]
        marker_idx = torch.full(
            (batch_size, 1), marker, dtype=torch.long, device=device
        )
        return self.marker_embed(marker_idx)

    def embed_context(self, batch: DataAttr) -> torch.Tensor:
        """Embed context pairs (xc, yc) with context marker."""
        marker_emb = self._get_marker_embedding(batch.xc.size(0), "context", batch.xc.device)
        return self.enc(torch.cat([batch.xc, batch.yc], dim=-1)) + marker_emb

    def embed_buffer(self, batch: DataAttr) -> torch.Tensor:
        """Embed buffer pairs (xb, yb) with buffer marker."""
        raise ValueError("TNP does not use buffer points.")
        marker_emb = self._get_marker_embedding(batch.xb.size(0), "buffer", batch.xb.device)
        return self.enc(torch.cat([batch.xb, batch.yb], dim=-1)) + marker_emb

    def embed_target(self, batch: DataAttr) -> torch.Tensor:
        """Embed target inputs (xt) with target marker."""
        marker_emb = self._get_marker_embedding(batch.xt.size(0), "target", batch.xt.device)
        if batch.yt is None:
            return self.pos_enc(batch.xt) + marker_emb
        else:
            return self.enc(torch.cat([batch.xt, batch.yt], dim=-1)) + marker_emb

    def forward(self, x, y):
        # Concatenate x and y for encoding
        if y is not None:
            return self.enc(torch.cat([x, y], dim=-1))
        else:
            return self.pos_enc(x)


### PFN Embedder
###
class PFNv1Embedder(nn.Module):
    """ PFN embedder with additional context/target markers."""
    
    def __init__(self, dim_x, dim_y, d_model, pos_emb_init: bool = False):
        super().__init__()
        self.marker_lookup = {"target": 0, "context": 1, "buffer": 2}
        self.enc_x = nn.Linear(dim_x, d_model)
        self.enc_y = nn.Linear(dim_y, d_model)
        self.marker_embed = torch.nn.Embedding(3, d_model)

        if pos_emb_init:
            self.marker_embed.weight = positional_encoding_init(3, d_model, 2)

    def _get_marker_embedding(
        self, 
        batch_size: int,
        marker_type: str, 
        device: torch.device
     ) -> torch.Tensor:
        """Get marker embedding for the specified type."""
        marker = self.marker_lookup[marker_type]
        marker_idx = torch.full(
            (batch_size, 1), marker, dtype=torch.long, device=device
        )
        return self.marker_embed(marker_idx)

    def embed_context(self, batch: DataAttr) -> torch.Tensor:
        """Embed context pairs (xc, yc) with context marker."""
        x_emb = self.enc_x(batch.xc)
        y_emb = self.enc_y(batch.yc)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "context", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_buffer(self, batch: DataAttr) -> torch.Tensor:
        """Embed buffer pairs (xb, yb) with buffer marker."""
        raise ValueError("PFN does not use buffer points.")
        x_emb = self.enc_x(batch.xb)
        y_emb = self.enc_y(batch.yb)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "buffer", x_emb.device)
        return x_emb + y_emb + marker_emb

    def embed_target(self, batch: DataAttr) -> torch.Tensor:
        """Embed target inputs (xt) with target marker."""
        x_emb = self.enc_x(batch.xt)
        marker_emb = self._get_marker_embedding(x_emb.size(0), "target", x_emb.device)
        return x_emb + marker_emb

    def forward(self, x, y):
        # pfn uses x encoder and y encoder, and then sum the embeddings
        if y is not None:
            return self.enc_x(x) + self.enc_y(y)
        else:
            return self.enc_x(x)
