import torch
import torch.nn as nn
import torch.nn.functional as F


class DynamicMLP(nn.Module):
    def __init__(self, input_size, layer_sizes, output_size, activations, use_bn=False, dropout_rate=0.5):
        """
        初始化动态MLP模型。

        参数:
        - input_size: 输入特征的维度。
        - layer_sizes: 每个隐藏层的大小列表。
        - output_size: 输出层的大小。
        - activations: 每层使用的激活函数列表，例如 ['relu', 'tanh']。
        - use_bn: 是否使用批归一化。
        - dropout_rate: Dropout率。
        """
        super(DynamicMLP, self).__init__()
        self.use_bn = use_bn
        self.dropout_rate = dropout_rate

        self.layers = nn.ModuleList()
        last_size = input_size

        # 定义隐藏层
        for idx, (layer_size, activation) in enumerate(zip(layer_sizes, activations)):
            self.layers.append(nn.Linear(last_size, layer_size))
            if use_bn:
                self.layers.append(nn.BatchNorm1d(layer_size))
            self.layers.append(self.get_activation(activation))
            if dropout_rate > 0:
                self.layers.append(nn.Dropout(dropout_rate))
            last_size = layer_size

        # 定义输出层
        self.output_layer = nn.Linear(last_size, output_size)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

    def get_activation(self, name):
        """
        根据名称返回相应的激活函数层。
        """
        if name == 'relu':
            return nn.ReLU()
        elif name == 'tanh':
            return nn.Tanh()
        elif name == 'sigmoid':
            return nn.Sigmoid()
        else:
            raise ValueError(f"Unsupported activation: {name}")

