# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

align_corners = True
N_latent = 256

class Encoder(nn.Module):

    def __init__(self, resolution, compress_mode=1, resize_dim=None):
        super().__init__()
        
        self.resize_dim = resize_dim
        layers = []
        self.prior_channels = resolution[0]
        prior_channels = self.prior_channels
        multiply = 4 if resolution[0] == 3 else 2
        # Desired effect:
        #
        #             pad|                                      |pad
        # inputs:      0 |1  2  3  4  5  6  7  8  9  10 11 12 13|0  0
        #              |________________|
        #                            |_________________|
        #                                          |________________|
        layers.append(nn.Sequential(
            nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*multiply, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(prior_channels*multiply),
        ))
        prior_channels *= multiply
        
        if compress_mode >= 2:
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*2, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels*2),
            ))
            prior_channels *= 2
            
        if compress_mode >=3:
            layers.append(nn.Sequential(
                nn.Conv2d(in_channels=prior_channels, out_channels=prior_channels*2, kernel_size=4, stride=2, padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels*2),
            ))
            prior_channels *= 2

        layers.append(nn.Sequential(
            nn.BatchNorm2d(prior_channels)
        ))
        
        self.encoder = nn.Sequential(*layers)
        
        img_dim = resize_dim if resize_dim else resolution[1]
        dim_compress = img_dim / (2 ** compress_mode)
        representation = np.prod([prior_channels, dim_compress, dim_compress]).astype(int)
        self.representation = representation
        # self.fc = nn.Linear(representation, N_latent)
        
    def forward(self, x):
        if self.resize_dim:
            x = F.interpolate(x, mode='bilinear', size=(self.resize_dim, self.resize_dim), align_corners=align_corners)
        x = self.encoder(x)
        # x = x.view(-1, self.representation)
        # x = self.fc(x)

        return torch.sigmoid(x)
        # return torch.clamp(x, 0., 1.)


class Decoder(nn.Module):

    def __init__(self, resolution, compress_mode=1, original_dim=224):
        super().__init__()
        
        out_channels = resolution[0]
        if out_channels == 3:
            prior_channels = out_channels * 4 * (2 ** (compress_mode - 1))
        else:
            prior_channels = out_channels * 2 * (2 ** (compress_mode - 1))
        prior_channels = int(prior_channels)
            
        self.resize_dim = original_dim
        self.prior_channels = prior_channels
        
        img_dim = 128 if original_dim else resolution[1]
        dim_compress = int(img_dim / (2 ** compress_mode))
        representation = np.prod([prior_channels, dim_compress, dim_compress]).astype(int)
        self.dim_compress = dim_compress
        
        # self.fc = nn.Linear(N_latent, representation)
        
        layers = []
        
        if compress_mode >=3:
            layers.append(nn.Sequential(
                nn.ConvTranspose2d(prior_channels, prior_channels // 2, kernel_size=4, stride=2, padding=1), #, output_padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels // 2),
            ))
            prior_channels = prior_channels // 2
            
        if compress_mode >=2:
            layers.append(nn.Sequential(
                nn.ConvTranspose2d(prior_channels, prior_channels // 2, kernel_size=4, stride=2, padding=1), # , output_padding=1),
                nn.ReLU(),
                nn.BatchNorm2d(prior_channels // 2),
            ))
            prior_channels = prior_channels // 2
        
        layers.append(nn.Sequential(
            nn.ConvTranspose2d(prior_channels, out_channels, kernel_size=4, stride=2, padding=1), #, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
        ))
        
        self.decoder = nn.Sequential(*layers)


    def forward(self, x):
        # x = self.fc(x)
        # x = x.view(-1, self.prior_channels, self.dim_compress, self.dim_compress)
        x = self.decoder(x)
        # x = torch.sigmoid(x)
        
        if self.resize_dim:
            x = F.interpolate(x, mode='bilinear', size=(self.resize_dim, self.resize_dim), align_corners=align_corners)
        
        return torch.sigmoid(x)
