import torch
import torch.nn as nn
import torch.nn.functional as F

from enum import IntEnum
from models.real_nvp.st_resnet import STResNet
from util import checkerboard_like


class MaskType(IntEnum):
    CHECKERBOARD = 0
    CHANNEL_WISE = 1


class CouplingLayer(nn.Module):
    """Coupling layer in RealNVP for vector inputs.

    Args:
        input_dim (int): Dimensionality of the input vector.
        hidden_dim (int): Dimensionality of the hidden layers in the `s` and `t` networks.
        reverse_mask (bool): Whether to reverse the mask for alternating operations.
    """
    def __init__(self, input_dim, hidden_dim, reverse_mask=False):
        super(CouplingLayer, self).__init__()
        self.input_dim = input_dim
        self.reverse_mask = reverse_mask

        # Scale and translate network (s, t)
        self.st_net = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # Learnable scale and shift for s
        self.s_scale = nn.Parameter(torch.ones(1))
        self.s_shift = nn.Parameter(torch.zeros(1))

    def forward(self, x, sldj=None, reverse=False):
        # Split input into two halves
        x_a, x_b = torch.chunk(x, 2, dim=1) if not self.reverse_mask else torch.chunk(x, 2, dim=1)[::-1]

        # Compute s and t using x_a
        st = self.st_net(x_a)
        s, t = torch.chunk(st, 2, dim=1)
        s = self.s_scale * torch.tanh(s) + self.s_shift

        # Scale and translate x_b
        if reverse:
            # Inverse operation
            inv_exp_s = torch.exp(-s)
            if torch.isnan(inv_exp_s).any():
                raise RuntimeError("Scale factor has NaN entries in inverse mode")
            x_b = (x_b - t) * inv_exp_s
        else:
            # Forward operation
            exp_s = torch.exp(s)
            if torch.isnan(exp_s).any():
                raise RuntimeError("Scale factor has NaN entries in forward mode")
            x_b = exp_s * x_b + t

            # Update log-determinant of the Jacobian
            if sldj is not None:
                sldj += s.sum(dim=1)

        # Concatenate x_a and x_b back
        x = torch.cat((x_a, x_b), dim=1) if not self.reverse_mask else torch.cat((x_b, x_a), dim=1)
        return x, sldj
