"""
Transformer Neural Process with N-dimensional output (TNPND).
Heavily borrowed from `github.com/tung-nd/TNP-pytorch/`
"""

import math
import torch
import torch.nn as nn
from torch.distributions import Normal, MultivariateNormal
from torch.distributions.multivariate_normal import _batch_mahalanobis

from src.models.benchmarks.tnp import TNP, SampleReshaper, create_mask
from src.models.benchmarks.modules import build_mlp
from src.enums.model_enums import CovApprox
from src.utils import DataAttr, LossAttr


class MixedPrecisionMultivariateNormal(MultivariateNormal):
    """
    Adapt torch.distributions.MultivariateNormal for mixed precision solve_triangular.
    Most of the computation borrowed from the original code.
    """
    _support_non_ar_joint: bool=True # whether the model can do joint samples without ar & permutation
    def __init__(
        self,
        loc,
        covariance_matrix=None,
        precision_matrix=None,
        scale_tril=None,
        validate_args=None,
    ):

        if precision_matrix:
            raise NotImplementedError

        super().__init__(loc, covariance_matrix=covariance_matrix, scale_tril=scale_tril, validate_args=validate_args)

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        diff = value - self.loc
        # convert type, so amp is avoided
        amp_dtype = diff.dtype
        M = _batch_mahalanobis(self._unbroadcasted_scale_tril.float(), diff.float()).to(amp_dtype)
        # 
        half_log_det = (
            self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
        )
        return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det


