import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import argparse
import warnings

import os
from tqdm import tqdm
import json
import time
import math

from tabsyn.tabsyn.vae.model import Model_VAE, Encoder_model, Decoder_model, gmm_prob, logsumexp, log_gmm_probs
from tabsyn.utils_train import preprocess, preprocess_from_numpy, preprocess_group_data, TabularDataset

warnings.filterwarnings('ignore')


LR = 1e-3
WD = 0
D_TOKEN = 4
TOKEN_BIAS = True

N_HEAD = 1
FACTOR = 32
NUM_LAYERS = 2

class GMM(nn.Module):
    def __init__(self, latent_dim, n_centroid):
        super(GMM, self).__init__()
        path = '/u3/w3pang/guided_tab_ddpm/tabsyn_workspace/california_20_quick_test/individual/household_individual/vae'
        import os
        weights_path = os.path.join(path, 'weights_.npy')
        means_path = os.path.join(path, 'means_.npy')
        covariances_path = os.path.join(path, 'covariances_.npy')
        # Define the parameters as nn.Parameters to make them leaves in the computational graph
        self.u_p = nn.Parameter(torch.from_numpy(np.load(means_path)))
        self.lambda_p = nn.Parameter(torch.from_numpy(np.load(covariances_path)))
        self.theta_p = nn.Parameter(torch.log(torch.from_numpy(np.load(weights_path))))

    def forward():
        pass

class GroupDataset(Dataset):
    def __init__(self, num_group_data, cat_group_data, centroid_num, centroid_cat):
        self.num_group_data = num_group_data
        self.cat_group_data = cat_group_data
        self.centroid_num = centroid_num
        self.centroid_cat = centroid_cat
    
    def __len__(self):
        return len(self.cat_group_data)
    
    def get_tensor(self, data, batch_size, dtype):
        if data is not None:
            return torch.tensor(data, dtype=dtype)
        else:
            return torch.empty((batch_size, 0), dtype=dtype)
    
    def __getitem__(self, idx):
        num_group_data = self.num_group_data[idx]
        cat_group_data = self.cat_group_data[idx]

        if num_group_data is not None:
            batch_size = num_group_data.shape[0]
        else:
            batch_size = cat_group_data.shape[0]
        
        return (
            self.get_tensor(num_group_data, batch_size, torch.float),
            self.get_tensor(cat_group_data, batch_size, torch.long),
            self.get_tensor(self.centroid_num[idx], batch_size, torch.float),
            self.get_tensor(self.centroid_cat[idx], batch_size, torch.long)
        )


def get_recon_loss(x, x_decoded_mean, original_dim, alpha, datatype=''):
    if datatype == 'sigmoid':
        recon_loss = alpha * original_dim * F.binary_cross_entropy_with_logits(x_decoded_mean, x, reduction='sum')
    else:
        recon_loss = alpha * original_dim * F.mse_loss(x_decoded_mean, x, reduction='sum')

    return recon_loss


def vae_loss(x, z, z_mean, z_log_var, u_p, lambda_p, theta_p, n_centroid, latent_dim):
    theta_p = F.softmax(theta_p)
    batch_size = x.shape[0]
    device = z.device
    z_repeated = z.unsqueeze(1).repeat(1, n_centroid, 1)
    Z = z_repeated.permute(0, 2, 1)

    z_mean_repeated = z_mean.unsqueeze(1).repeat(1, n_centroid, 1)
    z_mean_t = z_mean_repeated.permute(0, 2, 1)

    z_log_var_repeated = z_log_var.unsqueeze(1).repeat(1, n_centroid, 1)
    z_log_var_t = z_log_var_repeated.permute(0, 2, 1)

    u_p_unsqueezed = u_p.unsqueeze(0)
    u_tensor3 = u_p_unsqueezed.repeat(batch_size, 1, 1)
    u_tensor3 = u_tensor3.permute(0, 2, 1)

    lambda_p_unsqueezed = lambda_p.unsqueeze(0)
    lambda_tensor3 = lambda_p_unsqueezed.repeat(batch_size, 1, 1)
    lambda_tensor3 = lambda_tensor3.permute(0, 2, 1)

    theta_p_reshaped = theta_p.view(1, 1, -1)
    theta_tensor3 = theta_p_reshaped * torch.ones((batch_size, latent_dim, n_centroid)).to(device)

    log_theta = torch.log(theta_tensor3)
    log_lambda = torch.log(2 * math.pi * lambda_tensor3)
    squared_diff = (Z - u_tensor3) ** 2
    p_c_z = torch.exp(torch.sum((log_theta - 0.5 * log_lambda - squared_diff / (2 * lambda_tensor3)), dim=1)) + 1e-10

    gamma = p_c_z / torch.sum(p_c_z, dim=-1, keepdim=True)
    gamma_t = gamma.unsqueeze(1).expand(-1, latent_dim, -1)

    first_term = 0.5 * gamma_t * (
        latent_dim * torch.log(torch.tensor(math.pi * 2, device=device)) + 
        torch.log(lambda_tensor3) + 
        torch.exp(z_log_var_t) / lambda_tensor3 + 
        (z_mean_t - u_tensor3).pow(2) / lambda_tensor3
    )
    first_sum = torch.sum(first_term, dim=(1, 2))
    second_sum = -0.5 * torch.sum(z_log_var + 1, dim=-1)
    theta_p_expanded = theta_p.unsqueeze(0).expand(batch_size, -1)
    log_theta_p_gamma = torch.log(theta_p_expanded) * gamma
    third_sum = -torch.sum(log_theta_p_gamma, dim=-1)
    fourth_sum = torch.sum(torch.log(gamma) * gamma, dim=-1)
    result = first_sum + second_sum + third_sum + fourth_sum

    return result.mean()

    


