"""Head modules used to build ACEv2 model."""

import abc
import math
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch import nn
from torch.distributions import Categorical, Normal
from torch.nn import functional as F

from src.utils import DataAttr, LossAttr
from src.models.benchmarks.modules.pfn_head import get_bucket_borders, FullSupportBarDistribution


class Head(nn.Module):
    """Multi-component neural network head with separate MLPs."""

    def __init__(
        self, dim_model: int, dim_feedforward: int, dim_out: int, K: int
    ) -> None:
        super().__init__()
        self.K = K

        # weight shapes → [K, out, in]
        self.w1 = nn.Parameter(torch.empty(K, dim_feedforward, dim_model))
        self.b1 = nn.Parameter(torch.empty(K, dim_feedforward))
        self.w2 = nn.Parameter(torch.empty(K, dim_out, dim_feedforward))
        self.b2 = nn.Parameter(torch.empty(K, dim_out))

        # kaiming init for each block
        for w in (self.w1, self.w2):
            nn.init.kaiming_uniform_(w, a=math.sqrt(5))
        for b, fan_in in ((self.b1, dim_model), (self.b2, dim_feedforward)):
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(b, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through all K separate MLPs."""
        # broadcast the same input to every component: [B, T, K, dim_model]
        x_exp = x.unsqueeze(-2).expand(-1, -1, self.K, -1)

        h = torch.einsum("btkd,khd->btkh", x_exp, self.w1) + self.b1
        h = F.gelu(h)
        y = torch.einsum("btkh,koh->btko", h, self.w2) + self.b2  # [B,T,K,dim_out]
        return y


class NeuralProcessHead(abc.ABC, nn.Module):

    @abc.abstractmethod
    def forward(
        self,
        zt: torch.Tensor,
        yt: Optional[torch.Tensor] = None,
        *,
        loss_mask: Optional[torch.Tensor] = None,
        num_samples: int = 0,
    ) -> LossAttr:
        """
        Args:
            zt: [B, T, dim_model]
            yt: [B, T, dim_y]
            loss_mask: [B, T] (optional)
            num_samples: Number of samples to draw from the distribution.
        Returns:
            A LossAttr object containing the loss and other relevant information.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """
        Sample from the embedding zt.

        Args:
            zt: [B, T, dim_model]
            num_samples: Number of samples to generate.
        Returns:
            A tensor of shape [B, T, num_samples, dim_y] containing the sampled outputs.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def log_likelihood(
        self,
        zt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute log likelihood of i.i.d. targets.

        Args:
            zt: [B, T, dim_model]
            yt: [B, T, dim_y]
        Returns:
            A tensor of shape [B, T] containing the log likelihood values.
        """
        raise NotImplementedError


class MixtureGaussian(NeuralProcessHead):
    """Mixture of Gaussians model for probabilistic predictions."""

    def __init__(
        self,
        dim_y: int,
        dim_model: int,
        dim_feedforward: int,
        num_components: int,
        name: Optional[str] = None,
        trange: Tuple[float, float] = (-1.0, 1.0),
        std_min: float = 1e-3,
    ) -> None:

        assert num_components >= 1

        super().__init__()
        self.name = name
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.dim_y = dim_y

        # initialize head
        self.head = Head(
            dim_model=dim_model,
            dim_feedforward=dim_feedforward,
            dim_out=3 * dim_y,
            K=num_components,
        )

        self.num_components = num_components
        min_range, max_range = tuple(trange)

        if num_components > 1:
            self.mean_global_bias = nn.Parameter(
                torch.linspace(min_range, max_range, num_components)
            )
            delta = 0.5 * (max_range - min_range) / (num_components - 1)
            self.std_global_bias = nn.Parameter(
                torch.ones_like(self.mean_global_bias)
                * self._inverse_softplus(torch.tensor(delta))
            )
            self.weights_global_bias = nn.Parameter(torch.zeros(num_components))

        self.std_min = std_min

    def forward(
        self,
        zt: torch.Tensor,
        yt: Optional[torch.Tensor] = None,
        *,
        loss_mask: Optional[torch.Tensor] = None,
        num_samples: int = 0,
    ) -> LossAttr:
        """Forward pass through the mixture model."""
        mean, std, weights = self._parameterize(zt)

        if yt is not None:
            log_likelihood = self._loglikelihood(yt, mean, std, weights)
            if loss_mask is not None:
                log_likelihood = log_likelihood.mean(-1) * loss_mask
                denom = loss_mask.sum().clamp(min=1)
            else:
                denom = log_likelihood.numel()
            loss = -log_likelihood.sum() / denom
        else:
            # During inference, no loss or log_likelihood computation
            log_likelihood = None
            loss = None

        samples = None
        if num_samples > 0:
            B, T = zt.shape[:2]

            samples = self._sample_mixture(
                self._flat(mean),
                self._flat(std),
                self._flat(weights),
                num_sample=num_samples,
            )
            samples = samples.permute(1, 0, 2).view(B, T, num_samples, self.dim_y)

        return LossAttr(
            log_likelihood=log_likelihood,
            loss=loss,
            means=mean,
            sds=std,
            weights=weights,
            samples=samples,
        )

    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """Generate samples from the mixture Gaussian.
        
        Args:
            zt: Input features [B, T, dim_model]
            num_samples: Number of samples to generate
            
        Returns:
            Samples of shape [B, T, num_samples, dim_y]
        """
        means, stds, weights = self._parameterize(zt)
        B, T, _ = zt.shape

        samples = self._sample_mixture(
            self._flat(means),
            self._flat(stds),
            self._flat(weights),
            num_sample=num_samples
        )
        samples = samples.permute(1, 0, 2).view(B, T, num_samples, self.dim_y)

        return samples

    def log_likelihood(
        self,
        zt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute log likelihood of i.i.d. targets.

        Args:
            zt: [B, T, dim_model]
            yt: [B, T, dim_y]
        Returns:
            A tensor of shape [B, T] containing the log likelihood values.
        """
        mean, std, weights = self._parameterize(zt)
        ll_full = self._loglikelihood(yt, mean, std, weights) # [B, T, dim_y]
        return ll_full.sum(-1)

    def _flat(self, x: torch.Tensor) -> torch.Tensor:
        """Flatten tensor for mixture sampling."""
        return x.view(-1, self.num_components, self.dim_y)

    def add_bias(
        self, raw_mean: torch.Tensor, raw_std: torch.Tensor, raw_weights: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Add global biases to raw mixture parameters."""
        # Reshape biases to [1, 1, K, 1] for correct broadcasting with [B, T, K, dim_y]
        if self.num_components > 1:
            mean_bias_reshaped = self.mean_global_bias[None, None, :, None]
            std_bias_reshaped = self.std_global_bias[None, None, :, None]
            weights_bias_reshaped = self.weights_global_bias[None, None, :, None]

            mean = raw_mean + mean_bias_reshaped
            std = torch.min(F.softplus(raw_std + std_bias_reshaped), torch.tensor(2.0)) + self.std_min
            # Softmax must be applied over the component dimension (K = dim 2)
            weights = F.softmax(raw_weights + weights_bias_reshaped, dim=2)
        else:
            mean = raw_mean
            std = torch.min(F.softplus(raw_std), torch.tensor(2.0)) + self.std_min
            weights = F.softmax(raw_weights, dim=2)

        return mean, std, weights

    def _parameterize(
        self, zt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Convert network outputs to mixture parameters."""
        B, T, _ = zt.shape
        outputs = self.head(zt).view(B, T, self.num_components, 3, self.dim_y)
        means, stds, weights = outputs.unbind(dim=3)
        return self.add_bias(means, stds, weights)

    @staticmethod
    def _loglikelihood(
        value: torch.Tensor,
        means: torch.Tensor,
        stds: torch.Tensor,
        weights: torch.Tensor,
    ) -> torch.Tensor:
        """Compute log-likelihood of values under the mixture distribution."""
        # Unsqueeze value at dim 2 (component dimension) for broadcasting
        value = value.unsqueeze(2)  # Shape: [B, T, 1, dim_y]
        # [B, T, 1, dim_y] can broadcast with means/stds [B, T, K, dim_y]
        components = Normal(means, stds, validate_args=False)
        # log_prob result shape: [B, T, K, dim_y]
        log_probs = components.log_prob(value) + torch.log(weights.clamp(min=1e-12))
        # logsumexp over K dimension (dim=2) -> [B, T, dim_y]
        return torch.logsumexp(log_probs, dim=2)

    @staticmethod
    def _inverse_softplus(y: torch.Tensor) -> torch.Tensor:
        """Compute inverse softplus function."""
        return torch.log(torch.special.expm1(y))

    @staticmethod
    def _sample_mixture(
        means: torch.Tensor, # [B*T, K, dim_y]
        stds: torch.Tensor, # [B*T, K, dim_y]
        weights: torch.Tensor, # [B*T, K, dim_y]
        num_sample: int = 1000,
    ) -> torch.Tensor:
        """Sample from the mixture distribution."""
        weights_flat = weights.permute(0, 2, 1).reshape(-1, weights.shape[1])  # [B*T*dim_y, K]
        mixture = Categorical(weights_flat)

        # component_indices shape: [num_sample, B*T*dim_y]
        component_indices = mixture.sample((num_sample,))

        # Gather the parameters of the chosen components
        # Match the dimensions of means/stds for gathering
        # Shape: [num_sample, B*T*dim_y] -> expand to [num_sample, B*T, 1, dim_y]
        indices_expanded = component_indices.view(num_sample, -1, 1, weights.shape[-1])

        # means/stds shape: [B*T, K, dim_y] -> unsqueeze to [1, B*T, K, dim_y] and expand
        means_expanded = means.unsqueeze(0).expand(num_sample, -1, -1, -1)
        stds_expanded = stds.unsqueeze(0).expand(num_sample, -1, -1, -1)

        chosen_means = torch.gather(means_expanded, 2, indices_expanded).squeeze(2)
        chosen_stds = torch.gather(stds_expanded, 2, indices_expanded).squeeze(2)

        # Sample from the chosen components. Shape: [num_sample, B*T, dim_y]
        samples = Normal(chosen_means, chosen_stds).sample()

        return samples


class MultiChannelMixtureGaussian(NeuralProcessHead):
    """
    Gaussian mixture model for multi-channel outputs.
    
    Each output channel has its own mixture parameters, but they share
    the same input features. The log-likelihood is computed independently
    per channel then summed.
    
    Args:
        dim_y: Number of output channels (e.g., 7 for EEG)
        dim_model: Input feature dimension
        dim_feedforward: Hidden dimension for the head
        num_components: Number of mixture components per channel
        std_min: Minimum standard deviation for numerical stability
    """
    
    def __init__(
        self,
        dim_y: int,
        dim_model: int,
        dim_feedforward: int,
        num_components: int,
        std_min: float = 1e-3,
    ):

        assert num_components >= 1

        super().__init__()
        self.dim_y = dim_y
        self.dim_model = dim_model
        self.dim_feedforward = dim_feedforward
        self.num_components = num_components
        self.std_min = std_min
        
        # Create a head that outputs parameters for all channels
        # Output: (mean, log_std, logit_weights) for each component and channel
        # Size: K * dim_y * 3 (mean + log_std + logit_weight)
        self.head = nn.Sequential(
            nn.Linear(dim_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, num_components * dim_y * 3)
        )
        
        # Initialize the final layer with small values
        nn.init.normal_(self.head[-1].weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.head[-1].bias)
        
    def forward(
        self,
        zt: torch.Tensor,  # (B, T, dim_model)
        yt: Optional[torch.Tensor] = None,  # (B, T, dim_y)
        loss_mask: Optional[torch.Tensor] = None,  # (B, T)
        num_samples: int = 0,
    ) -> LossAttr:
        """Forward pass computing mixture parameters and optionally log-likelihood."""

        means, stds, weights = self._parameterize(zt)
        
        # Compute log-likelihood if target is provided
        if yt is not None:
            channel_log_probs = self._loglikelihood(yt, means, stds, weights)  # (B, T, D)

            # Sum over channels to get total log probability
            total_log_prob = channel_log_probs.mean(dim=-1)  # (B, T)

            # Apply mask if provided
            if loss_mask is not None:
                masked_log_prob = total_log_prob * loss_mask
                denom = loss_mask.sum().clamp(min=1)
            else:
                masked_log_prob = total_log_prob
                denom = total_log_prob.numel()
            
            loss = -masked_log_prob.sum() / denom
            log_likelihood = total_log_prob
        else:
            loss = None
            log_likelihood = None
        
        # Sampling
        samples = None
        if num_samples > 0:
            samples = self._sample(means, stds, weights, num_samples)
        
        return LossAttr(
            log_likelihood=log_likelihood,
            loss=loss,
            means=means,  # (B, T, K, D)
            sds=stds,      # (B, T, K, D) 
            weights=weights,  # (B, T, K, D)
            samples=samples,
        )

    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """Generate samples from the multi-channel mixture Gaussian.
        
        Args:
            zt: Input features [B, T, dim_model]
            num_samples: Number of samples to generate
            
        Returns:
            Samples of shape [B, T, num_samples, dim_y]
        """
        means, stds, weights = self._parameterize(zt)
        # Sample from the mixture
        return self._sample(means, stds, weights, num_samples)

    def log_likelihood(
        self,
        zt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute log likelihood of i.i.d. targets.

        Args:
            zt: [B, T, dim_model]
            yt: [B, T, dim_y]
        Returns:
            A tensor of shape [B, T] containing the log likelihood values.
        """
        means, stds, weights = self._parameterize(zt)
        channel_log_probs = self._loglikelihood(yt, means, stds, weights)
        return channel_log_probs.sum(dim=-1)  # (B, T)

    def _parameterize(
        self, zt: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Convert network outputs to mixture parameters."""
        B, T, _ = zt.shape
        K = self.num_components
        D = self.dim_y

        # Get raw parameters from head
        raw_params = self.head(zt)  # (B, T, K*D*3)
        raw_params = raw_params.view(B, T, K, D, 3)

        # Split into components
        raw_means = raw_params[..., 0]  # (B, T, K, D)
        raw_stds = raw_params[..., 1]  # (B, T, K, D)
        raw_logit_weights = raw_params[..., 2]  # (B, T, K, D)

        # Process parameters
        means = raw_means  # No transformation needed
        stds = torch.min(F.softplus(raw_stds), torch.tensor(2.0)) + self.std_min  # Ensure positive, cap at 2 + std_min

        # Normalize weights per channel using softmax over K dimension
        weights = F.softmax(raw_logit_weights, dim=2)  # (B, T, K, D)

        return means, stds, weights

    def _loglikelihood(
        self,
        value: torch.Tensor,
        means: torch.Tensor,
        stds: torch.Tensor,
        weights: torch.Tensor,
    ) -> torch.Tensor:
        """Compute log-likelihood of values under the mixture distribution."""
        K = self.num_components
        # Unsqueeze value at dim 2 (component dimension) and broadcast
        value = value.unsqueeze(2).expand(-1, -1, K, -1)  # Shape: [B, T, K, dim_y]
        diff = value - means
        log_probs = -0.5 * (
            torch.log(2 * torch.pi * stds**2) + 
            (diff**2) / (stds**2)
        )  # (B, T, K, D)

        # Add log weights
        weighted_log_probs = log_probs + weights.clamp(min=1e-12).log()  # (B, T, K, D)

        # Log-sum-exp over components for each channel
        return torch.logsumexp(weighted_log_probs, dim=2)  # (B, T, D)

    def _sample(
        self, 
        means: torch.Tensor,  # (B, T, K, D)
        stds: torch.Tensor,   # (B, T, K, D)
        weights: torch.Tensor,  # (B, T, K, D)
        num_samples: int
    ) -> torch.Tensor:
        """Sample from the mixture distribution."""
        B, T, K, D = means.shape

        # For each channel, sample component indices based on weights
        # Reshape for easier sampling
        weights_flat = weights.permute(0, 1, 3, 2).reshape(B*T*D, K)

        # Sample component indices
        component_samples = torch.multinomial(
            weights_flat, 
            num_samples, 
            replacement=True
        )  # (B*T*D, num_samples)

        # Gather means and stds for selected components
        means_flat = means.permute(0, 1, 3, 2).reshape(B*T*D, K)
        stds_flat = stds.permute(0, 1, 3, 2).reshape(B*T*D, K)

        selected_means = torch.gather(
            means_flat, 1, 
            component_samples
        ).view(B, T, D, num_samples)

        selected_stds = torch.gather(
            stds_flat, 1,
            component_samples
        ).view(B, T, D, num_samples)

        # Sample from Gaussians
        eps = torch.randn_like(selected_means)
        samples = selected_means + selected_stds * eps

        # Reshape to (B, T, num_samples, D)
        samples = samples.permute(0, 1, 3, 2)

        return samples


# PFN Head in our API
class RiemannHead(NeuralProcessHead):
    def __init__(
        self,
        dim_model: int,
        dim_feedforward: int,
        head_num_buckets: int,
        head_bucket_samples: Optional[torch.Tensor] = None,
    ):
        
        super().__init__()

        self.logits_decoder = nn.Sequential(
            nn.Linear(dim_model, dim_feedforward),
            nn.GELU(),
            nn.Linear(dim_feedforward, head_num_buckets)
        )
        
        bucket_borders = get_bucket_borders(head_num_buckets, ys=head_bucket_samples)
        self.predictor = FullSupportBarDistribution(bucket_borders)

    def forward(
        self,
        zt: torch.Tensor,
        yt: Optional[torch.Tensor] = None,
        *,
        loss_mask: Optional[torch.Tensor] = None,
        num_samples: int = 0,
    ) -> LossAttr:
        """
        Forward pass through the head module.
        
        :param zt: Tensor of shape [batch_size, num_target, dim_model] representing the encoded target data.
        :param yt: Optional tensor of shape [batch_size, num_target, 1] representing the target data (dim_y = 1).
        :param loss_mask: Optional tensor of shape [batch_size, num_target] for masking the loss computation.
        :param num_samples: Number of samples to generate during inference. If 0, no sampling is done (training).
        :return: LossAttr containing loss, log likelihood of yt | zt, if yt is given (training), or containing samples for test.
        """

        logits = self.logits_decoder(zt)  # [batch_size, num_target, num_buckets]

        if yt is not None:
            assert zt.shape[:-1] == yt.shape[:-1], f"Expected zt and yt to have the same batch shape, got {zt.shape[:-1]} and {yt.shape[:-1]}"
            assert yt.shape[-1] == 1, f"Expected yt to have shape [batch_size, num_target, 1], got {yt.shape}"

            # note: bar distribution log_prob expects yt to be of shape [batch_shape, num_target, num_samples, 1]
            log_likelihood = self.predictor.log_prob(logits, yt)  # [batch_shape, num_target, 1]
            if loss_mask is not None:
                log_likelihood = log_likelihood.squeeze(-1) * loss_mask
                denom = loss_mask.sum().clamp(min=1)
            else:
                denom = log_likelihood.numel()
            loss = -log_likelihood.sum() / denom
        else:
            # During inference, no loss or log_likelihood computation
            log_likelihood = None
            loss = None

        samples = None
        if num_samples > 0:
            B, T = zt.shape[:2]

            samples = self.predictor.sample(logits, num_samples=num_samples) # [B, T, num_samples]
            samples = samples.unsqueeze(-1) # [B, T, num_samples, 1]

        return LossAttr(
            log_likelihood=log_likelihood,
            loss=loss,
            means=self.predictor.mean(logits).unsqueeze(-1), # [B, T, 1]
            samples=samples,
        )

    def sample(self, zt: torch.Tensor, num_samples: int = 1) -> torch.Tensor:
        """Generate samples.
        
        Args:
            zt: Input features [B, T, dim_model]
            num_samples: Number of samples to generate
            
        Returns:
            Samples of shape [B, T, num_samples, dim_y]
        """
        logits = self.logits_decoder(zt)  # [batch_size, num_target, num_buckets]
        samples = self.predictor.sample(logits, num_samples=num_samples) # [B, T, num_samples]
        samples = samples.unsqueeze(-1) # [B, T, num_samples, 1]
        return samples

    def log_likelihood(
        self,
        zt: torch.Tensor,
        yt: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute the log likelihood of i.i.d. target.

        Args:
            zt: [B, T, dim_model]
            yt: [B, T, dim_y]
        Returns:
            A tensor of shape [B, T] containing the log likelihood values.
        """
        assert zt.shape[:-1] == yt.shape[:-1], f"Expected zt and yt to have the same batch shape, got {zt.shape[:-1]} and {yt.shape[:-1]}"
        assert yt.shape[-1] == 1, f"Expected yt to have shape [batch_size, num_target, 1], got {yt.shape}"
        logits = self.logits_decoder(zt)  # [batch_size, num_target, num_buckets]

        # note: bar distribution log_prob expects yt to be of shape [batch_shape, num_target, num_samples, 1]
        log_likelihood = self.predictor.log_prob(logits, yt)  # [batch_shape, num_target, 1]
        return log_likelihood.squeeze(-1)
