import os, sys
CODE_HOME = os.path.normpath(os.path.join(os.getcwd()))
sys.path.append(CODE_HOME)

import torch
import torch.nn as nn

from ..model_abstract import *
# from ..util import init_params
from .block import prob_block

class VFAE_Adult(VFAE_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.x_dim = cfg['train_info']['x_dim']
        self.embed_data = nn.Sequential(
            nn.Linear(self.x_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),
            )

        self.encoder = nn.Sequential(
            prob_block(100 + self.y_dim + self.s_dim, self.z_dim),
        )

        self.decoder_z1 = nn.Sequential(
            nn.Linear(self.z_dim + self.y_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),

            prob_block(100, self.z_dim)
            )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.z_dim + self.s_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),

            nn.Linear(100, self.x_dim),
            nn.Sigmoid()
            )

        self.disc = nn.Sequential(
            nn.Linear(self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 1),
            )


        self.encoder_trainable = [self.encoder, self.embed_data]
        self.decoder_trainable = [self.decoder_z1, self.decoder]
        self.disc_trainable = [self.disc]

class VFAE_Health(VFAE_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        self.x_dim = cfg['train_info']['x_dim']
        self.embed_data = nn.Sequential(
            nn.Linear(self.x_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),
            )

        self.encoder = nn.Sequential(
            prob_block(100 + self.y_dim + self.s_dim, self.z_dim),
        )

        self.decoder_z1 = nn.Sequential(
            nn.Linear(self.z_dim + self.y_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),

            prob_block(100, self.z_dim)
            )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.z_dim + self.s_dim, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(True),

            nn.Linear(100, self.x_dim),
            nn.Sigmoid()
            )

        self.disc = nn.Sequential(
            nn.Linear(self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 4*self.z_dim),
            nn.ReLU(True),

            nn.Linear(4*self.z_dim, 1),
            )

        self.encoder_trainable = [self.encoder, self.embed_data]
        self.decoder_trainable = [self.decoder_z1, self.decoder]
        self.disc_trainable = [self.disc]

class WFAE_MNIST(WFAE_abstract):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        ecfg = compose(config_name="train", overrides=["train_info=pretrain_MNIST"])
        mod = importlib.import_module('src.model.' + ecfg['train_info']['model'])
        mod_attr = getattr(mod, ecfg['train_info']['architecture'])
        self.embed_network = mod_attr.load_from_checkpoint(cfg["train_info"]["path_ckpt"], cfg = ecfg, log = None, verbose = 0)
        self.embed_network.eval()

        d = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(1, d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 4*d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),

            nn.Conv2d(4*d, 8*d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Flatten(),
            nn.Linear(8*d, self.z_dim)
            )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.y_dim + self.z_dim, 49*4*d),
            nn.Unflatten(1, (4*d, 7, 7)),
            
            nn.ConvTranspose2d(4*d, 2*d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(2*d),
            nn.LeakyReLU(inplace=True),

            nn.ConvTranspose2d(2*d, d, kernel_size = 4, stride = 2, padding = 1),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(inplace=True),

            nn.Conv2d(d, d, kernel_size = 4, padding = 'same'),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(inplace=True),
            
            # reconstruction
            nn.Conv2d(d, 1, kernel_size = 4, padding = 'same'),
            nn.Sigmoid(),
            
            )

        self.disc = nn.Sequential(
            nn.Linear(self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 1),
            )

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder]
        self.disc_trainable = [self.disc]