def compute_loss_recon(X_num, X_cat, Recon_X_num, Recon_X_cat, mu_z, logvar_z, use_sum=False):
    ce_loss_fn = nn.CrossEntropyLoss()
    if use_sum:
        mse_loss = (X_num - Recon_X_num).pow(2).sum()
    else:
        mse_loss = (X_num - Recon_X_num).pow(2).mean()
    mse_loss *= X_num.shape[1]
    ce_loss = 0
    acc = 0
    total_num = 0

    if X_cat is not None and X_cat.shape[1] > 0:
        for idx, x_cat in enumerate(Recon_X_cat):
            if x_cat is not None:
                ce_loss += ce_loss_fn(x_cat, X_cat[:, idx])
                x_hat = x_cat.argmax(dim = -1)
            acc += (x_hat == X_cat[:,idx]).float().sum()
            total_num += x_hat.shape[0]
        
        if not use_sum:
            ce_loss /= (idx + 1)
        acc /= total_num
    else:
        ce_loss = torch.tensor(0.0).to(X_num.device)
        acc = torch.tensor(0.0).to(X_num.device)
    # loss = mse_loss + ce_loss
    ce_loss *= X_cat.shape[1]

    temp = 1 + logvar_z - mu_z.pow(2) - logvar_z.exp()

    if use_sum:
        loss_kld = -0.5 * torch.sum(temp.mean(-1).mean())
    else:
        loss_kld = -0.5 * torch.mean(temp.mean(-1).mean())
    return mse_loss, ce_loss, loss_kld, acc


def compute_loss(X_num, X_cat, Recon_X_num, Recon_X_cat, mu_z, logvar_z, use_sum=False):
    ce_loss_fn = nn.CrossEntropyLoss()
    if use_sum:
        mse_loss = (X_num - Recon_X_num).pow(2).sum()
    else:
        mse_loss = (X_num - Recon_X_num).pow(2).mean()
    ce_loss = 0
    acc = 0
    total_num = 0

    if X_cat is not None and X_cat.shape[1] > 0:
        for idx, x_cat in enumerate(Recon_X_cat):
            if x_cat is not None:
                ce_loss += ce_loss_fn(x_cat, X_cat[:, idx])
                x_hat = x_cat.argmax(dim = -1)
            acc += (x_hat == X_cat[:,idx]).float().sum()
            total_num += x_hat.shape[0]
        
        if not use_sum:
            ce_loss /= (idx + 1)
        acc /= total_num
    else:
        ce_loss = torch.tensor(0.0).to(X_num.device)
        acc = torch.tensor(0.0).to(X_num.device)
    # loss = mse_loss + ce_loss

    temp = 1 + logvar_z - mu_z.pow(2) - logvar_z.exp()

    if use_sum:
        loss_kld = -0.5 * torch.sum(temp.mean(-1).mean())
    else:
        loss_kld = -0.5 * torch.mean(temp.mean(-1).mean())
    return mse_loss, ce_loss, loss_kld, acc


def lossfun_original(model, x_num, recon_x_num, x_cat, recon_x_cat, mu, logvar, z):
    ce_loss_fn = nn.CrossEntropyLoss()
    batch_size = x_num.size(0)
    z = z.unsqueeze(1)
    h = z - model.mu
    h = torch.exp(-0.5 * torch.sum((h * h / model.logvar.exp()), dim=2))
    h = h / torch.sum(0.5 * model.logvar, dim=1).exp()
    p_z_given_c = h / (2 * math.pi)
    p_z_c = p_z_given_c * model.weights
    gamma = p_z_c / torch.sum(p_z_c, dim=1, keepdim=True)

    h = logvar.exp().unsqueeze(1) + (mu.unsqueeze(1) - model.mu).pow(2)
    h = torch.sum(model.logvar + h / model.logvar.exp(), dim=2)
    recon_num_loss = F.mse_loss(recon_x_num, x_num, reduction='sum')

    ce_loss = 0

    if x_cat is not None and x_cat.shape[1] > 0:
        for idx, sub_x_cat in enumerate(recon_x_cat):
            if x_cat is not None:
                ce_loss += ce_loss_fn(sub_x_cat, x_cat[:, idx])
    else:
        ce_loss = torch.tensor(0.0).to(x_num.device)

    recon_cat_loss = ce_loss

    if recon_num_loss.isnan():
        recon_num_loss = torch.tensor(0.0).to(x_num.device)

    if recon_cat_loss.isnan():
        recon_cat_loss = torch.tensor(0.0).to(x_num.device)
 
    vade_loss = 0.5 * torch.sum(gamma * h) \
        - torch.sum(gamma * torch.log(model.weights + 1e-9)) \
        + torch.sum(gamma * torch.log(gamma + 1e-9)) \
        - 0.5 * torch.sum(1 + logvar)
    recon_num_loss = recon_num_loss / batch_size
    recon_cat_loss = recon_cat_loss / batch_size
    vade_loss = vade_loss / batch_size
    return recon_num_loss, recon_cat_loss, vade_loss


