import torch
import torch.nn as nn
import math

class GatingNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, sigma=0.5):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.sigma = sigma
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, input_dim)
        )
        self.apply(self.init_weights)
        self._sqrt_2 = math.sqrt(2)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, std=0.001)
            if hasattr(m, "bias") and m.bias is not None:
                m.bias.data.fill_(0.1)

    def forward(self, x):
        noise = torch.normal(0, self.sigma, size=x.size(), device=x.device) if self.training else 0
        mu = self.net(x)
        z = mu + 0.1 * noise * self.training
        gates = self.hard_sigmoid(z)
        sparse_x = x * gates
        return mu, sparse_x, gates

    @staticmethod
    def hard_sigmoid(x):
        return torch.clamp(x + 0.5, 0.0, 1.0)

    def regularization(self, mu, reduction_func=torch.sum):
        return reduction_func(0.5 + 0.5 * torch.erf(mu / (self.sigma * self._sqrt_2)))

    def get_gates(self, x):
        with torch.no_grad():
            gates = self.hard_sigmoid(self.net(x))
        return gates

    def num_open_gates(self, x):
        return self.get_gates(x).sum(dim=1).cpu().median(dim=0)[0].item()

class OSDT(nn.Module):
    def __init__(self, input_dim, depth, bin_fn=torch.sigmoid, gating_net=None):
        super().__init__()
        self.input_dim = input_dim
        self.depth = depth
        self.gating_net = gating_net
        self.split_weights = nn.Parameter(torch.randn(depth, input_dim))
        self.feature_thresholds = nn.Parameter(torch.randn(depth))
        self.log_temperatures = nn.Parameter(torch.zeros(depth))
        indices = torch.arange(2 ** depth)
        bits = ((indices[:, None] & (1 << torch.arange(depth))) > 0).float()
        self.register_buffer('bin_codes', bits.t())

    def forward(self, x):
        all_selected = []
        for i in range(self.depth):
            x_gated = self.gating_net(x)[1] if self.gating_net else x
            split_score = torch.einsum('bi,i->b', x_gated, self.split_weights[i])
            all_selected.append(split_score)
        selected_features = torch.stack(all_selected, dim=1)
        logits = (selected_features - self.feature_thresholds) * torch.exp(-self.log_temperatures)
        logits = torch.stack([-logits, logits], dim=-1)
        bins = torch.sigmoid(logits)
        bins_left = bins[..., 0]
        bins_right = bins[..., 1]
        bin_codes = self.bin_codes.unsqueeze(0)
        bin_match = bins_left.unsqueeze(-1) * (1 - bin_codes) + bins_right.unsqueeze(-1) * bin_codes
        path_probs = bin_match.prod(dim=-2)
        return path_probs

class OSDTEncoder(nn.Module):
    def __init__(self, input_dim, num_trees, depth, gating_net=None):
        super().__init__()
        self.trees = nn.ModuleList([
            OSDT(input_dim, depth, gating_net=gating_net) for _ in range(num_trees)
        ])
        self.output_dim = 2 ** depth

    def forward(self, x):
        outputs = [tree(x) for tree in self.trees]
        outputs = torch.stack(outputs, dim=0)
        return outputs.mean(dim=0)

class ModularEncoder(nn.Module):
    def __init__(self, input_dim, hidden_layers):
        super().__init__()
        layers = []
        current_in = input_dim
        for h in hidden_layers:
            layers.append(nn.Linear(current_in, h))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.LeakyReLU())
            current_in = h
        self.encoder = nn.Sequential(*layers)
        self.output_dim = hidden_layers[-1]

    def forward(self, x):
        return self.encoder(x)

