import torch
import torch.nn as nn
from mamba_ssm import Mamba
import utils

class Reconst_Mapper_MambaHybrid(nn.Module):
    def __init__(self, in_dim, out_dim, mamba_dim=128, seq_len=100):
        super().__init__()
        self.input_proj = nn.Linear(in_dim, mamba_dim)
        self.mamba = Mamba(d_model=mamba_dim)
        self.norm = nn.LayerNorm(mamba_dim)
        self.mlp = nn.Sequential(
            nn.Linear(mamba_dim, 200),
            nn.ReLU(),
            nn.Linear(200, out_dim)
        )

        utils.init_network_weights(self.input_proj)
        utils.init_network_weights(self.mlp, method='kaiming_uniform_')

    def forward(self, z):
        # ✅ 支持输入 shape: (k_iwae, B, T, D)
        k, b, t, d = z.shape
        x = z.view(k * b, t, d)               # -> (k*B, T, D)

        x = self.input_proj(x)                # -> (k*B, T, mamba_dim)
        res = x
        x = self.mamba(x)                     # -> (k*B, T, mamba_dim)
        x = self.norm(x + res)                # -> (k*B, T, mamba_dim)

        out = self.mlp(x)                     # -> (k*B, T, out_dim)
        out = out.view(k, b, t, -1)           # -> (k, B, T, out_dim)
        return out

class Embedding_MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim * 2, 200),
            nn.ReLU(),
            nn.Linear(200, out_dim)
        )
        utils.init_network_weights(self.layers, method='kaiming_uniform_')

    def forward(self, truth, mask):
        x = torch.cat((truth, mask), -1)
        assert (not torch.isnan(x).any())
        out = self.layers(x)
        return out


class Reconst_Mapper_MLP(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 200),
            nn.ReLU(),
            nn.Linear(200, out_dim)
        )
        utils.init_network_weights(self.layers, method='kaiming_uniform_')

    def forward(self, data):
        truth = self.layers(data)
        return truth


class Z_to_mu_ReLU(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 100),
            nn.ReLU(),
            nn.Linear(100, latent_dim),)
        utils.init_network_weights(self.net, method='kaiming_uniform_')

    def forward(self, data):
        return self.net(data)


class Z_to_std_ReLU(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 100),
            nn.ReLU(),
            nn.Linear(100, latent_dim),
            nn.Softplus(),)
        utils.init_network_weights(self.net)

    def forward(self, data):
        return self.net(data)


class BinaryClassifier(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, 300),
            nn.ReLU(),
            nn.Linear(300, 300),
            nn.ReLU(),
            nn.Linear(300, 1),
        )
        utils.init_network_weights(self.layers, method='kaiming_uniform_')

    def forward(self, x):
        return self.layers(x)