def lossfun(model, x_num, recon_x_num, x_cat, recon_x_cat, mu, logvar, z):
    ce_loss_fn = nn.CrossEntropyLoss()
    batch_size = x_num.size(0)
    z_expanded = z.unsqueeze(1)  # Only expand z when needed
    mu_diff = z_expanded - model.mu
    squared_diff = mu_diff.pow(2)

    # Compute the negative log likelihood part in a stable manner
    neg_log_likelihood = -0.5 * (squared_diff / model.logvar.exp()).sum(dim=2)
    # Normalize using log-sum-exp for numerical stability
    log_norm = torch.logsumexp(neg_log_likelihood, dim=1, keepdim=True)
    log_p_z_given_c = neg_log_likelihood - log_norm - math.log(2 * math.pi)
    log_p_z_c = log_p_z_given_c + torch.log(model.weights)
    log_gamma = log_p_z_c - torch.logsumexp(log_p_z_c, dim=1, keepdim=True)

    # Second part - Adjust for variance in a numerically stable manner
    # This approach assumes model.logvar is directly usable and meaningful in its current form.
    # Any direct operation on logvar should be carefully considered for numerical stability.
    variance_adjustment = logvar.exp().unsqueeze(1) + squared_diff
    h = torch.sum(model.logvar.unsqueeze(0) + variance_adjustment / model.logvar.exp().unsqueeze(0), dim=2)
    gamma = log_gamma.exp()

    recon_num_loss = F.mse_loss(recon_x_num, x_num, reduction='sum')

    ce_loss = 0

    if x_cat is not None and x_cat.shape[1] > 0:
        for idx, sub_x_cat in enumerate(recon_x_cat):
            if x_cat is not None:
                ce_loss += ce_loss_fn(sub_x_cat, x_cat[:, idx])
    else:
        ce_loss = torch.tensor(0.0).to(x_num.device)

    recon_cat_loss = ce_loss

    if recon_num_loss.isnan():
        recon_num_loss = torch.tensor(0.0).to(x_num.device)

    if recon_cat_loss.isnan():
        recon_cat_loss = torch.tensor(0.0).to(x_num.device)
 
    vade_loss = 0.5 * torch.sum(gamma * h) \
        - torch.sum(gamma * torch.log(model.weights + 1e-9)) \
        + torch.sum(gamma * torch.log(gamma + 1e-9)) \
        - 0.5 * torch.sum(1 + logvar)
    recon_num_loss = recon_num_loss / batch_size
    recon_cat_loss = recon_cat_loss / batch_size
    vade_loss = vade_loss / batch_size
    return recon_num_loss, recon_cat_loss, vade_loss


def predict_clusters(model, dataloader, device):
    model.eval()  # Ensure the model is in evaluation mode.
    
    res = []
    with torch.no_grad():  # No need to compute gradients.
        for batch_num, batch_cat in dataloader:
            batch_num = batch_num.to(device)
            batch_cat = batch_cat.to(device)

            pred = model.classify(batch_num, batch_cat)
            res.append(pred.cpu().numpy())

    res = np.concatenate(res, axis=0)
    return res


def vade_loss(log_probs, gmm_pi):
    log_pi = torch.log(gmm_pi)
    log_p_z = logsumexp(log_probs + log_pi.unsqueeze(0), dim=1)
    log_posteriors = log_probs + log_pi.unsqueeze(0) - log_p_z.unsqueeze(1)
    posteriors = torch.exp(log_posteriors)
    clustering_loss = -torch.mean(torch.sum(posteriors * log_posteriors, dim=1)) * 10
    
    return clustering_loss

