import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class _DeepSets(nn.Module):
    def __init__(self, dim_output_phi, dim_output_rho, with_label=False):
        super(_DeepSets, self).__init__()
        phi = []
        for i, d in enumerate(dim_output_phi):
            phi.append(("Lin_{}".format(i), nn.Linear(in_features=(13 if i == 0 else dim_output_phi[i-1]),
                                                      out_features=d)))
            phi.append(("Relu_{}".format(i), nn.ReLU()))
        self.phi = nn.Sequential(OrderedDict(phi))
        self.bn_phi = nn.BatchNorm1d(dim_output_phi[-1])

        rho = []
        for i, d in enumerate(dim_output_rho):
            rho.append(("Lin_{}".format(i), nn.Linear((dim_output_phi[-1] + 5) if i == 0 else dim_output_rho[i-1], d)))
            rho.append(("Relu_{}".format(i), nn.ReLU()))
        self.rho = nn.Sequential(OrderedDict(rho))
        self.bn_rho = nn.BatchNorm1d(dim_output_rho[-1])

    def forward(self, X, lab=None):
        X = torch.cat([X, lab], dim=2)
        X = self.phi(X)
        X = X.mean(1)
        X = self.bn_phi(X)

        X = torch.cat([X, lab.mean(1)], dim=1)
        return self.rho(X)


class DeepSets(nn.Module):
    def __init__(self, list_dim_output_phi, list_dim_output_rho, fc_metafeatures, dropout_fc):
        super(DeepSets, self).__init__()
        self.list_module = torch.nn.ModuleList()

        # Add DSS block
        self.model = _DeepSets(list_dim_output_phi, list_dim_output_rho, with_label=True)

        self.list_fc = torch.nn.ModuleList()
        self.list_bn_fc = torch.nn.ModuleList()
        self.list_dropout_fc = torch.nn.ModuleList()

        for i, dim in enumerate(fc_metafeatures):
            if i == 0:
                self.list_fc.append(nn.Linear(list_dim_output_rho[-1] if not isinstance(list_dim_output_rho, int) else list_dim_output_rho, dim))
            else:
                self.list_fc.append(nn.Linear(fc_metafeatures[i-1], dim))
            self.list_bn_fc.append(nn.BatchNorm1d(dim))
            self.list_dropout_fc.append(nn.Dropout(dropout_fc[i]))

    def forward(self, X, lab):
        z = self.model(X, lab)

        # Fully Connected
        for i, layer in enumerate(self.list_fc):
            z = layer(z)
            if i != len(self.list_fc) - 1:
                z = self.list_bn_fc[i](z)
                z = F.relu(z)
                z = self.list_dropout_fc[i](z)
        return z