class TNPND(TNP):
    """TNP with multivariate normal output distribution."""
    _support_non_ar_joint: bool=True # whether the model can do joint samples without ar & permutation
    
    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        d_model: int,
        emb_depth: int,
        dim_feedforward: int,
        nhead: int,
        dropout: float,
        num_layers: int,
        num_std_layers: int,
        bound_std: bool = False,
        cov_approx: CovApprox = CovApprox.CHOLESKY,
        prj_dim: int = 5,
        prj_depth: int = 4,
        diag_depth: int = 4,
        pos_emb_init: bool = False,
    ):
        super().__init__(
            dim_x, dim_y, d_model, emb_depth,
            dim_feedforward, nhead, dropout, num_layers, bound_std=bound_std, pos_emb_init=pos_emb_init
        )
        
        self.cov_approx = cov_approx
        
        # Mean prediction network
        self.mean_net = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y)
        )
        
        # Covariance prediction network
        std_encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, batch_first=True
        )
        self.std_encoder = nn.TransformerEncoder(std_encoder_layer, num_std_layers)
        
        # Project to lower dimensional space for covariance
        self.projector = build_mlp(
            d_model, dim_feedforward, prj_dim * dim_y, prj_depth
        )
        
        if cov_approx == CovApprox.LOWRANK:
            self.diag_net = build_mlp(
                d_model, dim_feedforward, dim_y, diag_depth
            )

    def create_mask(self, batch, device: str):
        mask = create_mask(batch, device=device, autoreg=False)
        return mask

    def encode(self, batch: DataAttr) -> torch.Tensor:
        """Encode context and target points."""
        # Embed context
        xc_enc = self.embedder.embed_context(batch)
        # Embed targets (without y values)
        xt_0_enc = self.embedder.embed_target(
            DataAttr(xt=batch.xt, yt=torch.zeros_like(batch.yt))
        )

        # Concatenate all embeddings
        encoder_input = torch.cat([xc_enc, xt_0_enc], dim=-2)

        # Pass through transformer
        mask = self.create_mask(batch, device=encoder_input.device)
        out = self.encoder(encoder_input, mask=mask)
        num_targets = batch.xt.shape[-2]

        # Return only target encodings
        return out[:, -num_targets:, :]

    def decode(
        self, 
        out_encoder: torch.Tensor,
        batch_size: int,
        dim_y: int,
        num_target: int
    ) -> MultivariateNormal:
        """Decode to multivariate normal distribution."""
        # Predict mean
        mean = self.mean_net(out_encoder).view(batch_size, -1)
        
        # Predict covariance
        out_std_encoder = self.std_encoder(out_encoder)
        std_prj = self.projector(out_std_encoder)
        std_prj = std_prj.view((batch_size, num_target * dim_y, -1))
        
        if self.cov_approx == CovApprox.CHOLESKY:
            # Cholesky decomposition approach
            std_tril = torch.bmm(std_prj, std_prj.transpose(1, 2))
            std_tril = std_tril.tril()
            
            if self.bound_std:
                diag_ids = torch.arange(num_target * dim_y, device=std_tril.device)
                std_tril[:, diag_ids, diag_ids] = 0.05 + 0.95 * torch.tanh(
                    std_tril[:, diag_ids, diag_ids]
                )
            
            pred_tar = MixedPrecisionMultivariateNormal(mean, scale_tril=std_tril)
        elif self.cov_approx == CovApprox.LOWRANK:
            # Low-rank plus diagonal approach
            diagonal = torch.exp(self.diag_net(out_encoder)).view((batch_size, -1, 1))
            std = torch.bmm(std_prj, std_prj.transpose(1, 2)) + torch.diag_embed(
                diagonal.squeeze(-1)
            )
            pred_tar = MixedPrecisionMultivariateNormal(mean, covariance_matrix=std)
        else:
            raise NotImplementedError
        
        return pred_tar
    
    def forward(self, batch: DataAttr, reduce_ll: bool = True) -> LossAttr:
        """Forward pass through TNPND."""
        batch_size = batch.xc.shape[0]
        dim_y = batch.yc.shape[-1]
        num_target = batch.xt.shape[1]
        
        # Encode
        out_encoder = self.encode(batch)
        
        # Decode to distribution
        pred_tar = self.decode(out_encoder, batch_size, dim_y, num_target)
        
        # Compute log-likelihood
        tar_ll = pred_tar.log_prob(batch.yt.reshape(batch_size, -1))
        tar_ll /= (num_target * dim_y) # normalize

        # Package outputs
        if reduce_ll:
            loss = -tar_ll.mean()
            mean_std = torch.mean(torch.diagonal(
                pred_tar.covariance_matrix, dim1=-2, dim2=-1
            ))
        else:
            loss = -tar_ll
            mean_std = None
        
        return LossAttr(
            loss=loss,
            log_likelihood=tar_ll,
            mean_std=mean_std
        )
    
    def predict(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50,
        return_samples: bool = False
    ) -> torch.Tensor:
        """Make predictions at target locations.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            num_samples: Number of samples to generate
            return_samples: If True, return samples; else return distribution
            
        Returns:
            Samples [B, Nt, num_samples, Dy] or Normal distribution
        """
        batch_size = xc.shape[0]
        dim_y = yc.shape[-1]
        num_target = xt.shape[1]
        
        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros((batch_size, num_target, dim_y), device=xt.device, dtype=xt.dtype) # Dummy targets
        )
        
        # Encode and decode
        out_encoder = self.encode(batch)
        pred_tar = self.decode(out_encoder, batch_size, dim_y, num_target)
        
        # Sample from distribution
        yt_samples = pred_tar.rsample([num_samples]).view(
            num_samples, batch_size, num_target, -1
        ) # [num_samples, B, Nt, Dy]
        
        if return_samples:
            return SampleReshaper.torch_dist2custom(yt_samples) # [B, Nt, num_samples, Dy]
        
        # Return mean and std
        std = yt_samples.std(dim=0)
        return Normal(
            pred_tar.mean.view(batch_size, num_target, -1),
            std
        )

    def sample_joint_predictive(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        num_samples: int = 50
    ) -> torch.Tensor:
        return self.predict(xc, yc, xt, num_samples=num_samples, return_samples=True) # [B, Nt, num_samples, Dy]

    def eval_log_joint_likelihood(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate log likelihood at all target points jointly.
        
        Args:
            xc: Context inputs [B, Nc, Dx]
            yc: Context outputs [B, Nc, Dy]
            xt: Target inputs [B, Nt, Dx]
            yt: Target outputs [B, Nt, Dy]
            
        Returns:
            Samples [B], log p(yt|xt, yc, xc)
        """
        batch_size = xc.shape[0]
        dim_y = yc.shape[-1]
        num_target = xt.shape[1]
        
        # Create batch
        batch = DataAttr(
            xc=xc,
            yc=yc,
            xt=xt,
            yt=torch.zeros_like(yt)  # Dummy targets
        )
        
        # Encode and decode
        out_encoder = self.encode(batch)
        pred_tar = self.decode(out_encoder, batch_size, dim_y, num_target)
        # compute log likelihood
        tar_ll = pred_tar.log_prob(yt.reshape(batch_size, -1)) # [B]
        return tar_ll

    def eval_log_likelihood(self, *args, **kwargs):
        raise NotImplementedError("TNP-ND does not support independent target evaluation")
