# -*- coding: utf-8 -*

import math
import torch.nn as nn
import torch.nn.functional as F

"""
Various Autoencoders (AEs): simple AE, Convolutional AE
Can change layers for all AEs
"""


class Autoencoder(nn.Module):
    def __init__(self, args, size, init_scale=1.0):
        super(Autoencoder, self).__init__()
        self.size = size
        n_units = size[0]*size[1]*size[2]
        self.n_units = n_units
        if args.activation == 'tanh':
            self.encoder = nn.Sequential(
                nn.Linear(n_units, 128),
                nn.Tanh(),
                nn.Linear(128, 64),
                nn.Tanh(),
                nn.Linear(64, 12),
                nn.Tanh(),
                nn.Linear(12, 3)
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.Tanh(),
                nn.Linear(12, 64),
                nn.Tanh(),
                nn.Linear(64, 128),
                nn.Tanh(),
                nn.Linear(128, n_units),
                nn.Tanh()
            )
        elif args.activation == 'sigmoid':
            self.encoder = nn.Sequential(
                nn.Linear(n_units, 128),
                nn.Sigmoid(),
                nn.Linear(128, 64),
                nn.Sigmoid(),
                nn.Linear(64, 12),
                nn.Sigmoid(),
                nn.Linear(12, 3)   
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.Sigmoid(),
                nn.Linear(12, 64),
                nn.Sigmoid(),
                nn.Linear(64, 128),
                nn.Sigmoid(),
                nn.Linear(128, n_units),
                nn.Sigmoid()
            )
        elif args.activation == 'leaky_relu':
            self.encoder = nn.Sequential(
                nn.Linear(n_units, 128),
                nn.LeakyReLU(inplace=True),
                nn.Linear(128, 64),
                nn.LeakyReLU(inplace=True),
                nn.Linear(64, 12),
                nn.LeakyReLU(inplace=True),
                nn.Linear(12, 3)   
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.LeakyReLU(inplace=True),
                nn.Linear(12, 64),
                nn.LeakyReLU(inplace=True),
                nn.Linear(64, 128),
                nn.LeakyReLU(inplace=True),
                nn.Linear(128, n_units),
                nn.LeakyReLU(inplace=True)
            )
        else:
            self.encoder = nn.Sequential(
                nn.Linear(n_units, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, 32),
                nn.ReLU(inplace=True),
                nn.Linear(32, 16),
                nn.ReLU(inplace=True),
                nn.Linear(16, args.num_classes) 
            )
            self.decoder = nn.Sequential(
                nn.Linear(args.num_classes, 16),
                nn.ReLU(inplace=True),
                nn.Linear(16, 32),
                nn.ReLU(inplace=True),
                nn.Linear(32, 64),
                nn.ReLU(inplace=True),
                nn.Linear(64, 128),
                nn.ReLU(inplace=True),
                nn.Linear(128, n_units),
                nn.ReLU(inplace=True)
            )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, init_scale * math.sqrt(2. /n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

                size = m.weight.size()
                fan_out = size[0]
                fan_in = size[1]
                variance = math.sqrt(2.0/(fan_in + fan_out))
                m.weight.data.normal_(0.0, init_scale * variance)

    def forward(self, x):
        x = x.view(-1,self.n_units)
        x_middle = self.encoder(x)
        x_final = self.decoder(x_middle)
        x_final = x_final.view(-1,self.size[0],self.size[1],self.size[2])
        return x_final


class ConvAutoencoder(nn.Module):
    def __init__(self, args, init_scale=1.0):
        super(ConvAutoencoder, self).__init__()
        # [128, 1, 28, 28]
        if args.data == 'MNIST':
            self.encoder_features = nn.Sequential(
                nn.Conv2d(args.input_dim, 40, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(40, 20, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(20, 10, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=7),
            )
            self.decoder_features = nn.Sequential(
                nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(10, 20, kernel_size=3, padding=1),
                nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(20, 40, kernel_size=3, padding=1),
                nn.UpsamplingBilinear2d(scale_factor=7),
                nn.Conv2d(40, args.input_dim, kernel_size=3, padding=1), 
            )
        # [128, 3, 32, 32]
        else:
            self.encoder_features = nn.Sequential(
                nn.Conv2d(args.input_dim, 20, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=4),
                nn.Conv2d(20, 10, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=2),
                nn.Conv2d(10, 5, kernel_size=3, padding=1),
                nn.MaxPool2d(kernel_size=4),
            )
            self.decoder_features = nn.Sequential(
                nn.UpsamplingBilinear2d(scale_factor=4),
                nn.Conv2d(5, 10, kernel_size=3, padding=1),
                nn.UpsamplingBilinear2d(scale_factor=2),
                nn.Conv2d(10, 20, kernel_size=3, padding=1),
                nn.UpsamplingBilinear2d(scale_factor=4),
                nn.Conv2d(20, args.input_dim, kernel_size=3, padding=1), 
            )

        if args.activation == 'tanh':
            self.encoder_classifier = nn.Sequential(
                nn.Linear(4, 8),
                nn.Tanh(),
                nn.Linear(8, 16),
                nn.Tanh(),
                nn.Linear(16, args.num_classes),
            )
            self.decoder_classifier = nn.Sequential(
                nn.Linear(args.num_classes, 16),
                nn.Tanh(),
                nn.Linear(16, 8),
                nn.Tanh(),
                nn.Linear(8, 4),
                nn.Tanh(),
            )
        elif args.activation == 'sigmoid':
            self.encoder_classifier = nn.Sequential(
                nn.Linear(4, 8),
                nn.Sigmoid(),
                nn.Linear(8, 16),
                nn.Sigmoid(),
                nn.Linear(16, args.num_classes),
            )
            self.decoder_classifier = nn.Sequential(
                nn.Linear(args.num_classes, 16),
                nn.Sigmoid(),
                nn.Linear(16, 8),
                nn.Sigmoid(),
                nn.Linear(8, 4),
                nn.Sigmoid(),
            )
        elif args.activation == 'leaky_relu':
            self.encoder_classifier = nn.Sequential(
                nn.Linear(4, 8),
                nn.LeakyReLU(inplace=True),
                nn.Linear(8, 16),
                nn.LeakyReLU(inplace=True),
                nn.Linear(16, args.num_classes),
            )
            self.decoder_classifier = nn.Sequential(
                nn.Linear(args.num_classes, 16),
                nn.LeakyReLU(inplace=True),
                nn.Linear(16, 8),
                nn.LeakyReLU(inplace=True),
                nn.Linear(8, 4),
                nn.LeakyReLU(inplace=True),
            )
        else:
            if args.dropout == 'dropout':
                self.encoder_classifier = nn.Sequential(
                    nn.Linear(4, 8),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(8, 16),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(16, args.num_classes),
                )
            elif args.dropout == 'dropout_1':
                self.encoder_classifier = nn.Sequential(
                    nn.Linear(4, 8),
                    nn.ReLU(inplace=True),
                    nn.Linear(8, 16),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(16, args.num_classes),
                )
            else:
                self.encoder_classifier = nn.Sequential(
                    nn.Linear(5, 10),
                    nn.ReLU(inplace=True),
                    nn.Linear(10, 20),
                    nn.ReLU(inplace=True),
                    nn.Linear(20, args.num_classes),
                )
            self.decoder_classifier = nn.Sequential(
                nn.Linear(args.num_classes, 20),
                nn.ReLU(inplace=True),
                nn.Linear(20, 10),
                nn.ReLU(inplace=True),
                nn.Linear(10, 5),
                nn.ReLU(inplace=True),
            )
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, init_scale * math.sqrt(2. /n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

                size = m.weight.size()
                fan_out = size[0]
                fan_in = size[1]
                variance = math.sqrt(2.0/(fan_in + fan_out))
                m.weight.data.normal_(0.0, init_scale * variance)

    def forward(self, x):
        x1 = self.encoder_features(x)
        x2 = x1.view(-1, 5)
        x_middle = self.encoder_classifier(x2)
        x3 = self.decoder_classifier(x_middle)
        x4 = x3.view(-1, 5, 1, 1)
        x_final = self.decoder_features(x4)
        return x_final
