"""
Core SESaMo model implementation.

This module provides the main Sesamo class that combines a prior
distribution, normalizing flow, stochastic modulation, and regularization
into a complete symmetry-enforcing normalizing flow model.
"""

import torch
from torch import Tensor
import torch.nn as nn
from typing import Union, Tuple
import logging

from .models.prior import Prior
from .models.stochasticmodulation import StochasticModulation
from .models.regularization import Regularization

logger = logging.getLogger("SESaMo")


class Sesamo(nn.Module):
    """
    Symmetry-Enforcing Stochastic Modulation model.
    
    Combines a prior distribution, normalizing flow(s), stochastic modulation
    layer(s), and regularization to create a normalizing flow that enforces
    symmetries in the output distribution.
    
    Parameters
    ----------
    prior : Prior
        Base distribution to sample from.
    flow : nn.Module or list
        Normalizing flow transformation(s).
    stochastic_modulation : StochasticModulation or list
        Stochastic modulation layer(s) for symmetry enforcement.
    regularization : Regularization or list
        Regularization module(s) for canonical region constraints.
        
    Example
    -------
    >>> sesamo = Sesamo(
    ...     prior=GaussianPrior(var=1, lat_shape=[1, 2]),
    ...     flow=RealNVP(lat_shape=[1, 2], num_coupling_layers=10),
    ...     stochastic_modulation=Z2Modulation(),
    ...     regularization=Z2Regularization(),
    ... )
    >>> samples, log_prob = sesamo.sample_with_logprob(1000)
    """
    
    def __init__(
        self,
        prior: Prior,
        flow: Union[nn.Module, list],
        stochastic_modulation: Union[StochasticModulation, list],
        regularization: Union[Regularization, list],
    ):
        super(Sesamo, self).__init__()
        
        if not isinstance(flow, list):
            flow = [flow]
        if not isinstance(stochastic_modulation, list):
            stochastic_modulation = [stochastic_modulation]
        if not isinstance(regularization, list):
            regularization = [regularization]
            
        self.prior = prior
        self.flow = nn.ModuleList(flow)
        self.stochastic_modulation = nn.ModuleList(stochastic_modulation)
        self.regularization = regularization
    
    
    def sample_for_training(self, num_samples: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        Creates samples from the model for training, including the log probability of the samples under the model, the log probability of the stochastic modulation, and the regularization penalty.
        
        Parameters
        ----------
        num_samples : int
            Number of samples to generate.

        Returns
        -------
        z : Tensor
            Samples drawn from the model.
        log_prob : Tensor
            Log-probability of `z` under the model distribution.
        log_prob_stochmod : Tensor
            Log-probability contribution from the stochastic modulation.
        penalty : Tensor
            Regularization penalty associated with the samples.
        """
        # sample from prior with log prob
        z, log_prob = self.prior.sample_with_logprob(num_samples)
        
        # apply flow
        for flow in self.flow:
            z, log_det = flow(z)
            log_prob -= log_det
            
        # compute regularization
        penalty = torch.zeros(z.shape[0], device=z.device, dtype=z.dtype)
        for regularization in self.regularization:
            penalty += regularization(z)
            
        # apply stochastic modulation
        for stochmod in self.stochastic_modulation:
            z, log_prob_stochmod = stochmod(z)
            log_prob += log_prob_stochmod
        
        return z, log_prob, log_prob_stochmod, penalty
    
    
    def sample_with_logprob(self, num_samples: int) -> Tuple[Tensor, Tensor]:
        """
        Creates samples and their log probabilities from the model.
        
        Parameters
        ----------
        num_samples : int
            Number of samples to generate.

        Returns
        -------
        z : Tensor
            Samples drawn from the model.
        log_prob : Tensor
            Log-probability of `z` under the model distribution.
        """
        return self.sample_for_training(num_samples)[:2]
    
    
    def sample(self, num_samples: int) -> Tensor:
        """
        Creates samples from the model.
        
        Parameters
        ----------
        num_samples : int
            Number of samples to generate.

        Returns
        -------
        z : Tensor
            Samples drawn from the model.
        """
        return self.sample_for_training(num_samples)[0]