import os
import torch

# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoding(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z

    def decode(self, z):
        h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    train_loss = 0
    dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(dataset, batch_size=args.vae_batch_size, shuffle=True)

    for epoch in range(args.num_vae_epoch):
        for batch_idx, (batch_x, _) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(batch_x)
            # loss = loss_function(recon_batch, batch_x, mu, logvar)
            BCE = F.mse_loss(recon_x, batch_x, reduction='mean')
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = BCE + KLD
            loss.backward()
            train_loss = train_loss + loss.item()
            optimizer.step()
            if batch_idx % 2 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(batch_x))
                # print(Line)
                logger.info(Line)

        #
        model.eval()
        recon_x,_,_ = model(x.detach())
        BCE = F.mse_loss(recon_x, x, reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
              epoch, BCE, train_loss / len(train_loader.dataset)))

        model.train()

    return model