import os
import torch

# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = torch.nn.Identity(input_dim)
        # self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        # self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        # self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        # self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

    def encode(self, x):

        return x

    def encoding(self, x):

        return x

    def decode(self, z):
        # h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return z

    def forward(self, x):

        return self.fc1(x), x, 0.0001


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, args):
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    return model