import os
import torch

# import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from utils import device

# __all__ = ['VAE','vae_retrain']

class _vae(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.fc21 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc22 = torch.nn.Linear(hidden_dim, latent_dim)
        self.fc3 = torch.nn.Linear(latent_dim, hidden_dim)
        self.fc4 = torch.nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.nn.functional.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encoding(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z

    def decode(self, z):
        h3 = torch.nn.functional.relu(self.fc3(z))
        # return torch.sigmoid(self.fc4(h3))
        return self.fc4(h3)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function_1(recon_x, x, y, mu, logvar, args):
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # find positive and negative pairs
    distance = torch.cdist(y, y, p=2)

    # Find the positive elements (smaller than threshold)
    # mask_pos = distance < args.thred
    # (smaller than avg )
    mask_pos = distance < (distance.mean() - distance.min())/2
    indices_pos = torch.nonzero(mask_pos, as_tuple=True)
    elements_pos = distance[mask_pos]

    # Find the positive elements (smaller than threshold)
    indices_neg = torch.nonzero(1-mask_pos, as_tuple=True)
    elements_neg = distance[not mask_pos]

    # compute triplet loss
    min_num = min(elements_pos.shape[0], elements_neg.shape[0])

    random_indices_pos = torch.randperm(elements_pos.size(0))[:min_num]
    random_indices_neg = torch.randperm(elements_neg.size(0))[:min_num]

    METRIC = torch.log(1 + (elements_pos[random_indices_pos] - elements_neg[random_indices_neg]).exp())

    return BCE + KLD + METRIC

# weighted version
def loss_function(recon_x, x, y, mu, logvar, args):
    # Get the sorted indices
    sorted_indices = torch.argsort(y.squeeze())

    # Create a tensor to hold the ranks
    ranks = torch.empty_like(sorted_indices).to(x.device)

    # Assign ranks based on sorted indices
    ranks[sorted_indices] = torch.arange(len(y)).to(x.device)

    #
    weight = 1 / (args.k * x.shape[0] + ranks)

    #
    BCE = F.mse_loss(recon_x, x, reduction='mean')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # find positive and negative pairs
    distance = torch.cdist(y, y, p=2)
    distance_z = torch.cdist(y, y, p=2)

    # Find the positive elements (a, b)
    value_sorted, _ = torch.sort(distance.view(-1), descending=False)
    thred = value_sorted[int(len(distance.view(-1)) * 0.3)]
    mask_pos = distance < thred
    indices_pos = torch.nonzero(mask_pos, as_tuple=False).to(x.device)

    # Find the negative elements (b, c)
    thred = value_sorted[-int(len(distance.view(-1)) * 0.3)]
    mask_pos = distance > thred
    indices_neg = torch.nonzero(~mask_pos, as_tuple=False).to(x.device)

    # compute triplet loss
    # Extract the second elements (b and c)
    b_elements = indices_pos[:, 1]
    c_elements = indices_neg[:, 0]

    # Find the indices where b == c
    matching_indices = (b_elements.unsqueeze(1) == c_elements).nonzero(as_tuple=True)

    # Extract the matching pairs (a, b, b, c)
    matching_pairs = torch.cat((indices_pos[matching_indices[0]], indices_neg[matching_indices[1]]), dim=1)

    v_pos = distance[matching_pairs[:, 0], matching_pairs[:, 1]]
    v_neg = distance[matching_pairs[:, 2], matching_pairs[:, 3]]

    METRIC = (weight[matching_pairs[:, 0]] * weight[matching_pairs[:, 1]] * weight[matching_pairs[:, 3]] * torch.log(1 + (v_pos - v_neg).exp())).mean()

    return BCE + KLD + METRIC

class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def vae_retrain(model, x, y, logger, args):

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    model.train()
    train_loss = 0
    dataset = CustomDataset(x.detach(), y.detach())
    train_loader = DataLoader(dataset, batch_size=args.vae_batch_size, shuffle=True)

    for epoch in range(args.num_vae_epoch):
        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):
            # batch_x = batch_x.to(device)
            optimizer.zero_grad()
            recon_x, mu, logvar = model(batch_x)
            loss = loss_function(recon_x, batch_x, batch_y, mu, logvar, args)
            # BCE = F.mse_loss(recon_x, batch_x, reduction='mean')
            # KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            # loss = BCE + KLD
            loss.backward()
            train_loss = train_loss + loss.item()
            optimizer.step()
            if batch_idx % 2 == 0:
                # print('Epoch:', epoch, ' batch: ', batch_idx, '')
                Line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    batch_idx, batch_idx * len(batch_x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(batch_x))
                # print(Line)
                logger.info(Line)

        #
        model.eval()
        recon_x,_,_ = model(x.detach())
        BCE = F.mse_loss(recon_x, x, reduction='mean')

        # print('====> Epoch: {} Average loss: {:.4f}'.format(
        #       epoch, train_loss / len(train_loader.dataset)))
        logger.info('====> Epoch: {} reconstruction error: {:.4f} Average loss: {:.4f}'.format(
              epoch, BCE, train_loss / len(train_loader.dataset)))

        model.train()

    return model