import os
import torch

# import torch.nn as nn
import torch.nn.functional as F

from torch.distributions.multivariate_normal import MultivariateNormal

from gpytorch.kernels.rbf_kernel import RBFKernel

from torch.utils.data import Dataset, DataLoader

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, target_dim=1):
        super().__init__()
        self.latent_dim = latent_dim

        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

        self.fc5 = torch.nn.Linear(target_dim, latent_dim)
        self.lengthscale = torch.nn.Parameter(torch.ones(latent_dim))

    def encode(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoding(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z

    def decode(self, z):
        h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x, y=None):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)

        if y is not None:
            covariance = self.rbf_kernel(self.fc5(y))
        else:
            covariance = None

        return self.decode(z), mu, logvar, covariance

    def rbf_kernel(self, x):
        """
        Compute the RBF kernel between two sets of inputs.
        Args:
            x1: Tensor of shape (n, d)
            x2: Tensor of shape (m, d)
            lengthscale: Lengthscale parameter (scalar or tensor)
        Returns:
            Covariance matrix of shape (n, m)
        """

        batch_size = x.shape[0]

        # Compute the squared Euclidean distance between each pair of points
        dists = torch.cdist(x, x, p=2) ** 2
        # Compute the RBF kernel
        dists = dists.repeat(self.latent_dim, 1, 1)
        lengthscale = self.lengthscale.unsqueeze(-1).unsqueeze(-1).repeat(1, batch_size, batch_size)
        cov_matrix = torch.exp(-dists / (2 * lengthscale ** 2))
        return cov_matrix


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar, covariance):
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    # tmp = torch.diag_embed(logvar.t().exp() + 0.00001)
    posterior = MultivariateNormal(mu.t(), torch.diag_embed(logvar.t().exp() + 0.00001))

    covariance = covariance + 1e-3 * torch.eye(covariance.shape[1]).expand(covariance.shape[0], -1, -1).to(x.device)

    # batch_lower_triangular = torch.tril(covariance).to(device)
    batch_lower_triangular = torch.linalg.cholesky(covariance).to(device)


    # Ensure positive diagonal elements
    # for i in range(batch_lower_triangular.shape[1]):
    #     batch_lower_triangular[:, i, i] = torch.abs(batch_lower_triangular[:, i, i]) + 1e-5 # Adding a small value to avoid zero

    prior = MultivariateNormal(torch.zeros(covariance.shape[0], covariance.shape[1]).to(x.device), scale_tril=batch_lower_triangular)

    KLD = torch.distributions.kl.kl_divergence(posterior, prior).mean()

    return BCE + KLD


class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    train_loss = 0
    dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(dataset, batch_size=args.vae_batch_size, shuffle=True)

    for epoch in range(args.num_vae_epoch):
        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar, covariance = model(batch_x, batch_y)
            loss = loss_function(recon_x, batch_x, mu, logvar, covariance)
            # BCE = F.mse_loss(recon_x, batch_x, reduction='mean')
            # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            # loss = BCE + KLD
            loss.backward()
            train_loss = train_loss + loss.item()
            optimizer.step()
            if batch_idx % 2 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(batch_x))
                # print(Line)
                logger.info(Line)

        #
        model.eval()
        recon_x,_,_,_ = model(x.detach())
        BCE = F.mse_loss(recon_x, x, reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
              epoch, BCE, train_loss / len(train_loader.dataset)))

        model.train()

    return model
