import torch
import torch.nn as nn
import torch.nn.functional as F

class Baseline_linear(nn.Module):
    def __init__(self, size_mf_hc, size_hparams):
        super(Baseline_linear, self).__init__()
        self.layer = nn.Linear(size_mf_hc + size_hparams, 1)

    def forward(self, handcrafted_mf, params):
        assert handcrafted_mf.max().item() <= 4 and handcrafted_mf.max().item() >= -4
        assert params.max().item() <= 1.0 and params.max().item() >= 0.0

        z = torch.cat([handcrafted_mf, params], dim=1)
        z = self.layer(z)
        return z


class Baseline_linear_ranking(nn.Module):
    def __init__(self, size_mf_hc, size_hparams):
        super(Baseline_linear_ranking, self).__init__()
        self.layer = NN_Surrogate(size_mf_hc, size_hparams)

    def forward(self, handcrafted_mf, params_1, params_2):
        assert handcrafted_mf.max().item() <= 4 and handcrafted_mf.max().item() >= -4
        assert params_1.max().item() <= 1.0 and params_1.max().item() >= 0.0
        assert params_2.max().item() <= 1.0 and params_2.max().item() >= 0.0


        z_1 = self.layer(handcrafted_mf, params_1)
        z_2 = self.layer(handcrafted_mf, params_2)
        return torch.cat([z_1, z_2], dim=1)


class NN_Surrogate(nn.Module):
    def __init__(self, size_meta_features, size_params):
        super(NN_Surrogate, self).__init__()

        self.fc_1 = nn.Linear(size_meta_features + size_params, 16)
        self.fc_2 = nn.Linear(16, 1)
        # self.fc_3 = nn.Linear(32, 1)

    def forward(self, mf, hp):
        z = torch.cat([mf, hp], dim=1)
        z = F.relu(self.fc_1(z))
        # z = F.relu(self.fc_2(z))
        return self.fc_2(z)