def train_vae_vade(args):
    num_group_data = args['num_group_data']
    cat_group_data = args['cat_group_data']
    batch_size = args['vae_batch_size']

    transformed_num_group_data, transformed_cat_group_data, d_numerical, categories = preprocess_group_data(
        num_group_data, 
        cat_group_data,
    )

    transformed_num_group_data = np.concatenate(transformed_num_group_data, axis=0)
    transformed_cat_group_data = np.concatenate(transformed_cat_group_data, axis=0)

    if categories is None:
        categories = []

    num_clusters = args['num_clusters']

    if 'D_TOKEN' in args:
        global D_TOKEN
        D_TOKEN = args['D_TOKEN']

    max_beta = args['max_beta']
    min_beta = args['min_beta']
    lambd = args['lambd']

    device =  args['device']
    ckpt_dir = args['ckpt_dir']

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    model_save_path = f'{ckpt_dir}/model.pt'
    encoder_save_path = f'{ckpt_dir}/encoder.pt'
    decoder_save_path = f'{ckpt_dir}/decoder.pt'

    transformed_num_group_tensor = torch.tensor(transformed_num_group_data, dtype=torch.float)
    transformed_cat_group_tensor = torch.tensor(transformed_cat_group_data, dtype=torch.long)

    train_data = TensorDataset(
        transformed_num_group_tensor,
        transformed_cat_group_tensor,
    )

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    loader = DataLoader(
        train_data,
        batch_size = batch_size,
        shuffle = False,
        num_workers = 4,
    )

    model = Model_VAE(
        NUM_LAYERS, 
        d_numerical, 
        categories, 
        D_TOKEN, 
        n_head = N_HEAD, 
        factor = FACTOR, 
        bias = True,
        num_clusters=num_clusters,
    )
    model = model.to(device)

    latent_dim = D_TOKEN * (d_numerical + len(categories))
    pre_encoder = Encoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)
    pre_decoder = Decoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)

    pre_encoder.eval()
    pre_decoder.eval()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

    num_epochs = args['vae_epochs']
    best_train_loss = float('inf')

    current_lr = optimizer.param_groups[0]['lr']
    current_lr = 1e-4
    patience = 0

    beta = max_beta
    start_time = time.time()
    for epoch in range(num_epochs):
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_loss_vade = 0.0

        curr_count = 0

        for batch_num, batch_cat in pbar:
            model.train()
            optimizer.zero_grad()

            batch_num = batch_num.to(device)
            batch_cat = batch_cat.to(device)

            recon_x_num, recon_x_cat, mu_z, logvar_z, z = model.vade_forward(batch_num, batch_cat)

            loss_mse, loss_ce, loss_vade = lossfun(
                model, 
                batch_num, 
                recon_x_num, 
                batch_cat, 
                recon_x_cat, 
                mu_z, 
                logvar_z, 
                z
            )

            loss = loss_mse + loss_ce + loss_vade

            loss.backward()
            # for name, param in model.named_parameters():
            #     if param.grad is not None and torch.any(torch.isnan(param.grad)):
            #         print(f"NaN gradient in {name}")
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            batch_length = batch_num.shape[0]
            curr_count += batch_length
            curr_loss_multi += loss_ce.item() * batch_length
            curr_loss_gauss += loss_mse.item() * batch_length
            curr_loss_vade += loss_vade.item() * batch_length

        num_loss = curr_loss_gauss / curr_count
        cat_loss = curr_loss_multi / curr_count
        cluster_loss = curr_loss_vade / curr_count
        
        train_loss = num_loss + cat_loss + cluster_loss
        scheduler.step(train_loss)

        new_lr = optimizer.param_groups[0]['lr']

        if new_lr != current_lr:
            current_lr = new_lr
            print(f"Learning rate updated: {current_lr}")
            
        if train_loss < best_train_loss:
            best_train_loss = train_loss
            patience = 0
            torch.save(model.state_dict(), model_save_path)
        else:
            patience += 1
            if patience == 10:
                if beta > min_beta:
                    beta = beta * lambd

            if patience == 500:
                print('Early stopping')
                break

        '''
            Evaluation
        '''
        model.eval()
        with torch.no_grad():
            recon_x_num, recon_x_cat, mu_z, logvar_z, z = model.vade_forward(batch_num, batch_cat)

            train_loss_mse, train_loss_ce, train_loss_vade = lossfun(
                model, 
                batch_num, 
                recon_x_num, 
                batch_cat, 
                recon_x_cat, 
                mu_z, 
                logvar_z, 
                z
            )
            
            train_loss = train_loss_mse.item() + train_loss_ce.item() + train_loss_vade.item()

            scheduler.step(train_loss)

            cluster_assignments = predict_clusters(model, loader, device)
            print('num unique clusters: ', len(np.unique(cluster_assignments)))
            print()

        print(
            'epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train Cluster:{:.6f}'.format(
                epoch, beta, num_loss, cat_loss, cluster_loss
            )
        )

    end_time = time.time()
    print('Training time: {:.4f} mins'.format((end_time - start_time)/60))

    # Saving latent embeddings
    pre_encoder.load_weights(model)
    pre_decoder.load_weights(model)

    torch.save(pre_encoder.state_dict(), encoder_save_path)
    torch.save(pre_decoder.state_dict(), decoder_save_path)

    loader = DataLoader(
        train_data,
        batch_size = batch_size,
        shuffle = False,
        num_workers = 4,
    )
    cluster_assignments = predict_clusters(model, loader, device)

    return cluster_assignments


