import torch


def get_activation_function(activation_name):
    """
    Returns the activation function based on the provided name.
    """
    if activation_name == "ReLU":
        return torch.nn.ReLU
    elif activation_name == "LeakyReLU":
        return torch.nn.LeakyReLU
    elif activation_name == "SiLU":
        return torch.nn.SiLU
    elif activation_name == "GELU":
        return torch.nn.GELU
    elif activation_name == "Mish":
        return torch.nn.Mish
    elif activation_name == "Tanh":
        return torch.nn.Tanh
    else:
        raise ValueError(f"Unknown activation function: {activation_name}")


def enable_dropout(model):
    """Function to enable the dropout layers during test-time"""
    model.eval()
    for m in model.modules():
        if m.__class__.__name__.startswith("Dropout"):
            m.train()


class MLP(torch.nn.Module):
    def __init__(
        self, input_dim, output_dim, hidden_dim, depth, dropout, activation="ReLU"
    ):
        super(MLP, self).__init__()
        activation = get_activation_function(activation)
        layers = []
        layers.append(torch.nn.Linear(input_dim, hidden_dim))
        layers.append(activation())
        for _ in range(depth - 1):
            layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
            layers.append(activation())
            layers.append(torch.nn.Dropout(p=dropout))
        layers.append(torch.nn.Linear(hidden_dim, output_dim))
        self.model = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class AffineCouplingLayer(torch.nn.Module):
    """Affine coupling layer for INNs.

    Implements the affine coupling transformation from RealNVP.
    Uses the existing MLP class for scale (s) and translation (t) networks.

    Forward pass:
        v1 = u1 * exp(s2(u2)) + t2(u2)
        v2 = u2 * exp(s1(v1)) + t1(v1)

    Inverse pass:
        u2 = (v2 - t1(v1)) * exp(-s1(v1))
        u1 = (v1 - t2(u2)) * exp(-s2(u2))

    Args:
        dim: Total input dimension.
        split_dim: Dimension at which to split input (u1 has split_dim dimensions).
        hidden_dim: Hidden dimension for s and t networks.
        subnet_depth: Depth of s and t MLPs.
        clamp: Soft clamping value for scale factor (numerical stability).
               Uses: s_clamped = clamp * tanh(s / clamp)
        activation: Activation function name for subnets.
    """

    def __init__(
        self,
        dim: int,
        split_dim: int,
        hidden_dim: int = 256,
        subnet_depth: int = 2,
        clamp: float = 2.0,
        activation: str = "ReLU",
    ):
        super().__init__()
        self.dim = dim
        self.split_dim = split_dim
        self.clamp = clamp

        dim1 = split_dim
        dim2 = dim - split_dim

        # Networks for first half transformation (conditioned on second half)
        self.s2 = MLP(
            input_dim=dim2,
            output_dim=dim1,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )
        self.t2 = MLP(
            input_dim=dim2,
            output_dim=dim1,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )

        # Networks for second half transformation (conditioned on first half)
        self.s1 = MLP(
            input_dim=dim1,
            output_dim=dim2,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )
        self.t1 = MLP(
            input_dim=dim1,
            output_dim=dim2,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )

    def _soft_clamp(self, s: torch.Tensor) -> torch.Tensor:
        """Apply soft clamping to scale factor for numerical stability."""
        return self.clamp * torch.tanh(s / self.clamp)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: x -> v"""
        u1, u2 = x[:, : self.split_dim], x[:, self.split_dim :]

        # Transform first half
        s2 = self._soft_clamp(self.s2(u2))
        t2 = self.t2(u2)
        v1 = u1 * torch.exp(s2) + t2

        # Transform second half
        s1 = self._soft_clamp(self.s1(v1))
        t1 = self.t1(v1)
        v2 = u2 * torch.exp(s1) + t1

        return torch.cat([v1, v2], dim=1)

    def inverse(self, v: torch.Tensor) -> torch.Tensor:
        """Inverse pass: v -> x"""
        v1, v2 = v[:, : self.split_dim], v[:, self.split_dim :]

        # Invert second half transformation
        s1 = self._soft_clamp(self.s1(v1))
        t1 = self.t1(v1)
        u2 = (v2 - t1) * torch.exp(-s1)

        # Invert first half transformation
        s2 = self._soft_clamp(self.s2(u2))
        t2 = self.t2(u2)
        u1 = (v1 - t2) * torch.exp(-s2)

        return torch.cat([u1, u2], dim=1)


class PermutationLayer(torch.nn.Module):
    """Fixed random permutation layer.

    Permutes dimensions to ensure all dimensions interact across coupling layers.

    Args:
        dim: Input dimension.
        seed: Random seed for reproducible permutation. If None, uses random permutation.
    """

    def __init__(self, dim: int, seed: int | None = None):
        super().__init__()
        if seed is not None:
            generator = torch.Generator().manual_seed(seed)
            perm = torch.randperm(dim, generator=generator)
        else:
            perm = torch.randperm(dim)

        # Store permutation and inverse as buffers (not parameters)
        self.register_buffer("perm", perm)
        self.register_buffer("inv_perm", torch.argsort(perm))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply permutation."""
        return x[:, self.perm]

    def inverse(self, x: torch.Tensor) -> torch.Tensor:
        """Apply inverse permutation."""
        return x[:, self.inv_perm]


