from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F


def make_n_orderd_dict(n, input_size, hidden_size, output_size):
    if n == 1:
        return OrderedDict([("fc0", nn.Linear(input_size, output_size)),])
    elif n == 2:
        return OrderedDict(
            [
                ("fc0", nn.Linear(input_size, hidden_size)),
                ("relu0", nn.ReLU()),
                ("fc1", nn.Linear(hidden_size, output_size)),
            ]
        )
    else:
        fc_list = [
            ("fc0", nn.Linear(input_size, hidden_size,)),
            ("relu0", nn.ReLU()),
        ]
        for i in range(1, n):
            if i == n - 1:
                fc_list.append((f"fc{i}", nn.Linear(hidden_size, output_size)))
            else:
                fc_list.append((f"fc{i}", nn.Linear(hidden_size, hidden_size)))
                fc_list.append((f"relu{i}", nn.ReLU()))
        fc_dict = OrderedDict(fc_list)
        return fc_dict


class Flatten(nn.Module):
    def forward(self, x):
        return x.reshape(x.size(0), -1)


class Conv2dSame(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        bias=True,
        padding_layer=nn.ReflectionPad2d,
    ):
        super().__init__()
        ka = kernel_size // 2
        kb = ka - 1 if kernel_size % 2 == 0 else ka
        self.net = torch.nn.Sequential(
            padding_layer((ka, kb, ka, kb)),
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias),
        )

    def forward(self, x):
        return self.net(x)


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            Conv2dSame(in_channels, out_channels, 3),
            nn.ReLU(),
            Conv2dSame(in_channels, out_channels, 3),
        )

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = F.relu(out)
        return out