def train_cluster_vae(args):
    do_train = True
    if args['read_ckpt']:
        if os.path.exists(f'{args["ckpt_dir"]}/1d_train_h.npy'):
            try:
                train_h = np.load(f'{args["ckpt_dir"]}/1d_train_h.npy')
                print('Pretrained latents exist, skip training')
                do_train = False
            except:
                do_train = True
                print('Pretrained latents loading failed, train model')

    if not do_train:
        return train_h

    num_group_data = args['num_group_data']
    cat_group_data = args['cat_group_data']

    transformed_num_group_data, transformed_cat_group_data, d_numerical, categories = preprocess_group_data(
        num_group_data, 
        cat_group_data,
    )

    if categories is None:
        categories = []

    centroid_nums = []
    centroid_cats = []
    for batch_num, batch_cat in zip(transformed_num_group_data, transformed_cat_group_data):
        if batch_num is not None:
            batch_size = batch_num.shape[0]
        else:
            batch_size = batch_cat.shape[0]

        if batch_num is not None:
            centroid_num = np.mean(batch_num, axis=0)
            centroid_num = [centroid_num] * len(batch_num)
            centroid_num = np.array(centroid_num)
        else:
            centroid_num = np.empty((batch_size, 0))

        if batch_cat is not None:
            centroid_cat = np.round(np.mean(batch_cat.astype(float), axis=0))
            centroid_cat = [centroid_cat] * len(batch_cat)
            centroid_cat = np.array(centroid_cat).astype(int)
        else:
            centroid_cat = np.empty((batch_size, 0))
        centroid_nums.append(centroid_num)
        centroid_cats.append(centroid_cat)

    transformed_num_group_data = np.concatenate(transformed_num_group_data, axis=0)
    transformed_cat_group_data = np.concatenate(transformed_cat_group_data, axis=0)
    centroid_nums = np.concatenate(centroid_nums, axis=0)
    centroid_cats = np.concatenate(centroid_cats, axis=0)

    transformed_num_group_tensor = torch.tensor(transformed_num_group_data, dtype=torch.float)
    transformed_cat_group_tensor = torch.tensor(transformed_cat_group_data, dtype=torch.long)

    train_data = TensorDataset(
        transformed_num_group_tensor,
        transformed_cat_group_tensor,
        torch.tensor(centroid_nums, dtype=torch.float),
        torch.tensor(centroid_cats, dtype=torch.long)
    )

    # train_data = GroupDataset(
    #     transformed_num_group_data, 
    #     transformed_cat_group_data,
    #     centroid_nums,
    #     centroid_cats
    # )
    train_loader = DataLoader(train_data, batch_size=16000, shuffle=True)

    max_beta = args['max_beta']
    min_beta = args['min_beta']
    lambd = args['lambd']

    device =  args['device']
    ckpt_dir = args['ckpt_dir']

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    model_save_path = f'{ckpt_dir}/model.pt'
    encoder_save_path = f'{ckpt_dir}/encoder.pt'
    decoder_save_path = f'{ckpt_dir}/decoder.pt'

    model = Model_VAE(
        NUM_LAYERS, 
        d_numerical, 
        categories, 
        D_TOKEN, 
        n_head = N_HEAD, 
        factor = FACTOR, 
        bias = True,
        compress_dim=1
    )
    model = model.to(device)

    pre_encoder = Encoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)
    pre_decoder = Decoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)

    pre_encoder.eval()
    pre_decoder.eval()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

    num_epochs = args['vae_epochs']
    best_train_loss = float('inf')

    current_lr = optimizer.param_groups[0]['lr']
    patience = 0

    beta = max_beta
    start_time = time.time()
    for epoch in range(num_epochs):
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_loss_kl = 0.0

        curr_count = 0

        for batch_num, batch_cat, centroid_num, centroid_cat in pbar:
            model.train()
            optimizer.zero_grad()

            batch_num = batch_num.to(device)
            batch_cat = batch_cat.to(device)
            centroid_num = centroid_num.to(device)
            centroid_cat = centroid_cat.to(device)

            Recon_X_num, Recon_X_cat, mu_z, std_z, h_compressed = model(batch_num, batch_cat, True)
        
            loss_mse, loss_ce, loss_kld, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)

            _, _, _, _, h_compressed_centroid = model(centroid_num, centroid_cat, True)

            centroid_loss = torch.mean((h_compressed - h_compressed_centroid).pow(2))
            
            if loss_mse.isnan():
                loss_mse = torch.tensor(0.0).to(device)
            loss = loss_mse + loss_ce + beta * loss_kld + centroid_loss
            loss.backward()
            optimizer.step()

            batch_length = batch_num.shape[0]
            curr_count += batch_length
            curr_loss_multi += loss_ce.item() * batch_length
            curr_loss_gauss += loss_mse.item() * batch_length
            curr_loss_kl    += loss_kld.item() * batch_length

        num_loss = curr_loss_gauss / curr_count
        cat_loss = curr_loss_multi / curr_count
        kl_loss = curr_loss_kl / curr_count
        
        train_loss = num_loss + cat_loss
        scheduler.step(train_loss)

        new_lr = optimizer.param_groups[0]['lr']

        if new_lr != current_lr:
            current_lr = new_lr
            print(f"Learning rate updated: {current_lr}")
            
        if train_loss < best_train_loss:
            best_train_loss = train_loss
            patience = 0
            torch.save(model.state_dict(), model_save_path)
        else:
            patience += 1
            if patience == 10:
                if beta > min_beta:
                    beta = beta * lambd

            if patience == 500:
                print('Early stopping')
                break

        model.eval()
        with torch.no_grad():
            Recon_X_num, Recon_X_cat, mu_z, std_z, h_compressed = model(batch_num, batch_cat, True)

            train_mse_loss, train_ce_loss, train_kl_loss, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)

            _, _, _, _, h_compressed_centroid = model(centroid_num, centroid_cat, True)

            centroid_loss = torch.mean((h_compressed - h_compressed_centroid).pow(2))
            
            if train_mse_loss.isnan():
                train_loss = train_ce_loss.item() + centroid_loss.item()
            else:
                train_loss = train_mse_loss.item() + train_ce_loss.item() + centroid_loss.item()

            scheduler.step(train_loss)

        # print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Train ACC:{:6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss, train_acc.item()))
        print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Train Centroid MSE: {:.6f}'.format(
            epoch, beta, num_loss, cat_loss, kl_loss, centroid_loss))

    end_time = time.time()
    print('Training time: {:.4f} mins'.format((end_time - start_time)/60))

    # Saving latent embeddings
    with torch.no_grad():
        pre_encoder.load_weights(model)
        pre_decoder.load_weights(model)

        torch.save(pre_encoder.state_dict(), encoder_save_path)
        torch.save(pre_decoder.state_dict(), decoder_save_path)

        transformed_num_group_tensor = transformed_num_group_tensor.to(device)
        transformed_cat_group_tensor = transformed_cat_group_tensor.to(device)

        print('Successfully load and save the model!')
        model.eval()
        with torch.no_grad():
            train_h = model(
                transformed_num_group_tensor, 
                transformed_cat_group_tensor,
                True
            )[-1].detach().cpu().numpy()

        np.save(f'{ckpt_dir}/1d_train_h.npy', train_h)

        print('Successfully save pretrained embeddings in disk!')

    return train_h


