import os
import torch

# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoding(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z

    def decode(self, z):
        h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, y, mu, logvar, args):

    # Get the sorted indices
    sorted_indices = torch.argsort(y.squeeze())

    # Create a tensor to hold the ranks
    ranks = torch.empty_like(sorted_indices).to(x.device)

    # Assign ranks based on sorted indices
    ranks[sorted_indices] = torch.arange(len(y)).to(x.device)

    #
    weight = 1 / (args.k * x.shape[0] + ranks)

    BCE = (F.mse_loss(recon_x, x, reduction='none').mean(dim=-1) * weight).sum()

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()).to(x.device)

    return BCE + KLD


class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    train_loss = 0
    dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(dataset, batch_size=args.vae_batch_size, shuffle=True)

    for epoch in range(args.num_vae_epoch):
        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(batch_x)
            loss = loss_function(recon_x, batch_x, batch_y, mu, logvar, args)
            # BCE = F.mse_loss(recon_x, batch_x, reduction='mean')
            # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            # loss = BCE + KLD
            loss.backward()
            train_loss = train_loss + loss.item()
            optimizer.step()
            if batch_idx % 2 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(batch_x))
                # print(Line)
                logger.info(Line)

        #
        model.eval()
        recon_x,_,_ = model(x.detach())
        BCE = F.mse_loss(recon_x, x, reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
              epoch, BCE, train_loss / len(train_loader.dataset)))

        model.train()

    return model