"""
Variational Mixture of Normalizing Flows (VMONF).

This module implements VMONF, which uses multiple normalizing flows
combined with a learned mixture distribution. A feedforward network
predicts mixture weights based on the input, allowing different flows
to specialize in different regions of the distribution.

References
----------
Pires, G. G. P. F., & Figueiredo, M. A. T. (2020). Variational Mixture 
of Normalizing Flows. ESANN 2020.
https://www.esann.org/sites/default/files/proceedings/2020/ES2020-188.pdf
"""

import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple
import logging
from .realnvp import RealNVP

logger = logging.getLogger("SESaMo")


class VMONF(nn.Module):
    """
    Variational Mixture of Normalizing Flows.
    
    Combines multiple RealNVP flows with a learned mixture distribution.
    A feedforward network predicts sector probabilities based on the input,
    and each sector has its own RealNVP transformation.
    
    Parameters
    ----------
    lat_shape : list
        Shape of the latent space (excluding batch dimension).
    sectors : int, optional
        Number of mixture components (flows). Default is 4.
    coupling : str, optional
        Coupling type (unused, for compatibility). Default is "altfc".
    num_coupling_layers : int, optional
        Coupling layers per flow. Default is 6.
    num_hidden_layers : int, optional
        Hidden layers per coupling. Default is 4.
    num_hidden_features : int, optional
        Features in hidden layers. Default is 100.
    activation : str, optional
        Activation function. Default is "relu".
    dtype : str, optional
        Data type string ("float32" or "float64"). Default is "float32".
        
    Attributes
    ----------
    name : str
        Model identifier ("vmonf").
    sectors : int
        Number of mixture components.
    realvnp_list : nn.ModuleList
        List of RealNVP flows, one per sector.
    feedforward : nn.Sequential
        Network that predicts mixture weights.
    prob_c : Tensor
        Mixture probabilities from the last forward pass.
    """
    
    name = "vmonf"
    
    def __init__(
        self,
        lat_shape: list,
        sectors: int = 4,
        coupling: str = "altfc",
        num_coupling_layers: int = 6,
        num_hidden_layers: int = 4,
        num_hidden_features: int = 100,
        activation: str = "relu",
        dtype: str = "float32",
        **kwargs
    ):
        """
        Initialize the VMONF model.
        
        Parameters
        ----------
        lat_shape : list
            Shape of latent space.
        sectors : int, optional
            Number of mixture components. Default is 4.
        coupling : str, optional
            Coupling type (for compatibility). Default is "altfc".
        num_coupling_layers : int, optional
            Coupling layers per flow. Default is 6.
        num_hidden_layers : int, optional
            Hidden layers per coupling. Default is 4.
        num_hidden_features : int, optional
            Features in hidden layers. Default is 100.
        activation : str, optional
            Activation function. Default is "relu".
        dtype : str, optional
            Data type. Default is "float32".
        **kwargs
            Additional keyword arguments (ignored).
        """
        super(VMONF, self).__init__()
        self.sectors = sectors
        
        self.realvnp_list = nn.ModuleList(
            [
                RealNVP(
                    lat_shape=lat_shape,
                    coupling=coupling,
                    num_coupling_layers=num_coupling_layers,
                    num_hidden_layers=num_hidden_layers,
                    num_hidden_features=num_hidden_features,
                    activation=activation,
                    dtype=dtype,
                )
                for _ in range(sectors)
            ]
        )
        
        in_dim = torch.prod(torch.tensor(lat_shape)).item()
        self.feedforward = nn.Sequential(
            nn.Linear(in_dim, num_hidden_features, dtype=dtype),
            nn.ReLU(),
            *[
                nn.Sequential(
                    nn.Linear(num_hidden_features, num_hidden_features, dtype=dtype),
                    nn.ReLU()
                ) for _ in range(num_hidden_layers - 1)
            ],
            nn.Linear(num_hidden_features, sectors, dtype=dtype),
            nn.Softmax(dim=-1)
        )
        
        logger.info(f"Initialized VMONF with {sectors} sectors")

    def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Transform samples through all flows and compute mixture weights.
        
        Parameters
        ----------
        z : Tensor
            Input samples of shape (batch, *lat_shape).
            
        Returns
        -------
        tuple
            (x_i, log_det_i) where:
            - x_i has shape (sectors, batch, *lat_shape)
            - log_det_i has shape (sectors, batch)
            
        Notes
        -----
        Also sets self.prob_c to the mixture probabilities of shape
        (sectors, batch).
        """
        self.prob_c = self.feedforward(z.reshape(z.shape[0], -1)).transpose(0, 1)  # shape (sectors, batch_size)
        
        # init x and log_dets
        x_i = torch.zeros((self.sectors, *z.shape), device=z.device, dtype=z.dtype)
        log_det_i = torch.zeros((self.sectors, z.shape[0]), device=z.device, dtype=z.dtype)
        
        for i, realvnp in enumerate(self.realvnp_list):
            x_i[i], log_det_i[i] = realvnp(z)
        
        return x_i, log_det_i