import numpy as np
from torch import nn
from torch.nn import functional as F

from manifold_flow.utils import various


class UNet(nn.Module):
    def __init__(self, in_features, max_hidden_features, num_layers, out_features, nonlinearity=F.relu):
        super().__init__()

        assert various.is_power_of_two(max_hidden_features), "'max_hidden_features' must be a power of two."
        assert max_hidden_features // 2 ** num_layers > 1, "'num_layers' must be {} or fewer".format(int(np.log2(max_hidden_features) - 1))

        self.nonlinearity = nonlinearity
        self.num_layers = num_layers

        self.initial_layer = nn.Linear(in_features, max_hidden_features)

        self.down_layers = nn.ModuleList([nn.Linear(in_features=max_hidden_features // 2 ** i, out_features=max_hidden_features // 2 ** (i + 1)) for i in range(num_layers)])

        self.middle_layer = nn.Linear(in_features=max_hidden_features // 2 ** num_layers, out_features=max_hidden_features // 2 ** num_layers)

        self.up_layers = nn.ModuleList(
            [nn.Linear(in_features=max_hidden_features // 2 ** (i + 1), out_features=max_hidden_features // 2 ** i) for i in range(num_layers - 1, -1, -1)]
        )

        self.final_layer = nn.Linear(max_hidden_features, out_features)

    def forward(self, inputs):
        temps = self.initial_layer(inputs)
        temps = self.nonlinearity(temps)

        down_temps = []
        for layer in self.down_layers:
            temps = layer(temps)
            temps = self.nonlinearity(temps)
            down_temps.append(temps)

        temps = self.middle_layer(temps)
        temps = self.nonlinearity(temps)

        for i, layer in enumerate(self.up_layers):
            temps += down_temps[self.num_layers - i - 1]
            temps = self.nonlinearity(temps)
            temps = layer(temps)

        return self.final_layer(temps)
