"""
Prior distributions for normalizing flows.

This module provides base distributions that serve as the starting point
for normalizing flow transformations. Each prior supports sampling and
log-probability computation.
"""

import torch
from torch import Tensor
import logging

logger = logging.getLogger("SESaMo")


class Prior(torch.nn.Module):
    """
    Base class for prior distributions.
    
    Subclasses must implement the `log_prob` and `sample` methods.
    
    Attributes
    ----------
    name : str
        Identifier for the prior type.
    """
    
    name = "base"
    
    def log_prob(self, x: Tensor) -> Tensor:
        """
        Compute the log probability of samples under the prior.
        
        Parameters
        ----------
        x : Tensor
            Samples to evaluate. Shape: (batch, *lat_shape).
            
        Returns
        -------
        Tensor
            Log probabilities. Shape: (batch,).
        """
        raise NotImplementedError("log_prob method not implemented")

    def sample(self, n: int) -> Tensor:
        """
        Draw samples from the prior.
        
        Parameters
        ----------
        n : int
            Number of samples to draw.
            
        Returns
        -------
        Tensor
            Samples from the prior. Shape: (n, *lat_shape).
        """
        raise NotImplementedError("sample method not implemented")
    
    def sample_with_logprob(self, n: int) -> tuple:
        """
        Draw samples and compute their log probabilities.
        
        Parameters
        ----------
        n : int
            Number of samples to draw.
            
        Returns
        -------
        tuple
            (samples, log_probs) where samples has shape (n, *lat_shape)
            and log_probs has shape (n,).
        """
        z = self.sample(n)
        log_prob = self.log_prob(z)
        return z, log_prob


class GaussianPrior(Prior):
    """
    Gaussian (normal) prior distribution.
    
    Samples from an independent Gaussian distribution for each element
    of the latent tensor.
    
    Parameters
    ----------
    lat_shape : list
        Shape of the latent tensor (excluding batch dimension).
    dtype : torch.dtype, optional
        Data type for computations. Default is torch.float64.
    mean : float, optional
        Mean of the Gaussian. Default is 0.
    var : float, optional
        Variance of the Gaussian. Default is 1.
    verbose : bool, optional
        Whether to log initialization. Default is True.
    """
    
    name = "gaussian"
    
    def __init__(
        self, 
        lat_shape: list, 
        dtype: torch.dtype = torch.float64, 
        mean: float = 0, 
        var: float = 1, 
        verbose: bool = True
    ):
        """
        Initialize the Gaussian prior.
        
        Parameters
        ----------
        lat_shape : list
            Shape of the latent tensor.
        dtype : torch.dtype, optional
            Data type. Default is torch.float64.
        mean : float, optional
            Mean of the distribution. Default is 0.
        var : float, optional
            Variance of the distribution (must be positive). Default is 1.
        verbose : bool, optional
            Whether to log initialization. Default is True.
            
        Raises
        ------
        ValueError
            If variance is not positive.
        """
        super(GaussianPrior, self).__init__()
        if var <= 0:
            raise ValueError("Variance must be positive")

        self.lat_shape = lat_shape
        self.dtype = dtype
        self.register_buffer("mean", torch.tensor(mean, dtype=dtype))
        self.register_buffer("sigma", torch.tensor(var**0.5, dtype=dtype))
        
        if verbose:
            logger.info(f"Initialized gaussian prior with: mean = {mean}, var = {var}")

    def log_prob(self, x: Tensor) -> Tensor:
        """
        Compute the log probability of samples.
        
        Parameters
        ----------
        x : Tensor
            Samples to evaluate. Shape: (batch, *lat_shape).
            
        Returns
        -------
        Tensor
            Log probabilities. Shape: (batch,).
        """
        log_prob = torch.distributions.Normal(self.mean, self.sigma).log_prob(x)
        return log_prob.reshape(x.shape[0], -1).sum(-1)

    def sample(self, n: int) -> Tensor:
        """
        Draw samples from the Gaussian prior.
        
        Parameters
        ----------
        n : int
            Number of samples to draw.
            
        Returns
        -------
        Tensor
            Samples of shape (n, *lat_shape).
        """
        return torch.distributions.Normal(self.mean, self.sigma).sample((n, *self.lat_shape))
    


class UniformPrior(Prior):
    """
    Uniform prior distribution.
    
    Samples uniformly from a hypercube [low, high]^d.
    
    Parameters
    ----------
    lat_shape : list
        Shape of the latent tensor (excluding batch dimension).
    low : float, optional
        Lower bound of the uniform distribution. Default is -1.
    high : float, optional
        Upper bound of the uniform distribution. Default is 1.
    dtype : torch.dtype, optional
        Data type for computations. Default is torch.float64.
    """
    
    name = "uniform"
    
    def __init__(
        self, 
        lat_shape: list, 
        low: float = -1, 
        high: float = 1, 
        dtype: torch.dtype = torch.float64
    ):
        """
        Initialize the uniform prior.
        
        Parameters
        ----------
        lat_shape : list
            Shape of the latent tensor.
        low : float, optional
            Lower bound. Default is -1.
        high : float, optional
            Upper bound. Default is 1.
        dtype : torch.dtype, optional
            Data type. Default is torch.float64.
            
        Raises
        ------
        ValueError
            If high <= low.
        """
        super(UniformPrior, self).__init__()
        if high <= low:
            raise ValueError("High bound must be larger than low bound")

        self.lat_shape = lat_shape
        self.register_buffer("low", torch.tensor(low, dtype=dtype))
        self.register_buffer("high", torch.tensor(high, dtype=dtype))

        logger.info(f"Initalized uniform prior with: low = {self.low}, high = {self.high}")

    def log_prob(self, x: Tensor) -> Tensor:
        """
        Compute the log probability of samples.
        
        Parameters
        ----------
        x : Tensor
            Samples to evaluate. Shape: (batch, *lat_shape).
            
        Returns
        -------
        Tensor
            Log probabilities. Shape: (batch,).
        """
        log_prob = torch.distributions.Uniform(self.low, self.high).log_prob(x)
        return log_prob.reshape(x.shape[0], -1).sum(-1)

    def sample(self, n: int) -> Tensor:
        """
        Draw samples from the uniform prior.
        
        Parameters
        ----------
        n : int
            Number of samples to draw.
            
        Returns
        -------
        Tensor
            Samples of shape (n, *lat_shape).
        """
        return torch.distributions.Uniform(self.low, self.high).sample((n, *self.lat_shape))