"""
RealNVP normalizing flow implementation.

This module provides the RealNVP (Real-valued Non-Volume Preserving)
normalizing flow architecture based on affine coupling layers.

References
----------
Dinh, L., Sohl-Dickstein, J., & Bengio, S. (2017). Density estimation 
using Real-NVP. ICLR 2017.
"""

import logging
from typing import Tuple
import torch
from torch import Tensor
import torch.nn as nn
import numpy as np

logger = logging.getLogger("SESaMo")


class Coupling(nn.Module):
    """
    Affine coupling layer for RealNVP.
    
    Splits the input into two parts and applies an affine transformation
    to one part conditioned on the other. The transformation is:
    y_on = x_on * exp(s(x_off)) + t(x_off)
    where s and t are neural networks.
    
    Parameters
    ----------
    lat_shape : list
        Shape of the latent space (excluding batch dimension).
    mask_config : int
        Masking configuration (0 or 1) determining which dimensions
        are transformed.
    num_hidden_layers : int, optional
        Number of hidden layers in the networks. Default is 4.
    num_hidden_features : int, optional
        Number of features in hidden layers. Default is 40.
    bias : bool, optional
        Whether to use bias in linear layers. Default is False.
    activation : str, optional
        Activation function ('relu', 'tanh', 'leakyrelu'). Default is 'relu'.
    dtype : torch.dtype, optional
        Data type for parameters. Default is torch.float32.
    """
    
    def __init__(
        self,
        lat_shape: list,
        mask_config: int,
        num_hidden_layers: int = 4,
        num_hidden_features: int = 40,
        bias: bool = False,
        activation: str = 'relu',
        dtype: torch.dtype = torch.float32,
    ):
        """
        Initialize the coupling layer.
        
        Parameters
        ----------
        lat_shape : list
            Shape of latent space.
        mask_config : int
            Masking configuration (0 or 1).
        num_hidden_layers : int, optional
            Number of hidden layers. Default is 4.
        num_hidden_features : int, optional
            Hidden layer width. Default is 40.
        bias : bool, optional
            Use bias in layers. Default is False.
        activation : str, optional
            Activation function. Default is 'relu'.
        dtype : torch.dtype, optional
            Data type. Default is torch.float32.
            
        Raises
        ------
        ValueError
            If activation is not supported.
        """
        super(Coupling, self).__init__()
        self.mask_config = mask_config
        self.lat_shape = lat_shape
        self.num_hidden_layers = num_hidden_layers
        self.num_hidden_features = num_hidden_features
        self.dtype = dtype

        in_out_dim = np.prod(lat_shape)

        activation_dict = {
            'relu': nn.ReLU(),
            'tanh': nn.Tanh(),
            'leakyrelu': nn.LeakyReLU(0.01),
        }
        if activation.lower() not in activation_dict:
            raise ValueError(f"Activation {activation} not implemented")
        activation_func = activation_dict[activation.lower()]
        
        # Scaling and translation networks
        self.scale_net = nn.Sequential(
            nn.Linear(in_out_dim // 2, num_hidden_features, bias=bias, dtype=dtype),
            activation_func,
            *[
                nn.Sequential(
                    nn.Linear(num_hidden_features, num_hidden_features, bias=bias, dtype=dtype),
                    activation_func,
                ) for _ in range(num_hidden_layers)
            ],
            nn.Linear(num_hidden_features, in_out_dim // 2, bias=bias, dtype=dtype)
        )

        self.translation_net = nn.Sequential(
            nn.Linear(in_out_dim // 2, num_hidden_features, bias=bias, dtype=dtype),
            activation_func,
            *[
                nn.Sequential(
                    nn.Linear(num_hidden_features, num_hidden_features, bias=bias, dtype=dtype),
                    activation_func,
                ) for _ in range(num_hidden_layers)
            ],
            nn.Linear(num_hidden_features, in_out_dim // 2, bias=bias, dtype=dtype)
        )

        # init parameter with xavier
        for m in self.scale_net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
        for m in self.translation_net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)


    def _split(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Split the input tensor into two interleaved parts.
        
        Converts (0,1,2,3) to (0,2) and (1,3) based on mask_config.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, features).
            
        Returns
        -------
        tuple
            (on, off) tensors, each of shape (batch, features//2).
        """
        B, W = x.shape
        x = x.reshape((B, W // 2, 2))
        if self.mask_config:
            on, off = x[:, :, 0], x[:, :, 1]
        else:
            off, on = x[:, :, 0], x[:, :, 1]

        return on, off


    def _join(self, on: Tensor, off: Tensor) -> Tensor:
        """
        Combine two tensor parts back into one.
        
        Parameters
        ----------
        on : Tensor
            Transformed part of shape (batch, features//2).
        off : Tensor
            Untransformed part of shape (batch, features//2).
            
        Returns
        -------
        Tensor
            Combined tensor of shape (batch, features).
        """
        if self.mask_config:
            x = torch.stack((on, off), dim=2)
        else:
            x = torch.stack((off, on), dim=2)

        return x
    

    def forward(self, x: Tensor, reverse: bool = False) -> Tuple[Tensor, Tensor]:
        """
        Apply the affine coupling transformation.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, *lat_shape).
        reverse : bool, optional
            Whether to apply the inverse transformation. Default is False.
            
        Returns
        -------
        tuple
            (output, log_det) where output has the same shape as input
            and log_det has shape (batch,).
        """
        x_shape = x.shape
        x = x.reshape(x.shape[0], -1)
        on, off = self._split(x)

        # Scaling and translation
        scale = self.scale_net(off)
        translation = self.translation_net(off)
        
        if not reverse:
            on = on * torch.exp(scale) + translation
            log_det = scale.sum(dim=1)
        else:
            on = (on - translation) * torch.exp(-scale)
            log_det = -scale.sum(dim=1)
            
        x = self._join(on, off)
        return x.reshape(x_shape), log_det
    
    
    def reverse(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Apply the inverse affine coupling transformation.
        
        Parameters
        ----------
        x : Tensor
            Input tensor of shape (batch, *lat_shape).
            
        Returns
        -------
        tuple
            (output, log_det) where output has the same shape as input
            and log_det has shape (batch,).
        """
        return self.forward(x, reverse=True)



class RealNVP(nn.Module):
    """
    RealNVP (Real-valued Non-Volume Preserving) normalizing flow.
    
    Stacks multiple affine coupling layers with alternating masking patterns
    to create a flexible bijective transformation.
    
    Parameters
    ----------
    lat_shape : list
        Shape of the latent space (excluding batch dimension).
    num_coupling_layers : int, optional
        Number of coupling layers. Default is 6.
    num_hidden_layers : int, optional
        Hidden layers per coupling. Default is 4.
    num_hidden_features : int, optional
        Features in hidden layers. Default is 100.
    mask_config : int, optional
        Initial masking configuration. Default is 1.
    bias : bool, optional
        Use bias in layers. Default is True.
    activation : str, optional
        Activation function. Default is "relu".
    dtype : torch.dtype, optional
        Data type. Default is torch.float32.
    
    Attributes
    ----------
    name : str
        Model identifier ("realnvp").
    couplings : nn.ModuleList
        List of coupling layers.
    """
    
    name = "realnvp"
    
    def __init__(
        self,
        lat_shape: list,
        num_coupling_layers: int = 6,
        num_hidden_layers: int = 4,
        num_hidden_features: int = 100,
        mask_config: int = 1,
        bias: bool = True,
        activation: str = "relu", 
        dtype: torch.dtype = torch.float32,
        **kwargs,
    ):
        """
        Initialize the RealNVP flow.
        
        Parameters
        ----------
        lat_shape : list
            Shape of latent space.
        num_coupling_layers : int, optional
            Number of coupling layers. Default is 6.
        num_hidden_layers : int, optional
            Hidden layers per coupling. Default is 4.
        num_hidden_features : int, optional
            Features in hidden layers. Default is 100.
        mask_config : int, optional
            Initial mask configuration. Default is 1.
        bias : bool, optional
            Use bias. Default is True.
        activation : str, optional
            Activation function. Default is "relu".
        dtype : torch.dtype, optional
            Data type. Default is torch.float32.
        **kwargs
            Additional keyword arguments (ignored).
        """
        super(RealNVP, self).__init__()

        # init couplings
        self.couplings = nn.ModuleList(
            [
                Coupling(
                    lat_shape=lat_shape,
                    num_hidden_layers=num_hidden_layers,
                    num_hidden_features=num_hidden_features,
                    mask_config=(mask_config + i) % 2,
                    bias=bias,
                    activation=activation,
                    dtype=dtype,
                )
                for i in range(num_coupling_layers)
            ]
        )

        logger.info(f"Initialized RealNVP with {num_coupling_layers} couplings, each with {num_hidden_layers} layers of {num_hidden_features} features")


    def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Transform samples from the prior to the target distribution.
        
        Parameters
        ----------
        z : Tensor
            Input samples of shape (batch, *lat_shape).
            
        Returns
        -------
        tuple
            (x, log_det) where x has shape (batch, *lat_shape)
            and log_det has shape (batch,).
        """
        # init log det
        log_det = torch.zeros(z.shape[0], device=z.device)

        # apply neural network
        for coupling in self.couplings:
            z, single_log_det = coupling(z)
            log_det += single_log_det

        return z, log_det


    def reverse(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Transform samples from the target distribution back to the prior.
        
        Parameters
        ----------
        x : Tensor
            Input samples of shape (batch, *lat_shape).
            
        Returns
        -------
        tuple
            (z, log_det) where z has shape (batch, *lat_shape)
            and log_det has shape (batch,).
        """
        # init log det
        log_det = torch.zeros(x.shape[0], device=x.device)

        # apply reverse couplings
        for coupling in reversed(self.couplings):
            x, single_log_det = coupling.reverse(x)
            log_det += single_log_det

        return x, log_det