def train_vae(args):
    data = args['data']

    if 'D_TOKEN' in args:
        global D_TOKEN
        D_TOKEN = args['D_TOKEN']

    max_beta = args['max_beta']
    min_beta = args['min_beta']
    lambd = args['lambd']

    device =  args['device']
    info = args['info']
    ckpt_dir = args['ckpt_dir']

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    model_save_path = f'{ckpt_dir}/model.pt'
    encoder_save_path = f'{ckpt_dir}/encoder.pt'
    decoder_save_path = f'{ckpt_dir}/decoder.pt'

    X_num, X_cat, categories, d_numerical = preprocess_from_numpy(
        data, 
        task_type=info['task_type'],
        concat=args['has_y'],
        n_classes=info['n_classes'],
        has_test=args['has_test']
    )

    if args['has_test']:

        X_train_num, X_test_num = X_num
        X_train_cat, X_test_cat = X_cat

        X_train_num, X_test_num = torch.tensor(X_train_num).float(), torch.tensor(X_test_num).float()
        X_train_cat, X_test_cat =  torch.tensor(X_train_cat), torch.tensor(X_test_cat)
    else:
        if X_num is None:
            X_num = np.empty((X_cat.shape[0], 0))
        X_train_num = torch.tensor(X_num).float()
        X_train_cat =  torch.tensor(X_cat)


    train_data = TabularDataset(X_train_num.float(), X_train_cat)

    if args['has_test']:
        X_test_num = X_test_num.float().to(device)
        X_test_cat = X_test_cat.to(device)

    batch_size = args['vae_batch_size']
    train_loader = DataLoader(
        train_data,
        batch_size = batch_size,
        shuffle = True,
        num_workers = 4,
    )

    model = Model_VAE(
        NUM_LAYERS, 
        d_numerical, 
        categories, 
        D_TOKEN, 
        n_head = N_HEAD, 
        factor = FACTOR, 
        bias = True
    )
    model = model.to(device)

    pre_encoder = Encoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)
    pre_decoder = Decoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)

    pre_encoder.eval()
    pre_decoder.eval()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

    num_epochs = args['vae_epochs']
    best_train_loss = float('inf')

    current_lr = optimizer.param_groups[0]['lr']
    patience = 0

    do_train = True

    if args['read_ckpt']:
        if os.path.exists(f'{ckpt_dir}/train_z.npy'):
            print('Pretrained embeddings exist, skip training')
            return
        ckpt_exist = os.path.exists(model_save_path) and \
            os.path.exists(encoder_save_path) and \
            os.path.exists(decoder_save_path)
        
        if ckpt_exist:
            try:
                model.load_state_dict(torch.load(model_save_path)).to(device)
                pre_encoder.load_state_dict(torch.load(encoder_save_path)).to(device)
                pre_decoder.load_state_dict(torch.load(decoder_save_path)).to(device)
                do_train = False
                print('Model loaded from', model_save_path)
            except:
                do_train = True
                print('Model loading failed, train model')

    if do_train:

        beta = max_beta
        start_time = time.time()
        for epoch in range(num_epochs):
            pbar = tqdm(train_loader, total=len(train_loader))
            pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

            curr_loss_multi = 0.0
            curr_loss_gauss = 0.0
            curr_loss_kl = 0.0

            curr_count = 0

            for batch_num, batch_cat in pbar:
                model.train()
                optimizer.zero_grad()

                batch_num = batch_num.to(device)
                batch_cat = batch_cat.to(device)

                Recon_X_num, Recon_X_cat, mu_z, std_z = model(batch_num, batch_cat)
            
                loss_mse, loss_ce, loss_kld, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)

                if loss_mse.isnan():
                    loss_mse = torch.tensor(0.0).to(device)
                loss = loss_mse + loss_ce + beta * loss_kld
                loss.backward()
                optimizer.step()

                batch_length = batch_num.shape[0]
                curr_count += batch_length
                curr_loss_multi += loss_ce.item() * batch_length
                curr_loss_gauss += loss_mse.item() * batch_length
                curr_loss_kl    += loss_kld.item() * batch_length

            num_loss = curr_loss_gauss / curr_count
            cat_loss = curr_loss_multi / curr_count
            kl_loss = curr_loss_kl / curr_count
            
            train_loss = num_loss + cat_loss
            scheduler.step(train_loss)

            new_lr = optimizer.param_groups[0]['lr']

            if new_lr != current_lr:
                current_lr = new_lr
                print(f"Learning rate updated: {current_lr}")
                
            if train_loss < best_train_loss:
                best_train_loss = train_loss
                patience = 0
                torch.save(model.state_dict(), model_save_path)
            else:
                patience += 1
                if patience == 10:
                    if beta > min_beta:
                        beta = beta * lambd

                if patience == 500:
                    print('Early stopping')
                    break

            '''
                Evaluation
            '''
            model.eval()
            with torch.no_grad():
                if args['has_test']:
                    Recon_X_num, Recon_X_cat, mu_z, std_z = model(X_test_num, X_test_cat)

                    val_mse_loss, val_ce_loss, val_kl_loss, val_acc = compute_loss(X_test_num, X_test_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)
                    val_loss = val_mse_loss.item() * 0 + val_ce_loss.item()

                    scheduler.step(val_loss)
                else:
                    Recon_X_num, Recon_X_cat, mu_z, std_z = model(batch_num, batch_cat)

                    train_mse_loss, train_ce_loss, train_kl_loss, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)
                    if train_mse_loss.isnan():
                        train_loss = train_ce_loss.item()
                    else:
                        train_loss = train_mse_loss.item() + train_ce_loss.item()

                    scheduler.step(train_loss)

            # print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Train ACC:{:6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss, train_acc.item()))
            if args['has_test']:
                print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Val MSE:{:.6f}, Val CE:{:.6f}, Train ACC:{:6f}, Val ACC:{:6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss, val_mse_loss.item(), val_ce_loss.item(), train_acc.item(), val_acc.item()))
            else:
                print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss))

        end_time = time.time()
        print('Training time: {:.4f} mins'.format((end_time - start_time)/60))
    
    # Saving latent embeddings
    with torch.no_grad():
        pre_encoder.load_weights(model)
        pre_decoder.load_weights(model)

        if do_train:
            torch.save(pre_encoder.state_dict(), encoder_save_path)
            torch.save(pre_decoder.state_dict(), decoder_save_path)

        loader = DataLoader(
            train_data,
            batch_size = batch_size,
            shuffle = False,
            num_workers = 4,
        )

        print('Successfully load and save the model!')

        embeddings_list = []

        for X_num_batch, X_cat_batch in loader:
            # Move batches to the same device as the model
            X_num_batch = X_num_batch.to(device)
            X_cat_batch = X_cat_batch.to(device)

            # Generate embeddings for the current batch and convert them to numpy
            batch_embeddings = pre_encoder(X_num_batch, X_cat_batch).detach().cpu().numpy()
            
            # Append the embeddings of the current batch to the list
            embeddings_list.append(batch_embeddings)

        # Concatenate all batch embeddings into a single array
        train_z = np.concatenate(embeddings_list, axis=0)

        # Save the concatenated embeddings to disk
        np.save(f'{ckpt_dir}/train_z.npy', train_z)

        print('Successfully save pretrained embeddings in disk!')


