import torch
import torch.nn as nn

# --------- MLP for Tabular Data (已用)
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.net(x)

# --------- LSTM for Tabular Data
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=1, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        # x shape: [batch, features] => reshape to [batch, seq, features]
        if x.dim() == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        # Take the last output
        out = out[:, -1, :]
        out = self.fc(out)
        return out

# --------- Transformer for Tabular Data
class TransformerClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, nhead=4, num_layers=2, dropout=0.2):
        super().__init__()
        self.embedding = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        # x shape: [batch, features] => [batch, seq, features]
        if x.dim() == 2:
            x = x.unsqueeze(1)
        x = self.embedding(x)
        out = self.transformer_encoder(x)
        out = out[:, -1, :]
        out = self.fc(out)
        return out

# --------- Simple CNN for Tabular Data (1D)
class CNNClassifier(nn.Module):
    def __init__(self, input_dim, output_dim, n_channels=16, kernel_size=3, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv1d(1, n_channels, kernel_size=kernel_size, padding=1)
        self.bn1 = nn.BatchNorm1d(n_channels)
        self.conv2 = nn.Conv1d(n_channels, n_channels, kernel_size=kernel_size, padding=1)
        self.bn2 = nn.BatchNorm1d(n_channels)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(n_channels * input_dim, output_dim)
    def forward(self, x):
        # x shape: [batch, features] => [batch, 1, features]
        if x.dim() == 2:
            x = x.unsqueeze(1)
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
