import os
from turtle import forward
import numpy as np
from PIL import Image
from torch.utils import data
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable

## ---------------------- HVAE ---------------------- ##

class VAE(nn.Module):
    def __init__(self, out_size=(112,112), 
                        fc_hidden1=1024, 
                        fc_hidden2=768, 
                        CNN_embed_dim=256,
                        dropout_rate=0.05):
        super(HVAE, self).__init__()
        self.out_size = out_size
        self.fc_hidden1, self.fc_hidden2, self.CNN_embed_dim = fc_hidden1, fc_hidden2, CNN_embed_dim

        # CNN architechtures
        self.ch1, self.ch2, self.ch3, self.ch4 = 16, 32, 64, 128
        self.k1, self.k2, self.k3, self.k4 = (5, 5), (5, 5), (5, 5), (5, 5)      # 2d kernel size
        self.s1, self.s2, self.s3, self.s4 = (2, 2), (2, 2), (2, 2), (2, 2)      # 2d strides
        self.pd1, self.pd2, self.pd3, self.pd4 = (0, 0), (0, 0), (0, 0), (0, 0)  # 2d padding

        # encoding components
        vgg = models.vgg19_bn(pretrained=True).cuda()
        vgg.avgpool = nn.AdaptiveAvgPool2d((4,4)) # to decrease output size a bit, 512x2x2
        modules = list(vgg.children())[:-1]      # delete the last fc layer.        
        self.vgg = nn.Sequential(*modules)

        # Set the feature extractor to non-trainable
        for module in vgg.features:
            for param in module.parameters():
                param.requires_grad = False    

        self.enc_to_latent = nn.Sequential(
            nn.Linear(512*4*4, self.fc_hidden1),
            nn.BatchNorm1d(self.fc_hidden1, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(self.fc_hidden1, self.fc_hidden2),
            nn.BatchNorm1d(self.fc_hidden2, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate)
        )

        # Latent vectors mu and sigma
        self.mu_layer = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)      # output = CNN embedding latent variables
        self.logvar_layer = nn.Linear(self.fc_hidden2, self.CNN_embed_dim)  # output = CNN embedding latent variables

        # Sampling vector
        self.latent_to_dec = nn.Sequential(
            nn.Linear(self.CNN_embed_dim, self.fc_hidden2),
            nn.BatchNorm1d(self.fc_hidden2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            nn.Linear(self.fc_hidden2, 64 * 4 * 4),
            nn.BatchNorm1d(64 * 4 * 4),
            nn.Dropout(dropout_rate),
            nn.ReLU(inplace=True)

        )

        # Decoder
        self.convTrans6 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=self.k4, stride=self.s4,
                               padding=self.pd4),
            nn.BatchNorm2d(32, momentum=0.01),
            nn.ReLU(inplace=True),
        )
        self.convTrans7 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=32, out_channels=8, kernel_size=self.k3, stride=self.s3,
                               padding=self.pd3),
            nn.BatchNorm2d(8, momentum=0.01),
            nn.ReLU(inplace=True),
        )

        self.convTrans8 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=8, out_channels=3, kernel_size=self.k2, stride=self.s2,
                               padding=self.pd2),
            nn.BatchNorm2d(3, momentum=0.01),
            nn.Sigmoid()    # y = (y1, y2, y3) \in [0 ,1]^3
        )


    def encode(self, x):
        # Extract features
        x = self.vgg(x)  
        x = x.view(x.size(0), -1)  # flatten output of conv

        # FC layers
        x = self.enc_to_latent(x)

        # Get mu, sigma
        mu, logvar = self.mu_layer(x), self.logvar_layer(x)
        
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        # FC Layers
        x = self.latent_to_dec(z).view(-1, 64, 4, 4)

        # Decode 
        x = self.convTrans6(x)
        x = self.convTrans7(x)
        x = self.convTrans8(x)

        # Increase resolution
        x = F.interpolate(x, size=self.out_size, mode='bilinear')
        
        return x

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_reconst = self.decode(z)

        return x_reconst, z, mu, logvar


class DecodeModel(nn.Module):
    def __init__(self, hvae_weights, 
                       num_voxels,
                       **kwargs) -> None:
        super().__init__()

        self.hvae = HVAE(**kwargs).cuda()
        
        # Load pretrained model
        self.hvae.load_state_dict(torch.load(hvae_weights))

        # Set hvae to non-trainable 
        for param in self.hvae.parameters():
            param.requires_grad = False

        # And model to eval
        self.hvae.eval()

        # The neural decoding model, maps the voxel space to the latent space of the decoder
        self.neural_decoder = nn.Sequential(
            nn.Linear(in_features=num_voxels, out_features=1024),
            nn.BatchNorm1d(1024, momentum=0.01),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=512),
            nn.BatchNorm1d(512, momentum=0.01),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=self.hvae.CNN_embed_dim)
        ) 

    # x is the (batched) fMRI data, (-1, self.num_voxels, )
    def forward(self, x):
        z  = self.neural_decoder(x)
        img = self.hvae.decode(z)
        return img, z


