import torch.nn as nn
import torch

# input noise dimension
nz = 100
# number of generator filters
ngf = 64
# number of discriminator filters
ndf = 64


class Generator(nn.Module):
    def __init__(self, nc=1):
        super(Generator, self).__init__()

        self.gan_type = "dcgan"
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 4 x 4
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 8 x 8
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 16 x 16
            # state size. (nc) x 32 x 32
        )
        if nc == 1:
            self.layer = nn.Sequential(
                nn.ConvTranspose2d(ngf, ngf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf),
                nn.ReLU(True),
                nn.ConvTranspose2d(ngf, nc, kernel_size=1, stride=1, padding=2, bias=False),
                nn.Tanh(),
            )
        else:
            self.layer = nn.Sequential(
                nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
                nn.Tanh(),
            )

    def forward(self, input):
        input = input.view(-1, nz, 1, 1)
        output = self.main(input)
        output = self.layer(output)
        return output


class Discriminator(nn.Module):
    def __init__(
        self,
        act=False,
        eps=1,
        norm=False,
        bn=False,
        clip=False,
        nc=3,
    ):
        super(Discriminator, self).__init__()

        self.gan_type = "dcgan"
        if act == "sigmoid":
            self.activation = nn.Sigmoid()
        elif act == "tanh":
            self.activation = nn.Tanh()
        elif act == "softplus":
            self.activation = nn.Softplus()
        elif act == "relu":
            self.activation = nn.ReLU()
        else:
            self.activation = None
        if bn:
            self.bn = nn.BatchNorm2d(1)
        else:
            self.bn = None
        self.act = act
        self.eps = eps
        self.norm = norm
        # print(self.norm)
        self.clip = clip
        self.main = nn.Sequential(
            # input is (nc) x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 16 x 16
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 8 x 8
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Sigmoid(),
            # state size. (ndf*4) x 4 x 4
            # self.layer = nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)
            # nn.Sigmoid(),
        )
        if nc == 3:
            self.layer = nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False)

        else:
            self.layer = nn.Conv2d(ndf * 4, 1, 4, 2, 1, bias=False)

    def forward(self, input):
        output = self.main(input)
        output = self.layer(output)
        return output.view(-1, 1).squeeze(1)
