from collections import OrderedDict
import torch
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, layers, scale=1.0, activation='SiLU'):
        super(MLP, self).__init__()
        self.depth = len(layers) - 1
        self.activation = getattr(torch.nn, activation) 
        self.scale = scale

        layer_list = []
        for i in range(self.depth - 1):
            layer_list.append(
                ("layer_%d" % i, torch.nn.Linear(layers[i], layers[i + 1]))
            )
            layer_list.append(("activation_%d" % i, self.activation()))
        layer_list.append(
            ("layer_%d" % (self.depth - 1), torch.nn.Linear(layers[-2], layers[-1]))
        )
        layer_dict = OrderedDict(layer_list)

        self.layers = torch.nn.Sequential(layer_dict)

    def forward(self, x, label=None):
        if label is not None:
            label = label.reshape(-1, 1) * self.scale
            state = torch.cat((x, label), dim=1)
        else:
            state = x
        out = self.layers(state)
        return out

class MLP_twins(nn.Module):
    def __init__(self, layers_oab, layers_oa, scale=1.0, activation='SiLU'):
        super(MLP_twins, self).__init__()
        self.model_oab = MLP(layers=layers_oab, scale=scale, activation=activation)
        self.model_oa = MLP(layers=layers_oa, scale=scale, activation=activation)

    def forward(self, x, network_mode, label=None):
        if network_mode == 'first':
            return self.model_oab(x, label=label)
        elif network_mode == 'second':
            return self.model_oa(x, label=label)
        else:
            raise ValueError(f"Invalid mode specified: {network_mode}. Expected 'first' or 'second'.")

