import torch
import torch.nn as nn
from copy import deepcopy


class GeneratorFemnist(nn.Module):
    '''
    The generator for Femnist dataset
    '''
    def __init__(self, noise_dim=100):
        super(GeneratorFemnist, self).__init__()

        module_list = []
        module_list.append(
            nn.Linear(in_features=noise_dim,
                      out_features=4 * 4 * 256,
                      bias=False))
        module_list.append(nn.BatchNorm1d(num_features=4 * 4 * 256))
        module_list.append(nn.ReLU())
        self.body1 = nn.Sequential(*module_list)

        # need to reshape the output of self.body1

        module_list = []

        module_list.append(
            nn.ConvTranspose2d(in_channels=256,
                               out_channels=128,
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               bias=False))
        module_list.append(nn.BatchNorm2d(128))
        module_list.append(nn.ReLU())
        self.body2 = nn.Sequential(*module_list)

        module_list = []
        module_list.append(
            nn.ConvTranspose2d(in_channels=128,
                               out_channels=64,
                               kernel_size=(3, 3),
                               stride=(2, 2),
                               bias=False))
        module_list.append(nn.BatchNorm2d(64))
        module_list.append(nn.ReLU())
        self.body3 = nn.Sequential(*module_list)

        module_list = []
        module_list.append(
            nn.ConvTranspose2d(in_channels=64,
                               out_channels=1,
                               kernel_size=(4, 4),
                               stride=(2, 2),
                               bias=False))
        module_list.append(nn.BatchNorm2d(1))
        module_list.append(nn.Tanh())
        self.body4 = nn.Sequential(*module_list)

    def forward(self, x):

        tmp1 = self.body1(x).view(-1, 256, 4, 4)

        assert tmp1.size()[1:4] == (256, 4, 4)

        tmp2 = self.body2(tmp1)
        assert tmp2.size()[1:4] == (128, 6, 6)

        tmp3 = self.body3(tmp2)

        assert tmp3.size()[1:4] == (64, 13, 13)

        tmp4 = self.body4(tmp3)
        assert tmp4.size()[1:4] == (1, 28, 28)

        return tmp4
