"""
Prior-Data Fitted Networks (PFNs).
Adapted from
github.com/automl/TransformersCanDoBayesianInference
github.com/automl/PFNs

Mueller et al. ICLR 2022. TRANSFORMERS CAN DO BAYESIAN INFERENCE
Appendix E1:
 - x, y are encoded (Linear maps) and summed
 - no positional encoding
 - target predictions: outputs as positioned

"""

import logging
import torch
import torch.nn as nn
from typing import Optional

from .tnp import TNP, SampleReshaper
from .modules import (
    PFNv1Embedder,
    BarDistribution,
    FullSupportBarDistribution,
    LogitsKnownDistribution,
    get_bucket_borders,
)
from src.utils import DataAttr, LossAttr

logger = logging.getLogger(__name__)

class PFN(TNP):
    """PFN"""

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        head_num_buckets: int,
        head_bucket_samples: Optional[torch.Tensor] = None,
        pos_emb_init: bool = False,
    ):
        """
        Args:
            dim_x: Dimension of input x.
            dim_y: Dimension of input y. PFN supports only dim_y=1.
            d_model: embedding size.
            dim_feedforward: Dimension of hidden layers.
            nhead: Number of attention heads.
            dropout: Dropout rate.
            num_layers: Number of transformer encoder layers.
            head_num_buckets: Number of buckets for the head distribution.
            head_bucket_samples: Samples to determine each bucket border. This can be None, for example when we later load bucket borders (torch Module buffer) from trained models.
            pos_emb_init: Whether to use positional encoding initialization for embedding markers.
        """

        assert dim_y==1, f"PFN only supports dim_y=={dim_y}"

        super(TNP, self).__init__() # nn.Module.__init__

        self.embedder = PFNv1Embedder(dim_x, dim_y, d_model, pos_emb_init=pos_emb_init)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.decoder = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, head_num_buckets)
        )

        # Initialize bar distribution
        if head_bucket_samples is None:
            logger.info("Head bucket samples are not provided, using random samples.")
            bucket_borders = get_bucket_borders(head_num_buckets, full_range=(-2.0, 2.0))
            # the borders will be torch Module buffer (can be loaded)
        else:
            bucket_borders = get_bucket_borders(head_num_buckets, ys=head_bucket_samples)
        self.predictor = FullSupportBarDistribution(bucket_borders)

    def create_mask(self, batch: DataAttr, device: str) -> torch.Tensor:
        """Create attention mask for the transformer encoder."""
        nc = batch.xc.shape[-2]  # Number of context points
        nt = batch.xt.shape[-2]  # Number of target points

        mask = torch.concat([
            torch.concat([
                torch.ones([nc, nc], device=device), # context to context attention
                torch.zeros([nc, nt], device=device) # context to target attention
            ], dim=1),
            torch.concat([
                torch.ones([nt, nc], device=device), # target to context attention
                torch.eye(nt, device=device) # target self attention
            ], dim=1),
        ], dim=0)
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

        return mask

    def encode(self, batch: DataAttr) -> torch.Tensor:
        """Encode context and target points."""
        # Embed context
        xc_enc = self.embedder.embed_context(batch)
        # Embed targets (without y values)
        xt_enc = self.embedder.embed_target(batch)

        # Concatenate all embeddings
        encoder_input = torch.cat([xc_enc, xt_enc], dim=1)

        # Pass through transformer
        mask = self.create_mask(batch, device=encoder_input.device)
        out = self.encoder(encoder_input, mask=mask)

        # Return only target encodings
        num_targets = batch.xt.shape[1]
        return out[:, -num_targets:, :]

    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """
        Forward pass through the model.
        
        Args:
            batch: DataAttr containing context and target data.
            reduce_ll: If True, reduce log likelihood to a scalar.
        
        Returns:
            LossAttr containing loss and log likelihood of target data.
        """

        # Encode
        out_encoder = self.encode(batch) # [batch_size, num_target, d_model]

        # Decode to get logits for bar distribution
        logits = self.decoder(out_encoder) # [batch_size, num_target, num_buckets]

        # Compute loss and log likelihood
        return self.predictor(logits, batch.yt[..., 0], reduce_ll=reduce_ll)

    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Make predictions at target locations."""
        batch_size = xc.shape[0]
        num_target = xt.shape[1]

        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros_like(xt[..., :1])  # Dummy target for prediction
        )

        # Encode and decode
        out_encoder = self.encode(batch)
        logits = self.decoder(out_encoder)

        if return_samples:
            samples = self.predictor.sample(logits, num_samples=num_samples) # [B, Nt, num_samples]
            return samples.unsqueeze(-1)  # [B, Nt, num_samples, Dy=1]

        return LogitsKnownDistribution(self.predictor, logits)

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:

        B = xc.shape[0]
        Dx = xc.shape[-1]
        Dy = yc.shape[-1]
        C = xc.shape[1]
        T = xt.shape[1]
        assert Dy == 1, "PFN only supports dim_y=1"

        output = []

        xc_stacked = xc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, xc.shape[-2], Dx) # [num_samples*B, Nc, Dx]
        yc_stacked = yc.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, yc.shape[-2], Dy) # [num_samples*B, Nc, Dy=1]
        xt_stacked = xt.unsqueeze(0).repeat(num_samples, 1, 1, 1).view(-1, T, Dx) # [num_samples*B, T, Dx]
        yt_stacked = torch.zeros(num_samples*B, T, Dy, device=xt.device) # [num_samples*B, T, Dy]

        x_stacked = torch.cat([xc_stacked, xt_stacked], dim=1)
        y_stacked = torch.cat([yc_stacked, yt_stacked], dim=1)

        for t in range(T):
            batch = DataAttr(
                xc=x_stacked[:, :C+t, :],
                yc=y_stacked[:, :C+t, :],
                xt=x_stacked[:, C+t:C+t+1, :],
                yt=y_stacked[:, C+t:C+t+1, :],
            )
            torch.cuda.empty_cache()
            zt_stacked = self.encode(batch) # [num_samples*B, 1, D_model]
            logits = self.decoder(zt_stacked) # [num_samples*B, 1, num_buckets]

            # Sample without extra dimension
            yhat = self.predictor.sample(logits, num_samples=1) # [num_samples*B, 1, 1]

            output.append(yhat.view(num_samples, B, 1, Dy))

            if t < T - 1:
                y_stacked[:, C+t, :] = yhat.squeeze(1)

        samples = torch.cat(output, dim=-2) # [num_samples, B, T, Dy]
        return SampleReshaper.torch_dist2custom(samples) # [B, T, num_samples, Dy]

    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        output = []
        C = xc.shape[1]  # Number of context points
        T = xt.shape[1]  # Number of target points
        x_full = torch.cat([xc, xt], dim=1)
        y_full = torch.cat([yc, yt], dim=1)

        for t in range(T):
            batch = DataAttr(
                xc=x_full[:, :C+t, :],
                yc=y_full[:, :C+t, :],
                xt=x_full[:, C+t:C+t+1, :],
                yt=y_full[:, C+t:C+t+1, :]
            )
            torch.cuda.empty_cache()

            out_encoder = self.encode(batch) # [B, 1, D_model]
            logits = self.decoder(out_encoder)
            tar_ll = self.predictor(logits, batch.yt[..., 0]).log_likelihood # [B, 1]

            output.append(tar_ll)

        tar_ll = torch.cat(output, dim=-1)  # [B, T]
        return tar_ll.sum(-1)

    def eval_log_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at individual target point.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], sum_i log p( [yt]_i | [xt]_i, yc, xc )
        """
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=yt
        )
        ll = self.forward(batch).log_likelihood.squeeze(-1) # [B, Nt]
        return ll.sum(-1)