class ModularDecoder(nn.Module):
    def __init__(self, output_dim, hidden_layers):
        super().__init__()
        layers = []
        current_in = hidden_layers[-1]
        # Add all but last layer (with BatchNorm and activation)
        for h in reversed(hidden_layers[:-1]):
            layers.append(nn.Linear(current_in, h))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.LeakyReLU())
            current_in = h
        # Output layer (no batch norm or activation)
        layers.append(nn.Linear(current_in, output_dim))
        self.decoder = nn.Sequential(*layers)

    def forward(self, z):
        return self.decoder(z)

class SSAE(nn.Module):
    """
    Semi-supervised Autoencoder: latent size must be 2 ** osdt_depth.
    """
    def __init__(self, input_dim, hidden_layers, osdt_depth):
        super().__init__()
        assert hidden_layers[-1] == 2 ** osdt_depth, \
            f"The final latent dimension ({hidden_layers[-1]}) must equal 2**osdt_depth ({2 ** osdt_depth})."
        self.encoder = ModularEncoder(input_dim, hidden_layers)
        self.decoder = ModularDecoder(input_dim, hidden_layers)

    def forward(self, x):
        z = self.encoder(x)
        x_rec = self.decoder(z)
        return x_rec

class Tandem(nn.Module):
    """
    TANDEM: Both encoders must end in 2 ** osdt_depth, decoder expects same latent size.
    Returns both reconstructions and both latent vectors.
    """
    def __init__(
        self, input_dim,
        hidden_layers,
        num_trees=1, osdt_depth=7,
        gating_net=None
    ):
        super().__init__()
        assert hidden_layers[-1] == 2 ** osdt_depth, \
            f"The final latent dimension ({hidden_layers[-1]}) must equal 2**osdt_depth ({2 ** osdt_depth})."
        self.nn_encoder = ModularEncoder(input_dim, hidden_layers)
        self.osdt_encoder = OSDTEncoder(input_dim, num_trees, osdt_depth, gating_net=gating_net)
        self.decoder = ModularDecoder(input_dim, hidden_layers)

    def forward(self, x):
        z_nn = self.nn_encoder(x)
        z_osdt = self.osdt_encoder(x)
        x_rec_nn = self.decoder(z_nn)
        x_rec_osdt = self.decoder(z_osdt)
        return x_rec_nn, x_rec_osdt, z_nn, z_osdt

# ==== Example Usage ====

if __name__ == "__main__":
    input_dim = 48
    osdt_depth = 7
    final_latent = 2 ** osdt_depth
    hidden_layers = [128, 256, 128, final_latent]
    num_trees = 2
    batch_size = 16

    gating_net = GatingNet(input_dim, hidden_dim=128)

    ssae = SSAE(input_dim=input_dim, hidden_layers=hidden_layers, osdt_depth=osdt_depth)
    tandem = Tandem(
        input_dim=input_dim,
        hidden_layers=hidden_layers,
        num_trees=num_trees,
        osdt_depth=osdt_depth,
        gating_net=gating_net
    )

    x = torch.randn(batch_size, input_dim)
    recon_ssae = ssae(x)
    recon_nn, recon_osdt, z_nn, z_osdt = tandem(x)

    # ----- Reconstruction Loss Checks -----
    criterion = nn.MSELoss()
    loss_ssae = criterion(recon_ssae, x)
    loss_nn = criterion(recon_nn, x)
    loss_osdt = criterion(recon_osdt, x)
    total_tandem_loss = loss_nn + loss_osdt

    print("=== SHAPES ===")
    print("recon_ssae shape:", recon_ssae.shape)
    print("recon_nn shape:", recon_nn.shape)
    print("recon_osdt shape:", recon_osdt.shape)
    print("z_nn shape:", z_nn.shape)
    print("z_osdt shape:", z_osdt.shape)
    print("")

    print("=== RECONSTRUCTION LOSSES ===")
    print(f"SSAE loss:        {loss_ssae.item():.6f}")
    print(f"Tandem NN loss:   {loss_nn.item():.6f}")
    print(f"Tandem OSDT loss: {loss_osdt.item():.6f}")
    print(f"Tandem Total:     {total_tandem_loss.item():.6f}")