from torch import nn
from torch.utils.data import DataLoader
import torch


class NeuralDiversitySurrogate(nn.Module):
    def __init__(self, num_layer=3, num_feat=0, num_hidden=32, dropout_ratio=0.4, activation='relu'):
        super(NeuralDiversitySurrogate, self).__init__()
        self.num_layer = num_layer
        self.num_feat = num_feat
        self.num_hidden = num_hidden
        self.dropout_ratio = dropout_ratio
        self.activation = activation

        self.dense_layers = nn.ModuleList(
            [nn.Linear(num_feat if i == 0 else num_hidden, num_hidden) for i in range(num_layer)])
        self.bn = nn.ModuleList([nn.LayerNorm(num_hidden).double() for _ in range(num_layer)])
        if activation == 'tanh':
            self.activations = nn.ModuleList([nn.Tanh() for _ in range(num_layer)])
        elif activation == 'relu':
            self.activations = nn.ModuleList([nn.ReLU(inplace=True) for _ in range(num_layer)])
        else:
            raise ValueError('Invalid activation function!')

        self.dropout_layers = nn.ModuleList([nn.Dropout(p=dropout_ratio) for _ in range(num_layer)])

        self.final_fc = nn.Linear(num_hidden * 2, 1)

    def shared_forward(self, x):
        x_ = self.activations[0](self.bn[0](self.dense_layers[0](x)))
        x_ = self.dropout_layers[0](x_)
        for i in range(1, self.num_layer):
            x_ = self.activations[i](self.bn[i](self.dense_layers[i](x_)))
            x_ = self.dropout_layers[i](x_)

        return x_

    def forward(self, c1, c2):
        c1_embedding = self.shared_forward(c1)
        c2_embedding = self.shared_forward(c2)
        embedding = torch.cat([c1_embedding, c2_embedding], dim=1)
        return self.final_fc(embedding)


def train(model, data, val_data=None, batch_size=256, lr=1e-4, num_epoch=100):
    train_data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1)
    val_data_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=1)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_func = nn.MSELoss()

    for i in range(num_epoch):
        model.train()
        train_loss = 0
        for c1, c2, target in train_data_loader:
            optimizer.zero_grad()
            pred = model(c1, c2)
            loss = loss_func(pred, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * len(c1)
        train_loss = train_loss / len(data)
        print('Epoch %d: Training MSE loss %.5f' % (i, train_loss))
        if val_data is not None:
            model.eval()
            val_loss = 0
            for c1, c2, target in val_data_loader:
                pred = model(c1, c2)
                loss = loss_func(pred, target)
                val_loss += loss.item() * len(c1)
            val_loss = val_loss / len(val_data)
            print('Epoch %d: Valid MSE loss %.5f' % (i, val_loss))
    return model