def main(args):
    dataname = args.dataname
    data_dir = f'data/{dataname}'

    max_beta = args.max_beta
    min_beta = args.min_beta
    lambd = args.lambd

    device =  args.device


    info_path = f'data/{dataname}/info.json'

    with open(info_path, 'r') as f:
        info = json.load(f)

    curr_dir = os.path.dirname(os.path.abspath(__file__))
    ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    model_save_path = f'{ckpt_dir}/model.pt'
    encoder_save_path = f'{ckpt_dir}/encoder.pt'
    decoder_save_path = f'{ckpt_dir}/decoder.pt'

    X_num, X_cat, categories, d_numerical = preprocess(
        data_dir, 
        task_type = info['task_type'],
        concat=args.has_y
    )

    X_train_num, _ = X_num
    X_train_cat, _ = X_cat

    X_train_num, X_test_num = X_num
    X_train_cat, X_test_cat = X_cat

    # X_train_num = torch.tensor(X_train_num).float()
    # X_train_cat =  torch.tensor(X_train_cat)

    X_train_num, X_test_num = torch.tensor(X_train_num).float(), torch.tensor(X_test_num).float()
    X_train_cat, X_test_cat =  torch.tensor(X_train_cat), torch.tensor(X_test_cat)


    train_data = TabularDataset(X_train_num.float(), X_train_cat)

    X_test_num = X_test_num.float().to(device)
    X_test_cat = X_test_cat.to(device)

    batch_size = args.vae_batch_size
    train_loader = DataLoader(
        train_data,
        batch_size = batch_size,
        shuffle = True,
        num_workers = 4,
    )

    model = Model_VAE(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR, bias = True)
    model = model.to(device)

    pre_encoder = Encoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)
    pre_decoder = Decoder_model(NUM_LAYERS, d_numerical, categories, D_TOKEN, n_head = N_HEAD, factor = FACTOR).to(device)

    pre_encoder.eval()
    pre_decoder.eval()

    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WD)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

    num_epochs = args.vae_epochs
    best_train_loss = float('inf')

    current_lr = optimizer.param_groups[0]['lr']
    patience = 0

    beta = max_beta
    start_time = time.time()
    for epoch in range(num_epochs):
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_loss_kl = 0.0

        curr_count = 0

        for batch_num, batch_cat in pbar:
            model.train()
            optimizer.zero_grad()

            batch_num = batch_num.to(device)
            batch_cat = batch_cat.to(device)

            Recon_X_num, Recon_X_cat, mu_z, std_z = model(batch_num, batch_cat)
        
            loss_mse, loss_ce, loss_kld, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)

            loss = loss_mse + loss_ce + beta * loss_kld
            loss.backward()
            optimizer.step()

            batch_length = batch_num.shape[0]
            curr_count += batch_length
            curr_loss_multi += loss_ce.item() * batch_length
            curr_loss_gauss += loss_mse.item() * batch_length
            curr_loss_kl    += loss_kld.item() * batch_length

        num_loss = curr_loss_gauss / curr_count
        cat_loss = curr_loss_multi / curr_count
        kl_loss = curr_loss_kl / curr_count
        
        # num_loss = np.around(curr_loss_gauss / curr_count, 5)
        # cat_loss = np.around(curr_loss_multi / curr_count, 5)
        # kl_loss = np.around(curr_loss_kl / curr_count, 5)
        
        train_loss = num_loss + cat_loss
        scheduler.step(train_loss)

        new_lr = optimizer.param_groups[0]['lr']

        if new_lr != current_lr:
            current_lr = new_lr
            print(f"Learning rate updated: {current_lr}")
            
        if train_loss < best_train_loss:
            best_train_loss = train_loss
            patience = 0
            torch.save(model.state_dict(), model_save_path)
        else:
            patience += 1
            if patience == 10:
                if beta > min_beta:
                    beta = beta * lambd

        '''
            Evaluation
        '''
        model.eval()
        with torch.no_grad():
            Recon_X_num, Recon_X_cat, mu_z, std_z = model(X_test_num, X_test_cat)

            val_mse_loss, val_ce_loss, val_kl_loss, val_acc = compute_loss(X_test_num, X_test_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)
            val_loss = val_mse_loss.item() * 0 + val_ce_loss.item()

            scheduler.step(val_loss)

        # print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Train ACC:{:6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss, train_acc.item()))
        print('epoch: {}, beta = {:.6f}, Train MSE: {:.6f}, Train CE:{:.6f}, Train KL:{:.6f}, Val MSE:{:.6f}, Val CE:{:.6f}, Train ACC:{:6f}, Val ACC:{:6f}'.format(epoch, beta, num_loss, cat_loss, kl_loss, val_mse_loss.item(), val_ce_loss.item(), train_acc.item(), val_acc.item() ))

    end_time = time.time()
    print('Training time: {:.4f} mins'.format((end_time - start_time)/60))
    
    # Saving latent embeddings
    with torch.no_grad():
        pre_encoder.load_weights(model)
        pre_decoder.load_weights(model)

        torch.save(pre_encoder.state_dict(), encoder_save_path)
        torch.save(pre_decoder.state_dict(), decoder_save_path)

        X_train_num = X_train_num.to(device)
        X_train_cat = X_train_cat.to(device)

        print('Successfully load and save the model!')

        train_z = pre_encoder(X_train_num, X_train_cat).detach().cpu().numpy()

        np.save(f'{ckpt_dir}/train_z.npy', train_z)

        print('Successfully save pretrained embeddings in disk!')

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Variational Autoencoder')

    parser.add_argument('--dataname', type=str, default='adult', help='Name of dataset.')
    parser.add_argument('--gpu', type=int, default=0, help='GPU index.')
    parser.add_argument('--max_beta', type=float, default=1e-2, help='Initial Beta.')
    parser.add_argument('--min_beta', type=float, default=1e-5, help='Minimum Beta.')
    parser.add_argument('--lambd', type=float, default=0.7, help='Decay of Beta.')

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and torch.cuda.is_available():
        args.device = 'cuda:{}'.format(args.gpu)
    else:
        args.device = 'cpu'