import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict


class _Dense(nn.Module):
    def __init__(self, hidden, input_dim):
        super(_Dense, self).__init__()
        self.hidden = hidden
        self.model = torch.nn.ModuleList([])
        for i, dim in enumerate(hidden):
            self.model.append(nn.Linear(input_dim if i == 0 else hidden[i - 1], dim))
            self.model.append(nn.ReLU())

    def forward(self, X):
        return self.model(X)

class _Res(_Dense):
    def forward(self, X):
        x_ = X + 0
        for idx, units in enumerate(self.hidden):
            x_ = self.model[idx*2](x_)
            if idx < (len(self.hidden) - 1):
                x_ = self.model[idx*2 + 1](x_) # activation
        return self.model[-1](x_ + X)


class _D2V(nn.Module):
    def __init__(self):
        super(_D2V, self).__init__()
        self.prepool_layers = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(),
            _Res(hidden=[128, 128, 128], input_dim=128),
            nn.Linear(128, 128),
            nn.ReLU()
        )

        self.postpool_layers = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            _Res(hidden=[128, 128, 128], input_dim=128),
            nn.Linear(128, 128),
            nn.ReLU()
        )

        self.prediction_layers = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()
        )

    def forward(self, X, Y):
        list_feat = []
        for x, y in zip(X, Y):
            n, d = x.size()
            l = y.size(1)
            # x.repeat_interleave(l, dim=0).flatten()
            # y.unsqueeze(2).repeat_interleave(d, dim=2).view(n, l * d).flatten()
            feat = torch.stack([
                x.repeat_interleave(l, dim=0).flatten(),
                y.unsqueeze(2).repeat_interleave(d, dim=2).view(n, l * d).flatten()], dim=1
            )

            feat = self.prepool_layers(feat)
            feat = feat.view(n, d, l, 128).mean(0)

            feat = self.postpool_layers(feat)
            list_feat.append(feat.mean(0).mean(0))
        list_feat = torch.stack(list_feat, dim=0)
        return self.prediction_layers(list_feat)


class D2V_ours(nn.Module):
    def __init__(self, fc_metafeatures, dropout_fc):
        super(D2V_ours, self).__init__()
        self.list_module = torch.nn.ModuleList()

        # Add D2V block
        self.model = _D2V()
        self.list_fc = torch.nn.ModuleList()
        self.list_bn_fc = torch.nn.ModuleList()
        self.list_dropout_fc = torch.nn.ModuleList()

        for i, dim in enumerate(fc_metafeatures):
            self.list_fc.append(nn.Linear(128 if i == 0 else fc_metafeatures[i-1], dim))
            self.list_bn_fc.append(nn.BatchNorm1d(dim))
            self.list_dropout_fc.append(nn.Dropout(dropout_fc[i]))

    def forward(self, X, lab):
        z = self.model(X, lab)

        # Fully Connected
        for i, layer in enumerate(self.list_fc):
            z = layer(z)
            if i != len(self.list_fc) - 1:
                z = self.list_bn_fc[i](z)
                z = F.relu(z)
                z = self.list_dropout_fc[i](z)
        return z
