from torch import nn
import learn2learn as l2l


def one_layer_net(n_in=16, n_out=5):
    return nn.Sequential(
        nn.Linear(n_in, n_out),
    )


def two_layer_net(n_in=16, n_out=5):
    return nn.Sequential(
        nn.Linear(n_in, 16),
        nn.BatchNorm1d(16),
        nn.Sigmoid(),
        nn.Linear(16, n_out),
    )


def three_layer_net(n_in=16, n_out=5):
    return nn.Sequential(
        nn.Linear(n_in, 16),
        nn.BatchNorm1d(16),
        nn.Sigmoid(),
        nn.Linear(16, 16),
        nn.BatchNorm1d(16),
        nn.Sigmoid(),
        nn.Linear(16, n_out),
    )


def simple_layer_net(n_in=16, n_out=5, hidden=[16], activation=None):
    if activation is None:
        a_func = nn.Sigmoid
    elif activation == 'sigmoid':
        a_func = nn.Sigmoid
    elif activation == 'relu':
        a_func = nn.ReLU
    else:
        assert activation == 'lrelu'
        a_func = nn.LeakyReLU

    w_prev = n_in
    if len(hidden) == 0:
        return one_layer_net(n_in, n_out)
    
    w = hidden[0]
    nns = [nn.Linear(n_in, w)]
    w_prev = w
    for w in hidden[1:]:
        nns.extend([nn.BatchNorm1d(w_prev),
                    a_func(),
                    nn.Linear(w_prev, w)])
        w_prev = w

    w = n_out
    nns.extend([nn.BatchNorm1d(w_prev),
                a_func(),
                nn.Linear(w_prev, w)])

    return nn.Sequential(*nns)


class ConvFeature(nn.Module):

    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super().__init__()
        self.encoder = l2l.vision.models.CNN4Backbone(
            hidden_size=hid_dim,
            channels=x_dim,
            max_pool=True,
       )
        self.out_channels = 1600

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

