import torch
import torch.nn as nn
import torch.nn.functional as F

class CouplingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.scale_net = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim // 2),
            nn.Tanh()
        )
        self.translate_net = nn.Sequential(
            nn.Linear(input_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim // 2)
        )
        
    def forward(self, x):
        x_a, x_b = x.chunk(2, dim=-1)
        s = self.scale_net(x_a)
        t = self.translate_net(x_a)
        y_b = x_b * torch.exp(s) + t
        y = torch.cat([x_a, y_b], dim=-1)
        return y

class RealNVP(nn.Module):
    def __init__(self, dim=16, hidden_dim=128, n_coupling_layers=4):
        super().__init__()
        self.dim = dim
        self.n_coupling_layers = n_coupling_layers
        self.layers = nn.ModuleList([
            CouplingLayer(dim, hidden_dim) for _ in range(n_coupling_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def randomize_params_partial(self, percent=0.5, control_vector=None):
        n = len(self.layers)
        k = int(n * percent)
        indices = torch.randperm(n)[:k]
        for idx in indices:
            layer = self.layers[idx]
            if control_vector is not None:
                control = control_vector.to(layer.scale_net[0].bias.device)
                with torch.no_grad():
                    layer.scale_net[0].bias.copy_(control)

def run_realnvp_batch(batch_input):
    batchsize, seq_len, dim = batch_input.shape

    outputs = torch.zeros_like(batch_input)

    for i in range(batchsize):
        model = RealNVP(dim=dim, hidden_dim=128, n_coupling_layers=4)

        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.05)
                nn.init.constant_(m.bias, 0)
        model.apply(init_weights)

        control_vector = torch.randn(128) * 0.01
        model.randomize_params_partial(percent=0.5, control_vector=control_vector)

        with torch.no_grad():
            x_i = batch_input[i:i+1]  # [1, 5000, 16]
            z_i = model(x_i)  # [1, 5000, 16]
            outputs[i] = z_i.squeeze(0)  # [5000, 16]

    return outputs

if __name__ == "__main__":
    batchsize = 4
    seq_len = 5000
    dim = 16
    x = torch.randn(batchsize, seq_len, dim)
    outputs = run_realnvp_batch(x)
    print(f"Output shape: {outputs.shape}")