class INN(torch.nn.Module):
    """Invertible Neural Network with affine coupling layers.

    Maps design parameters x to (labels y, latent z) and vice versa.
    The network is exactly invertible by construction.

    Args:
        input_dim: Dimension of x (design parameters).
        output_dim: Dimension of y (labels).
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.
        permutation_seed: Base seed for permutation layers (incremented per layer).
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        num_blocks: int = 4,
        hidden_dim: int = 256,
        subnet_depth: int = 2,
        clamp: float = 2.0,
        activation: str = "ReLU",
        permutation_seed: int | None = 42,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.latent_dim = input_dim - output_dim

        if self.latent_dim <= 0:
            raise ValueError(
                f"input_dim ({input_dim}) must be greater than output_dim ({output_dim})"
            )

        # Build sequence of coupling blocks with permutations
        self.blocks = torch.nn.ModuleList()
        split_dim = input_dim // 2

        for i in range(num_blocks):
            # Add coupling layer
            self.blocks.append(
                AffineCouplingLayer(
                    dim=input_dim,
                    split_dim=split_dim,
                    hidden_dim=hidden_dim,
                    subnet_depth=subnet_depth,
                    clamp=clamp,
                    activation=activation,
                )
            )
            # Add permutation layer (except after last block)
            if i < num_blocks - 1:
                seed = permutation_seed + i if permutation_seed is not None else None
                self.blocks.append(PermutationLayer(dim=input_dim, seed=seed))

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass: x -> (y, z)

        Args:
            x: Design parameters, shape (batch_size, input_dim).

        Returns:
            y: Labels, shape (batch_size, output_dim).
            z: Latent variables, shape (batch_size, latent_dim).
        """
        h = x
        for block in self.blocks:
            h = block.forward(h)

        y = h[:, : self.output_dim]
        z = h[:, self.output_dim :]
        return y, z

    def inverse(self, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        """Inverse pass: (y, z) -> x

        Args:
            y: Labels, shape (batch_size, output_dim).
            z: Latent variables, shape (batch_size, latent_dim).

        Returns:
            x: Design parameters, shape (batch_size, input_dim).
        """
        h = torch.cat([y, z], dim=1)
        for block in reversed(self.blocks):
            h = block.inverse(h)
        return h


class ConditionalAffineCouplingLayer(torch.nn.Module):
    """Conditional affine coupling layer for INNs.

    Similar to AffineCouplingLayer but the s and t networks also receive
    conditioning information c as input.

    Forward pass:
        v1 = u1 * exp(s2(u2, c)) + t2(u2, c)
        v2 = u2 * exp(s1(v1, c)) + t1(v1, c)

    Inverse pass:
        u2 = (v2 - t1(v1, c)) * exp(-s1(v1, c))
        u1 = (v1 - t2(u2, c)) * exp(-s2(u2, c))

    Args:
        dim: Total input dimension.
        split_dim: Dimension at which to split input (u1 has split_dim dimensions).
        cond_dim: Dimension of conditioning variable c.
        hidden_dim: Hidden dimension for s and t networks.
        subnet_depth: Depth of s and t MLPs.
        clamp: Soft clamping value for scale factor (numerical stability).
        activation: Activation function name for subnets.
    """

    def __init__(
        self,
        dim: int,
        split_dim: int,
        cond_dim: int,
        hidden_dim: int = 256,
        subnet_depth: int = 2,
        clamp: float = 2.0,
        activation: str = "ReLU",
    ):
        super().__init__()
        self.dim = dim
        self.split_dim = split_dim
        self.cond_dim = cond_dim
        self.clamp = clamp

        dim1 = split_dim
        dim2 = dim - split_dim

        # Networks for first half transformation (conditioned on second half + c)
        self.s2 = MLP(
            input_dim=dim2 + cond_dim,
            output_dim=dim1,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )
        self.t2 = MLP(
            input_dim=dim2 + cond_dim,
            output_dim=dim1,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )

        # Networks for second half transformation (conditioned on first half + c)
        self.s1 = MLP(
            input_dim=dim1 + cond_dim,
            output_dim=dim2,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )
        self.t1 = MLP(
            input_dim=dim1 + cond_dim,
            output_dim=dim2,
            hidden_dim=hidden_dim,
            depth=subnet_depth,
            dropout=0,
            activation=activation,
        )

    def _soft_clamp(self, s: torch.Tensor) -> torch.Tensor:
        """Apply soft clamping to scale factor for numerical stability."""
        return self.clamp * torch.tanh(s / self.clamp)

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """Forward pass: x -> v (conditioned on c)"""
        u1, u2 = x[:, : self.split_dim], x[:, self.split_dim :]

        # Transform first half (conditioned on u2 and c)
        u2_c = torch.cat([u2, c], dim=1)
        s2 = self._soft_clamp(self.s2(u2_c))
        t2 = self.t2(u2_c)
        v1 = u1 * torch.exp(s2) + t2

        # Transform second half (conditioned on v1 and c)
        v1_c = torch.cat([v1, c], dim=1)
        s1 = self._soft_clamp(self.s1(v1_c))
        t1 = self.t1(v1_c)
        v2 = u2 * torch.exp(s1) + t1

        return torch.cat([v1, v2], dim=1)

    def inverse(self, v: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """Inverse pass: v -> x (conditioned on c)"""
        v1, v2 = v[:, : self.split_dim], v[:, self.split_dim :]

        # Invert second half transformation (conditioned on v1 and c)
        v1_c = torch.cat([v1, c], dim=1)
        s1 = self._soft_clamp(self.s1(v1_c))
        t1 = self.t1(v1_c)
        u2 = (v2 - t1) * torch.exp(-s1)

        # Invert first half transformation (conditioned on u2 and c)
        u2_c = torch.cat([u2, c], dim=1)
        s2 = self._soft_clamp(self.s2(u2_c))
        t2 = self.t2(u2_c)
        u1 = (v1 - t2) * torch.exp(-s2)

        return torch.cat([u1, u2], dim=1)


class ConditionalINN(torch.nn.Module):
    """Conditional Invertible Neural Network with affine coupling layers.

    Maps design parameters x to (labels y, latent z) given conditioning c,
    and vice versa. The network is exactly invertible by construction.

    Args:
        input_dim: Dimension of x (design parameters).
        output_dim: Dimension of y (labels).
        cond_dim: Dimension of conditioning variable c.
        num_blocks: Number of coupling blocks.
        hidden_dim: Hidden dimension in subnet MLPs.
        subnet_depth: Depth of s, t networks.
        clamp: Clamping value for scale factors.
        activation: Activation function for subnets.
        permutation_seed: Base seed for permutation layers (incremented per layer).
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        cond_dim: int,
        num_blocks: int = 4,
        hidden_dim: int = 256,
        subnet_depth: int = 2,
        clamp: float = 2.0,
        activation: str = "ReLU",
        permutation_seed: int | None = 42,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.cond_dim = cond_dim
        self.latent_dim = input_dim - output_dim

        if self.latent_dim <= 0:
            raise ValueError(
                f"input_dim ({input_dim}) must be greater than output_dim ({output_dim})"
            )

        # Build sequence of coupling blocks with permutations
        self.coupling_blocks = torch.nn.ModuleList()
        self.permutation_blocks = torch.nn.ModuleList()
        split_dim = input_dim // 2

        for i in range(num_blocks):
            # Add coupling layer
            self.coupling_blocks.append(
                ConditionalAffineCouplingLayer(
                    dim=input_dim,
                    split_dim=split_dim,
                    cond_dim=cond_dim,
                    hidden_dim=hidden_dim,
                    subnet_depth=subnet_depth,
                    clamp=clamp,
                    activation=activation,
                )
            )
            # Add permutation layer (except after last block)
            if i < num_blocks - 1:
                seed = permutation_seed + i if permutation_seed is not None else None
                self.permutation_blocks.append(PermutationLayer(dim=input_dim, seed=seed))

    def forward(
        self, x: torch.Tensor, c: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass: x -> (y, z) given conditioning c

        Args:
            x: Design parameters, shape (batch_size, input_dim).
            c: Conditioning variable, shape (batch_size, cond_dim).

        Returns:
            y: Labels, shape (batch_size, output_dim).
            z: Latent variables, shape (batch_size, latent_dim).
        """
        h = x
        for i, coupling in enumerate(self.coupling_blocks):
            h = coupling.forward(h, c)
            if i < len(self.permutation_blocks):
                h = self.permutation_blocks[i].forward(h)

        y = h[:, : self.output_dim]
        z = h[:, self.output_dim :]
        return y, z

    def inverse(
        self, y: torch.Tensor, z: torch.Tensor, c: torch.Tensor
    ) -> torch.Tensor:
        """Inverse pass: (y, z) -> x given conditioning c

        Args:
            y: Labels, shape (batch_size, output_dim).
            z: Latent variables, shape (batch_size, latent_dim).
            c: Conditioning variable, shape (batch_size, cond_dim).

        Returns:
            x: Design parameters, shape (batch_size, input_dim).
        """
        h = torch.cat([y, z], dim=1)
        for i in range(len(self.coupling_blocks) - 1, -1, -1):
            if i < len(self.permutation_blocks):
                h = self.permutation_blocks[i].inverse(h)
            h = self.coupling_blocks[i].inverse(h, c)
        return h


if __name__ == "__main__":

    from uq_diagcfm.utils import count_module_parameters

    model = MLP(
        input_dim=6,
        output_dim=6,
        hidden_dim=512 * 2,
        depth=3,
        dropout=0.0,
        activation="ReLU",
    )

    print(model)
    print("Number of parameters:", count_module_parameters(model))