class WFAE_eYaleB(WFAE_attr):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        ecfg = compose(config_name="train", overrides=["train_info=pretrain_eYaleB"])
        mod = importlib.import_module('src.model.' + ecfg['train_info']['model'])
        mod_attr = getattr(mod, ecfg['train_info']['architecture'])
        self.embed_network = mod_attr.load_from_checkpoint(cfg["train_info"]["path_ckpt"], cfg = ecfg, log = None, verbose = 0)
        self.embed_network.eval()

        d = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(1, d//2, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(d//2),
            nn.ReLU(True),

            nn.Conv2d(d//2, d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(d),
            nn.ReLU(True),

            nn.Conv2d(d, 2*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 4*d, kernel_size = 3, stride = 2, padding = 1),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),

            nn.Conv2d(4*d, 8*d, kernel_size = 3, stride = 2, padding = 1),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Flatten(),
            nn.Linear(16*8*d, self.z_dim),
            )

        self.decoder_z1 = nn.Sequential(
            nn.Linear(self.z_dim + self.y_dim, 50),
            nn.BatchNorm1d(50),
            nn.LeakyReLU(inplace=True),

            nn.Linear(50, self.z_dim + self.y_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.z_dim + self.y_dim + self.s_dim, 64*16*d),
            nn.Unflatten(1, (16*d, 8, 8)),
            
            nn.ConvTranspose2d(16*d, 8*d, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
            nn.BatchNorm2d(8*d),
            nn.LeakyReLU(inplace=True),
            
            nn.ConvTranspose2d(8*d, 4*d, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
            nn.BatchNorm2d(4*d),
            nn.LeakyReLU(inplace=True),

            nn.Conv2d(4*d, 4*d, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(4*d),
            nn.LeakyReLU(inplace=True),
            
            nn.ConvTranspose2d(4*d, 2*d, kernel_size = 5, stride = 2, padding = 2, output_padding = 1),
            nn.BatchNorm2d(2*d),
            nn.LeakyReLU(inplace=True),
            
            nn.ConvTranspose2d(2*d, d, kernel_size = 5, stride = 2, padding = 2, output_padding = 1),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(inplace=True),

            nn.Conv2d(d, d, kernel_size = 5, padding = 2),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(inplace=True),
            
            # reconstruction
            nn.Conv2d(d, 1, kernel_size = 5, padding = 2),
            nn.Sigmoid(),
            )

        self.disc = nn.Sequential(
            nn.Linear(self.z_dim, 16),
            nn.ReLU(True),

            nn.Linear(16, 16),
            nn.ReLU(True),

            nn.Linear(16, 16),
            nn.ReLU(True),

            nn.Linear(16, 16),
            nn.ReLU(True),

            nn.Linear(16, 1),
            )

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder, self.decoder_z1]
        self.disc_trainable = [self.disc]

class WFAE_vggface2(WFAE_attr):
    def __init__(self, cfg, log, verbose = 1):
        super().__init__(cfg, log, verbose)
        ecfg = compose(config_name="train", overrides=["train_info=pretrain_vggface2"])
        mod = importlib.import_module('src.model.' + ecfg['train_info']['model'])
        mod_attr = getattr(mod, ecfg['train_info']['architecture'])
        self.embed_network = mod_attr.load_from_checkpoint(cfg["train_info"]["path_ckpt"], cfg = ecfg, log = None, verbose = 0)
        self.embed_network.eval()
        self.iter_per_epoch = cfg['train_info']['iter_per_epoch']

        d = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 2*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 2*d, kernel_size = 5, padding = 2),
            nn.BatchNorm2d(2*d),
            nn.ReLU(True),

            nn.Conv2d(2*d, 4*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),
            
            nn.Conv2d(4*d, 4*d, kernel_size = 5, padding = 2),
            nn.BatchNorm2d(4*d),
            nn.ReLU(True),
            
            nn.Conv2d(4*d, 8*d, kernel_size = 5, stride = 2, padding = 2),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Conv2d(8*d, 8*d, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(8*d),
            nn.ReLU(True),
            
            nn.Conv2d(8*d, 16*d, kernel_size = 3, stride = 2, padding = 1),
            nn.BatchNorm2d(16*d),
            nn.ReLU(True),

            nn.Conv2d(16*d, 16*d, kernel_size = 3, padding = 1),
            nn.BatchNorm2d(16*d),
            nn.ReLU(True),
            
            nn.Flatten(),
            nn.Linear(64*16*d, self.z_dim),
            )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.z_dim + self.y_dim, 64*16*d),
            nn.Unflatten(1, (16*d, 8, 8)),
            
            nn.ConvTranspose2d(16*d, 8*d, kernel_size = 5, stride = 2, padding = 2, output_padding = 1),
            nn.BatchNorm2d(8*d),
            nn.LeakyReLU(inplace=True),

            res_block(8*d, 5, True),
            
            nn.ConvTranspose2d(8*d, 4*d, kernel_size = 5, stride = 2, padding = 2, output_padding = 1), 
            nn.BatchNorm2d(4*d),
            nn.LeakyReLU(inplace=True),

            res_block(4*d, 5, True),
            
            nn.ConvTranspose2d(4*d, 2*d, kernel_size = 5, stride = 2, padding = 2, output_padding = 1), 
            nn.BatchNorm2d(2*d),
            nn.LeakyReLU(inplace=True),
            
            res_block(2*d, 3, True),
            
            nn.ConvTranspose2d(2*d, d, kernel_size = 3, stride = 2, padding = 1, output_padding = 1),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(inplace=True),

            res_block(d, 3, True),
            
            # reconstruction
            nn.Conv2d(d, 3, kernel_size = 3, padding = 1), 
            nn.Sigmoid(),
            )

        self.disc = nn.Sequential(
            nn.Linear(self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 8*self.z_dim),
            nn.ReLU(True),

            nn.Linear(8*self.z_dim, 1),
            )

        self.encoder_trainable = [self.encoder]
        self.decoder_trainable = [self.decoder]
        self.disc_trainable = [self.disc]

    def encode(self, x):
        self.embed_network.eval()
        with torch.no_grad():
            y, z = self.embed_network.encode(x)
        return torch.cat((self.encoder(x), y, z.sigmoid()), dim = 1)