from tabsyn.process_dataset import get_info_from_domain, pipeline_process_data
from tabsyn.tabsyn.vae.main import train_vae, train_cluster_vae, train_vae_vade
from tabsyn.tabsyn.train_classifier import train_classifier as tabsyn_train_classifier
from tabsyn.tabsyn.main import train_tabsyn
from tabsyn.tabsyn.sample import cond_sample
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from pipeline_utils import *
from sklearn.cluster import KMeans
from collections import defaultdict

import os
import json
import numpy as np

def discretize_array_evenly(arr, n):
    """
    Adjusted discretization to ensure 'n' distinct values.
    
    This function uses percentile cuts instead of linspace to ensure
    that each bin has approximately the same number of elements,
    which can help in achieving a more even distribution for 'n' distinct bins.
    
    Parameters:
    - arr: NumPy array to discretize.
    - n: Number of bins.
    
    Returns:
    - Discretized array with values mapped from 0 to n-1.
    """
    # Find percentiles to ensure each bin has approximately an equal number of elements
    percentiles = np.percentile(arr, np.linspace(0, 100, n+1))
    
    # Use digitize with percentiles as bins
    discretized_evenly = np.digitize(arr, percentiles) - 1  # Adjust to 0-based index
    
    # Ensure values are within 0 to n-1
    discretized_evenly = np.clip(discretized_evenly, 0, n-1)
    
    return discretized_evenly

def discretize_array(arr, n):
    # Ensure arr is a numpy array and n is an integer greater than 1
    if not isinstance(arr, np.ndarray):
        raise ValueError("Input must be a numpy array.")
    if not isinstance(n, int) or n <= 1:
        raise ValueError("n must be an integer greater than 1.")
    # Edge case: if array has unique values fewer than n, directly assign ranks
    if len(np.unique(arr)) <= n:
        unique_values = np.unique(arr)
        value_to_bin = {v: i for i, v in enumerate(sorted(unique_values))}
        discretized = np.array([value_to_bin[v] for v in arr])
        return discretized
    # Normalize the array to [0, 1]
    normalized_arr = (arr - np.min(arr)) / (np.max(arr) - np.min(arr))
    # Scale to [0, n-1] and floor values to get discrete bins
    # We use np.floor to ensure we map to n bins, adjusting for edge values by a small epsilon
    epsilon = 1e-10
    scaled_arr = np.floor(normalized_arr * (n - epsilon))
    # Convert to integers
    discretized_arr = scaled_arr.astype(int)
    return discretized_arr

def discretize_array_old(arr, n):
    # Ensure arr is a numpy array and n is an integer greater than 1
    if not isinstance(arr, np.ndarray):
        raise ValueError("Input must be a numpy array.")
    if not isinstance(n, int) or n <= 1:
        raise ValueError("n must be an integer greater than 1.")
    # Find the min and max of the array
    min_val = arr.min()
    max_val = arr.max()
    # Avoid division by zero if all array values are the same
    if min_val == max_val:
        return np.zeros(arr.shape, dtype=int)
    # Scale the array to [0, n-1]
    scaled = (arr - min_val) / (max_val - min_val) * (n - 1)
    # Convert to integers
    discretized = np.round(scaled).astype(int)
    return discretized


def clustering_naive(
        child_df, 
        child_domain_dict, 
        parent_df,
        parent_domain_dict,
        child_primary_key,
        parent_primary_key,
        num_clusters,
        parent_scale,
        key_scale,
        parent_name,
        child_name,
        args,
        handle_size1=False
    ):
    original_child_cols = list(child_df.columns)
    original_parent_cols = list(parent_df.columns)

    relation_cluster_name = f'{parent_name}_{child_name}_cluster'

    child_data = child_df.to_numpy()
    parent_data = parent_df.to_numpy()

    child_num_cols = []
    child_cat_cols = []

    parent_num_cols = []
    parent_cat_cols = []

    for col_index, col in enumerate(original_child_cols):
        if col in child_domain_dict:
            if child_domain_dict[col]['type'] == 'discrete':
                child_cat_cols.append((col_index, col))
            else:
                child_num_cols.append((col_index, col))

    for col_index, col in enumerate(original_parent_cols):
        if col in parent_domain_dict:
            if parent_domain_dict[col]['type'] == 'discrete':
                parent_cat_cols.append((col_index, col))
            else:
                parent_num_cols.append((col_index, col))
    
    child_primary_key_index = original_child_cols.index(child_primary_key)
    parent_primary_key_index = original_parent_cols.index(parent_primary_key)
    foreing_key_index = original_child_cols.index(parent_primary_key)

    # sort child data by foreign key
    sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])]
    child_group_data_dict = CRF.tools.get_group_data_dict(sorted_child_data, [foreing_key_index,])

    # sort parent data by primary key
    sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])]

    group_lengths = []
    unique_group_ids = sorted_parent_data[:, parent_primary_key_index]
    for group_id in unique_group_ids:
        group_id = tuple([group_id])
        if not group_id in child_group_data_dict:
            group_lengths.append(0)
        else:
            group_lengths.append(len(child_group_data_dict[group_id]))

    group_lengths = np.array(group_lengths, dtype=int)

    sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths, axis=0)
    assert((sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all())

    child_group_data = CRF.tools.get_group_data(sorted_child_data, [foreing_key_index,])

    sorted_child_num_data = sorted_child_data[:, [col_index for col_index, col in child_num_cols]]
    sorted_child_cat_data = sorted_child_data[:, [col_index for col_index, col in child_cat_cols]]
    sorted_parent_num_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_num_cols]]
    sorted_parent_cat_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_cat_cols]]

    joint_num_matrix = np.concatenate([sorted_child_num_data, sorted_parent_num_data], axis=1)
    joint_cat_matrix = np.concatenate([sorted_child_cat_data, sorted_parent_cat_data], axis=1)

    if joint_cat_matrix.shape[1] > 0:

        joint_cat_matrix_p_index = sorted_child_cat_data.shape[1]
        joint_num_matrix_p_index = sorted_child_num_data.shape[1]

        cat_converted = []
        label_encoders = []
        for i in range(joint_cat_matrix.shape[1]):
            if len(np.unique(joint_cat_matrix[:, i])) > 1000:
                continue
            label_encoder = LabelEncoder()
            cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float))
            label_encoders.append(label_encoder)

        cat_converted = np.vstack(cat_converted).T

        # Initialize an empty array to store the encoded values
        cat_one_hot = np.empty((cat_converted.shape[0], 0))

        # Loop through each column in the data and encode it
        for col in range(cat_converted.shape[1]):
            encoder = OneHotEncoder(sparse_output=False)
            column = cat_converted[:, col].reshape(-1, 1)
            encoded_column = encoder.fit_transform(column)
            cat_one_hot = np.concatenate((cat_one_hot, encoded_column), axis=1)

        cat_one_hot[:, joint_cat_matrix_p_index:] = parent_scale * cat_one_hot[:, joint_cat_matrix_p_index:]

    # Perform quantile normalization using QuantileTransformer
    num_quantile = quantile_normalize_sklearn(joint_num_matrix)
    num_min_max = min_max_normalize_sklearn(joint_num_matrix)

    key_quantile = quantile_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1))
    key_min_max = min_max_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1))

    # key_scaled = key_scaler * key_quantile
    key_scaled = key_scale * key_min_max

    num_quantile[:, joint_num_matrix_p_index:] = parent_scale * num_quantile[:, joint_num_matrix_p_index:]
    num_min_max[:, joint_num_matrix_p_index:] = parent_scale * num_min_max[:, joint_num_matrix_p_index:]
    

    # cluster_data = np.concatenate((num_quantile, cat_one_hot, key_scaled), axis=1)

    if joint_cat_matrix.shape[1] > 0:
        cluster_data = np.concatenate((num_min_max, cat_one_hot, key_scaled), axis=1)
    else:
        cluster_data = np.concatenate((num_min_max, key_scaled), axis=1)
    # kmeans = KMeans(n_clusters=num_clusters, n_init='auto', init='k-means++')

    # print('clustering')
    # kmeans.fit(cluster_data)

    # cluster_labels = kmeans.labels_

    child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int)

    # voting to determine the cluster label for each parent
    group_cluster_labels = []
    curr_index = 0
    curr_label = 0
    for group_length in child_group_lengths:
        group_cluster_labels.append(curr_label)
        curr_label += 1
        curr_index += group_length

    group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1))

    # obtain the child data with clustering
    sorted_child_data_with_cluster = np.concatenate(
        [
            sorted_child_data,
            group_assignment
        ],
        axis=1
    )

    group_labels_list = group_cluster_labels
    group_lengths_list = child_group_lengths.tolist()

    group_lengths_dict = {}
    for i in range(len(group_labels_list)):
        group_label = group_labels_list[i]
        if not group_label in group_lengths_dict:
            group_lengths_dict[group_label] = defaultdict(int)
        group_lengths_dict[group_label][group_lengths_list[i]] += 1

    group_lengths_prob_dicts = {}
    for group_label, freq_dict in group_lengths_dict.items():
        group_lengths_prob_dicts[group_label] = freq_to_prob(freq_dict)

    # recover the preprocessed data back to dataframe
    child_df_with_cluster = pd.DataFrame(
        sorted_child_data_with_cluster,
        columns=original_child_cols + [relation_cluster_name]
    )

    # recover child df order
    child_df_with_cluster = pd.merge(
        child_df[[child_primary_key]],
        child_df_with_cluster,
        on=child_primary_key,
        how='left',
    )

    parent_id_to_cluster = {}
    for i in range(len(sorted_child_data)):
        parent_id = sorted_child_data[i, foreing_key_index]
        if parent_id in parent_id_to_cluster:
            assert(parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1])
            continue
        parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1]

    max_cluster_label = max(parent_id_to_cluster.values())

    parent_data_clusters = []
    for i in range(len(parent_data)):
        if parent_data[i, parent_primary_key_index] in parent_id_to_cluster:
            parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]])
        else:
            parent_data_clusters.append(max_cluster_label + 1)

    parent_data_clusters = np.array(parent_data_clusters).reshape(-1, 1)
    parent_data_with_cluster = np.concatenate(
        [
            parent_data,
            parent_data_clusters
        ],
        axis=1
    )
    parent_df_with_cluster = pd.DataFrame(
        parent_data_with_cluster,
        columns=original_parent_cols + [relation_cluster_name]
    )

    new_col_entry = {
        'type': 'discrete',
        'size': len(set(parent_data_clusters.flatten()))
    }

    print('num clusters: ', len(set(parent_data_clusters.flatten())))

    parent_domain_dict[relation_cluster_name] = new_col_entry.copy()
    child_domain_dict[relation_cluster_name] = new_col_entry.copy()

    parent_labels = parent_df_with_cluster[relation_cluster_name].to_numpy()
    child_labels = child_df_with_cluster[relation_cluster_name].to_numpy()

    parent_label_save_path = os.path.join(
        args.working_dir,
        parent_name,
        f'{relation_cluster_name}_label.npy'
    )

    child_label_save_path = os.path.join(
        args.working_dir,
        child_name,
        f'{relation_cluster_name}_label.npy'
    )

    if not os.path.exists(os.path.join(args.working_dir, parent_name)):
        os.makedirs(os.path.join(args.working_dir, parent_name))

    if not os.path.exists(os.path.join(args.working_dir, child_name)):
        os.makedirs(os.path.join(args.working_dir, child_name))

    np.save(parent_label_save_path, parent_labels)
    np.save(child_label_save_path, child_labels)

    return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts


def learn_to_cluster_vade(
        child_df, 
        child_domain_dict, 
        parent_df,
        parent_domain_dict,
        child_primary_key,
        parent_primary_key,
        num_clusters,
        parent_scale,
        key_scale,
        parent_name,
        child_name,
        args,
    ):
    original_child_cols = list(child_df.columns)
    original_parent_cols = list(parent_df.columns)

    relation_cluster_name = f'{parent_name}_{child_name}_cluster'

    child_data = child_df.to_numpy()
    parent_data = parent_df.to_numpy()

    child_num_cols = []
    child_cat_cols = []

    parent_num_cols = []
    parent_cat_cols = []

    for col_index, col in enumerate(original_child_cols):
        if col in child_domain_dict:
            if child_domain_dict[col]['type'] == 'discrete':
                child_cat_cols.append((col_index, col))
            else:
                child_num_cols.append((col_index, col))

    for col_index, col in enumerate(original_parent_cols):
        if col in parent_domain_dict:
            if parent_domain_dict[col]['type'] == 'discrete':
                parent_cat_cols.append((col_index, col))
            else:
                parent_num_cols.append((col_index, col))
    
    child_primary_key_index = original_child_cols.index(child_primary_key)
    parent_primary_key_index = original_parent_cols.index(parent_primary_key)
    foreing_key_index = original_child_cols.index(parent_primary_key)

    # sort child data by foreign key
    sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])]
    child_group_data_dict = CRF.tools.get_group_data_dict(sorted_child_data, [foreing_key_index,])

    # sort parent data by primary key
    sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])]

    group_lengths = []
    unique_group_ids = sorted_parent_data[:, parent_primary_key_index]
    for group_id in unique_group_ids:
        group_id = tuple([group_id])
        if not group_id in child_group_data_dict:
            group_lengths.append(0)
        else:
            group_lengths.append(len(child_group_data_dict[group_id]))

    group_lengths = np.array(group_lengths, dtype=int)

    sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths, axis=0)
    assert((sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all())

    child_group_data = CRF.tools.get_group_data(sorted_child_data, [foreing_key_index,])

    # Treat all features as numerical

    # child_group_data_cat = []
    # child_group_data_num = []
    # for group_data in child_group_data:
    #     child_group_data_cat.append(group_data[:, [col_index for col_index, col in child_cat_cols]])
    #     child_group_data_num.append(group_data[:, [col_index for col_index, col in child_num_cols]])
    #     # id_indices = [child_primary_key_index, foreing_key_index]
    #     # child_group_data_num.append(np.delete(group_data, id_indices, axis=1))
    #     # child_group_data_cat.append(np.empty((group_data.shape[0], 0)))

    num_group_data = []
    cat_group_data = []
    curr_index = 0
    for child_group in child_group_data:
        parent_data_group = sorted_parent_data_repeated[curr_index: curr_index + len(child_group)]
        curr_index += len(child_group)
        # num_group_data.append(np.concatenate([child_group[:, [col_index for col_index, col in child_num_cols]], parent_data_group[:, [col_index for col_index, col in parent_num_cols]],], axis=1))
        # cat_group_data.append(np.concatenate([child_group[:, [col_index for col_index, col in child_cat_cols]], parent_data_group[:, [col_index for col_index, col in parent_cat_cols]],], axis=1))
        # treat all features as numerical
        num_group_data.append(np.concatenate([child_group, parent_data_group], axis=1))
        cat_group_data.append(np.empty((child_group.shape[0], 0)))

    # train cluster_vae
    cluster_vae_args = {}
    # cluster_vae_args['cat_group_data'] = child_group_data_cat
    # cluster_vae_args['num_group_data'] = child_group_data_num

    cluster_vae_args['cat_group_data'] = cat_group_data
    cluster_vae_args['num_group_data'] = num_group_data

    cluster_vae_args['max_beta'] = args.max_beta
    cluster_vae_args['min_beta'] = args.min_beta
    cluster_vae_args['lambd'] = args.lambd
    cluster_vae_args['device'] = torch.device(args.device)
    cluster_vae_args['info'] = {
        'task_type': 'None',
        'n_classes': 0
    }
    vae_ckpt_dir = os.path.join(
        args.working_dir, 
        child_name, 
        f'{parent_name}_{child_name}', 
        'vae_vade'
    )
    if not os.path.exists(vae_ckpt_dir):
        os.makedirs(vae_ckpt_dir)

    cluster_vae_args['ckpt_dir'] = vae_ckpt_dir
    cluster_vae_args['vae_epochs'] = args.learn_to_cluster_epochs
    cluster_vae_args['vae_batch_size'] = args.vae_batch_size
    cluster_vae_args['has_y'] = args.has_y
    cluster_vae_args['has_test'] = args.has_test
    cluster_vae_args['read_ckpt'] = args.read_ckpt
    cluster_vae_args['num_clusters'] = num_clusters

    clusters = train_vae_vade(cluster_vae_args)

    cluster_labels = clusters.flatten()

    child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int)

    # voting to determine the cluster label for each parent
    group_cluster_labels = []
    curr_index = 0
    agree_rates = []
    for group_length in child_group_lengths:
        # First, determine the most common label in the current group
        most_common_label_count = np.max(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_labels.append(group_cluster_label)
        
        # Compute agree rate using the most common label count
        agree_rate = most_common_label_count / group_length
        agree_rates.append(agree_rate)
        
        # Then, update the curr_index for the next iteration
        curr_index += group_length

    # Compute the average agree rate across all groups
    average_agree_rate = np.mean(agree_rates)
    print('average agree rate: ', average_agree_rate)

    group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1))

    # obtain the child data with clustering
    sorted_child_data_with_cluster = np.concatenate(
        [
            sorted_child_data,
            group_assignment
        ],
        axis=1
    )

    group_labels_list = group_cluster_labels
    group_lengths_list = child_group_lengths.tolist()

    group_lengths_dict = {}
    for i in range(len(group_labels_list)):
        group_label = group_labels_list[i]
        if not group_label in group_lengths_dict:
            group_lengths_dict[group_label] = defaultdict(int)
        group_lengths_dict[group_label][group_lengths_list[i]] += 1

    group_lengths_prob_dicts = {}
    for group_label, freq_dict in group_lengths_dict.items():
        group_lengths_prob_dicts[group_label] = freq_to_prob(freq_dict)

    # recover the preprocessed data back to dataframe
    child_df_with_cluster = pd.DataFrame(
        sorted_child_data_with_cluster,
        columns=original_child_cols + [relation_cluster_name]
    )

    # recover child df order
    child_df_with_cluster = pd.merge(
        child_df[[child_primary_key]],
        child_df_with_cluster,
        on=child_primary_key,
        how='left',
    )

    parent_id_to_cluster = {}
    for i in range(len(sorted_child_data)):
        parent_id = sorted_child_data[i, foreing_key_index]
        if parent_id in parent_id_to_cluster:
            assert(parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1])
            continue
        parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1]

    max_cluster_label = max(parent_id_to_cluster.values())

    parent_data_clusters = []
    for i in range(len(parent_data)):
        if parent_data[i, parent_primary_key_index] in parent_id_to_cluster:
            parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]])
        else:
            parent_data_clusters.append(max_cluster_label + 1)

    parent_data_clusters = np.array(parent_data_clusters).reshape(-1, 1)
    parent_data_with_cluster = np.concatenate(
        [
            parent_data,
            parent_data_clusters
        ],
        axis=1
    )
    parent_df_with_cluster = pd.DataFrame(
        parent_data_with_cluster,
        columns=original_parent_cols + [relation_cluster_name]
    )

    new_col_entry = {
        'type': 'discrete',
        'size': len(set(parent_data_clusters.flatten()))
    }

    print('num clusters: ', len(set(parent_data_clusters.flatten())))

    parent_domain_dict[relation_cluster_name] = new_col_entry.copy()
    child_domain_dict[relation_cluster_name] = new_col_entry.copy()

    parent_labels = parent_df_with_cluster[relation_cluster_name].to_numpy()
    child_labels = child_df_with_cluster[relation_cluster_name].to_numpy()

    parent_label_save_path = os.path.join(
        args.working_dir,
        parent_name,
        f'{relation_cluster_name}_label.npy'
    )

    child_label_save_path = os.path.join(
        args.working_dir,
        child_name,
        f'{relation_cluster_name}_label.npy'
    )

    if not os.path.exists(os.path.join(args.working_dir, parent_name)):
        os.makedirs(os.path.join(args.working_dir, parent_name))

    if not os.path.exists(os.path.join(args.working_dir, child_name)):
        os.makedirs(os.path.join(args.working_dir, child_name))

    np.save(parent_label_save_path, parent_labels)
    np.save(child_label_save_path, child_labels)

    return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts



def learn_to_cluster(
        child_df, 
        child_domain_dict, 
        parent_df,
        parent_domain_dict,
        child_primary_key,
        parent_primary_key,
        num_clusters,
        parent_scale,
        key_scale,
        parent_name,
        child_name,
        args,
    ):
    original_child_cols = list(child_df.columns)
    original_parent_cols = list(parent_df.columns)

    relation_cluster_name = f'{parent_name}_{child_name}_cluster'

    child_data = child_df.to_numpy()
    parent_data = parent_df.to_numpy()

    child_num_cols = []
    child_cat_cols = []

    parent_num_cols = []
    parent_cat_cols = []

    for col_index, col in enumerate(original_child_cols):
        if col in child_domain_dict:
            if child_domain_dict[col]['type'] == 'discrete':
                child_cat_cols.append((col_index, col))
            else:
                child_num_cols.append((col_index, col))

    for col_index, col in enumerate(original_parent_cols):
        if col in parent_domain_dict:
            if parent_domain_dict[col]['type'] == 'discrete':
                parent_cat_cols.append((col_index, col))
            else:
                parent_num_cols.append((col_index, col))
    
    child_primary_key_index = original_child_cols.index(child_primary_key)
    parent_primary_key_index = original_parent_cols.index(parent_primary_key)
    foreing_key_index = original_child_cols.index(parent_primary_key)

    # sort child data by foreign key
    sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])]
    child_group_data_dict = CRF.tools.get_group_data_dict(sorted_child_data, [foreing_key_index,])

    # sort parent data by primary key
    sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])]

    group_lengths = []
    unique_group_ids = sorted_parent_data[:, parent_primary_key_index]
    for group_id in unique_group_ids:
        group_id = tuple([group_id])
        if not group_id in child_group_data_dict:
            group_lengths.append(0)
        else:
            group_lengths.append(len(child_group_data_dict[group_id]))

    group_lengths = np.array(group_lengths, dtype=int)

    sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths, axis=0)
    assert((sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all())

    child_group_data = CRF.tools.get_group_data(sorted_child_data, [foreing_key_index,])

    # Treat all features as numerical

    child_group_data_cat = []
    child_group_data_num = []
    for group_data in child_group_data:
        child_group_data_cat.append(group_data[:, [col_index for col_index, col in child_cat_cols]])
        child_group_data_num.append(group_data[:, [col_index for col_index, col in child_num_cols]])
        # id_indices = [child_primary_key_index, foreing_key_index]
        # child_group_data_num.append(np.delete(group_data, id_indices, axis=1))
        # child_group_data_cat.append(np.empty((group_data.shape[0], 0)))

    # train cluster_vae
    cluster_vae_args = {}
    cluster_vae_args['cat_group_data'] = child_group_data_cat
    cluster_vae_args['num_group_data'] = child_group_data_num

    cluster_vae_args['max_beta'] = args.max_beta
    cluster_vae_args['min_beta'] = args.min_beta
    cluster_vae_args['lambd'] = args.lambd
    cluster_vae_args['device'] = torch.device(args.device)
    cluster_vae_args['info'] = {
        'task_type': 'None',
        'n_classes': 0
    }
    vae_ckpt_dir = os.path.join(
        args.working_dir, 
        child_name, 
        f'{parent_name}_{child_name}', 
        'cluster_vae'
    )
    if not os.path.exists(vae_ckpt_dir):
        os.makedirs(vae_ckpt_dir)

    cluster_vae_args['ckpt_dir'] = vae_ckpt_dir
    cluster_vae_args['vae_epochs'] = args.learn_to_cluster_epochs
    cluster_vae_args['vae_batch_size'] = args.vae_batch_size
    cluster_vae_args['has_y'] = args.has_y
    cluster_vae_args['has_test'] = args.has_test
    cluster_vae_args['read_ckpt'] = args.read_ckpt
    cluster_vae_args['num_clusters'] = num_clusters

    train_h = train_cluster_vae(cluster_vae_args)

    discretized = discretize_array_evenly(train_h, num_clusters)
    cluster_labels = discretized.flatten()

    child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int)

    # voting to determine the cluster label for each parent
    group_cluster_labels = []
    curr_index = 0
    agree_rates = []
    for group_length in child_group_lengths:
        # First, determine the most common label in the current group
        most_common_label_count = np.max(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_labels.append(group_cluster_label)
        
        # Compute agree rate using the most common label count
        agree_rate = most_common_label_count / group_length
        agree_rates.append(agree_rate)
        
        # Then, update the curr_index for the next iteration
        curr_index += group_length

    # Compute the average agree rate across all groups
    average_agree_rate = np.mean(agree_rates)
    print('average agree rate: ', average_agree_rate)

    group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1))

    # obtain the child data with clustering
    sorted_child_data_with_cluster = np.concatenate(
        [
            sorted_child_data,
            group_assignment
        ],
        axis=1
    )

    group_labels_list = group_cluster_labels
    group_lengths_list = child_group_lengths.tolist()

    group_lengths_dict = {}
    for i in range(len(group_labels_list)):
        group_label = group_labels_list[i]
        if not group_label in group_lengths_dict:
            group_lengths_dict[group_label] = defaultdict(int)
        group_lengths_dict[group_label][group_lengths_list[i]] += 1

    group_lengths_prob_dicts = {}
    for group_label, freq_dict in group_lengths_dict.items():
        group_lengths_prob_dicts[group_label] = freq_to_prob(freq_dict)

    # recover the preprocessed data back to dataframe
    child_df_with_cluster = pd.DataFrame(
        sorted_child_data_with_cluster,
        columns=original_child_cols + [relation_cluster_name]
    )

    # recover child df order
    child_df_with_cluster = pd.merge(
        child_df[[child_primary_key]],
        child_df_with_cluster,
        on=child_primary_key,
        how='left',
    )

    parent_id_to_cluster = {}
    for i in range(len(sorted_child_data)):
        parent_id = sorted_child_data[i, foreing_key_index]
        if parent_id in parent_id_to_cluster:
            assert(parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1])
            continue
        parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1]

    max_cluster_label = max(parent_id_to_cluster.values())

    parent_data_clusters = []
    for i in range(len(parent_data)):
        if parent_data[i, parent_primary_key_index] in parent_id_to_cluster:
            parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]])
        else:
            parent_data_clusters.append(max_cluster_label + 1)

    parent_data_clusters = np.array(parent_data_clusters).reshape(-1, 1)
    parent_data_with_cluster = np.concatenate(
        [
            parent_data,
            parent_data_clusters
        ],
        axis=1
    )
    parent_df_with_cluster = pd.DataFrame(
        parent_data_with_cluster,
        columns=original_parent_cols + [relation_cluster_name]
    )

    new_col_entry = {
        'type': 'discrete',
        'size': len(set(parent_data_clusters.flatten()))
    }

    print('num clusters: ', len(set(parent_data_clusters.flatten())))

    parent_domain_dict[relation_cluster_name] = new_col_entry.copy()
    child_domain_dict[relation_cluster_name] = new_col_entry.copy()

    parent_labels = parent_df_with_cluster[relation_cluster_name].to_numpy()
    child_labels = child_df_with_cluster[relation_cluster_name].to_numpy()

    parent_label_save_path = os.path.join(
        args.working_dir,
        parent_name,
        f'{relation_cluster_name}_label.npy'
    )

    child_label_save_path = os.path.join(
        args.working_dir,
        child_name,
        f'{relation_cluster_name}_label.npy'
    )

    if not os.path.exists(os.path.join(args.working_dir, parent_name)):
        os.makedirs(os.path.join(args.working_dir, parent_name))

    if not os.path.exists(os.path.join(args.working_dir, child_name)):
        os.makedirs(os.path.join(args.working_dir, child_name))

    np.save(parent_label_save_path, parent_labels)
    np.save(child_label_save_path, child_labels)

    return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts


def tabsyn_pair_clustering_keep_id(
        child_df, 
        child_domain_dict, 
        parent_df,
        parent_domain_dict,
        child_primary_key,
        parent_primary_key,
        num_clusters,
        parent_scale,
        key_scale,
        parent_name,
        child_name,
        args,
        handle_size1=False
    ):
    original_child_cols = list(child_df.columns)
    original_parent_cols = list(parent_df.columns)

    relation_cluster_name = f'{parent_name}_{child_name}_cluster'

    child_data = child_df.to_numpy()
    parent_data = parent_df.to_numpy()

    child_num_cols = []
    child_cat_cols = []

    parent_num_cols = []
    parent_cat_cols = []

    for col_index, col in enumerate(original_child_cols):
        if col in child_domain_dict:
            if child_domain_dict[col]['type'] == 'discrete':
                child_cat_cols.append((col_index, col))
            else:
                child_num_cols.append((col_index, col))

    for col_index, col in enumerate(original_parent_cols):
        if col in parent_domain_dict:
            if parent_domain_dict[col]['type'] == 'discrete':
                parent_cat_cols.append((col_index, col))
            else:
                parent_num_cols.append((col_index, col))
    
    child_primary_key_index = original_child_cols.index(child_primary_key)
    parent_primary_key_index = original_parent_cols.index(parent_primary_key)
    foreing_key_index = original_child_cols.index(parent_primary_key)

    # sort child data by foreign key
    sorted_child_data = child_data[np.argsort(child_data[:, foreing_key_index])]
    child_group_data_dict = CRF.tools.get_group_data_dict(sorted_child_data, [foreing_key_index,])

    # sort parent data by primary key
    sorted_parent_data = parent_data[np.argsort(parent_data[:, parent_primary_key_index])]

    group_lengths = []
    unique_group_ids = sorted_parent_data[:, parent_primary_key_index]
    for group_id in unique_group_ids:
        group_id = tuple([group_id])
        if not group_id in child_group_data_dict:
            group_lengths.append(0)
        else:
            group_lengths.append(len(child_group_data_dict[group_id]))

    group_lengths = np.array(group_lengths, dtype=int)

    sorted_parent_data_repeated = np.repeat(sorted_parent_data, group_lengths, axis=0)
    assert((sorted_parent_data_repeated[:, parent_primary_key_index] == sorted_child_data[:, foreing_key_index]).all())

    child_group_data = CRF.tools.get_group_data(sorted_child_data, [foreing_key_index,])

    sorted_child_num_data = sorted_child_data[:, [col_index for col_index, col in child_num_cols]]
    sorted_child_cat_data = sorted_child_data[:, [col_index for col_index, col in child_cat_cols]]
    sorted_parent_num_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_num_cols]]
    sorted_parent_cat_data = sorted_parent_data_repeated[:, [col_index for col_index, col in parent_cat_cols]]

    joint_num_matrix = np.concatenate([sorted_child_num_data, sorted_parent_num_data], axis=1)
    joint_cat_matrix = np.concatenate([sorted_child_cat_data, sorted_parent_cat_data], axis=1)

    if joint_cat_matrix.shape[1] > 0:

        joint_cat_matrix_p_index = sorted_child_cat_data.shape[1]
        joint_num_matrix_p_index = sorted_child_num_data.shape[1]

        cat_converted = []
        label_encoders = []
        for i in range(joint_cat_matrix.shape[1]):
            if len(np.unique(joint_cat_matrix[:, i])) > 1000:
                continue
            label_encoder = LabelEncoder()
            cat_converted.append(label_encoder.fit_transform(joint_cat_matrix[:, i]).astype(float))
            label_encoders.append(label_encoder)

        cat_converted = np.vstack(cat_converted).T

        # Initialize an empty array to store the encoded values
        cat_one_hot = np.empty((cat_converted.shape[0], 0))

        # Loop through each column in the data and encode it
        for col in range(cat_converted.shape[1]):
            encoder = OneHotEncoder(sparse_output=False)
            column = cat_converted[:, col].reshape(-1, 1)
            encoded_column = encoder.fit_transform(column)
            cat_one_hot = np.concatenate((cat_one_hot, encoded_column), axis=1)

        cat_one_hot[:, joint_cat_matrix_p_index:] = parent_scale * cat_one_hot[:, joint_cat_matrix_p_index:]

    # Perform quantile normalization using QuantileTransformer
    num_quantile = quantile_normalize_sklearn(joint_num_matrix)
    num_min_max = min_max_normalize_sklearn(joint_num_matrix)

    key_quantile = quantile_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1))
    key_min_max = min_max_normalize_sklearn(sorted_parent_data_repeated[:, parent_primary_key_index].reshape(-1, 1))

    # key_scaled = key_scaler * key_quantile
    key_scaled = key_scale * key_min_max

    num_quantile[:, joint_num_matrix_p_index:] = parent_scale * num_quantile[:, joint_num_matrix_p_index:]
    num_min_max[:, joint_num_matrix_p_index:] = parent_scale * num_min_max[:, joint_num_matrix_p_index:]
    

    # cluster_data = np.concatenate((num_quantile, cat_one_hot, key_scaled), axis=1)

    if joint_cat_matrix.shape[1] > 0:
        cluster_data = np.concatenate((num_min_max, cat_one_hot, key_scaled), axis=1)
    else:
        cluster_data = np.concatenate((num_min_max, key_scaled), axis=1)
    kmeans = KMeans(n_clusters=num_clusters, n_init='auto', init='k-means++')

    print('clustering')
    kmeans.fit(cluster_data)

    cluster_labels = kmeans.labels_

    child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int)

    # voting to determine the cluster label for each parent
    group_cluster_labels = []
    curr_index = 0
    agree_rates = []
    for group_length in child_group_lengths:
        # First, determine the most common label in the current group
        most_common_label_count = np.max(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_label = np.argmax(np.bincount(cluster_labels[curr_index: curr_index + group_length]))
        group_cluster_labels.append(group_cluster_label)
        
        # Compute agree rate using the most common label count
        agree_rate = most_common_label_count / group_length
        agree_rates.append(agree_rate)
        
        # Then, update the curr_index for the next iteration
        curr_index += group_length

    # Compute the average agree rate across all groups
    average_agree_rate = np.mean(agree_rates)
    print('average agree rate: ', average_agree_rate)

    group_assignment = np.repeat(group_cluster_labels, child_group_lengths, axis=0).reshape((-1, 1))

    # obtain the child data with clustering
    sorted_child_data_with_cluster = np.concatenate(
        [
            sorted_child_data,
            group_assignment
        ],
        axis=1
    )

    group_labels_list = group_cluster_labels
    group_lengths_list = child_group_lengths.tolist()

    group_lengths_dict = {}
    for i in range(len(group_labels_list)):
        group_label = group_labels_list[i]
        if not group_label in group_lengths_dict:
            group_lengths_dict[group_label] = defaultdict(int)
        group_lengths_dict[group_label][group_lengths_list[i]] += 1

    group_lengths_prob_dicts = {}
    for group_label, freq_dict in group_lengths_dict.items():
        group_lengths_prob_dicts[group_label] = freq_to_prob(freq_dict)

    # recover the preprocessed data back to dataframe
    child_df_with_cluster = pd.DataFrame(
        sorted_child_data_with_cluster,
        columns=original_child_cols + [relation_cluster_name]
    )

    # recover child df order
    child_df_with_cluster = pd.merge(
        child_df[[child_primary_key]],
        child_df_with_cluster,
        on=child_primary_key,
        how='left',
    )

    parent_id_to_cluster = {}
    for i in range(len(sorted_child_data)):
        parent_id = sorted_child_data[i, foreing_key_index]
        if parent_id in parent_id_to_cluster:
            assert(parent_id_to_cluster[parent_id] == sorted_child_data_with_cluster[i, -1])
            continue
        parent_id_to_cluster[parent_id] = sorted_child_data_with_cluster[i, -1]

    max_cluster_label = max(parent_id_to_cluster.values())

    parent_data_clusters = []
    for i in range(len(parent_data)):
        if parent_data[i, parent_primary_key_index] in parent_id_to_cluster:
            parent_data_clusters.append(parent_id_to_cluster[parent_data[i, parent_primary_key_index]])
        else:
            parent_data_clusters.append(max_cluster_label + 1)

    parent_data_clusters = np.array(parent_data_clusters).reshape(-1, 1)
    parent_data_with_cluster = np.concatenate(
        [
            parent_data,
            parent_data_clusters
        ],
        axis=1
    )
    parent_df_with_cluster = pd.DataFrame(
        parent_data_with_cluster,
        columns=original_parent_cols + [relation_cluster_name]
    )

    new_col_entry = {
        'type': 'discrete',
        'size': len(set(parent_data_clusters.flatten()))
    }

    print('num clusters: ', len(set(parent_data_clusters.flatten())))

    parent_domain_dict[relation_cluster_name] = new_col_entry.copy()
    child_domain_dict[relation_cluster_name] = new_col_entry.copy()

    parent_labels = parent_df_with_cluster[relation_cluster_name].to_numpy()
    child_labels = child_df_with_cluster[relation_cluster_name].to_numpy()

    parent_label_save_path = os.path.join(
        args.working_dir,
        parent_name,
        f'{relation_cluster_name}_label.npy'
    )

    child_label_save_path = os.path.join(
        args.working_dir,
        child_name,
        f'{relation_cluster_name}_label.npy'
    )

    if not os.path.exists(os.path.join(args.working_dir, parent_name)):
        os.makedirs(os.path.join(args.working_dir, parent_name))

    if not os.path.exists(os.path.join(args.working_dir, child_name)):
        os.makedirs(os.path.join(args.working_dir, child_name))

    np.save(parent_label_save_path, parent_labels)
    np.save(child_label_save_path, child_labels)

    return parent_df_with_cluster, child_df_with_cluster, group_lengths_prob_dicts

def update_domain(old_domain_dict, df_with_cluster):
    new_domain_dict = old_domain_dict.copy()
    for col in df_with_cluster.columns:
        if not col in old_domain_dict and not '_id' in col:
            new_domain_dict[col] = {
                'type': 'categorical',
            }
    return new_domain_dict

def update_info(new_domain_dict, df_with_cluster):
    new_info = get_info_from_domain(
        data_df=df_with_cluster,
        domain_dict=new_domain_dict,
    )
    return new_info

def update_table_info(table_name, old_domain_dict, df_with_cluster):
    id_cols = [col for col in df_with_cluster.columns if '_id' in col]
    df_with_cluster_no_id = df_with_cluster.drop(columns=id_cols)
    new_domain = update_domain(
        old_domain_dict=old_domain_dict, 
        df_with_cluster=df_with_cluster_no_id
    )
    new_info = update_info(
        new_domain_dict=new_domain, 
        df_with_cluster=df_with_cluster_no_id
    )
    new_data, new_info = pipeline_process_data(
        name=table_name, 
        data_df=df_with_cluster_no_id, 
        info=new_info, 
        ratio=1,
        save=False
    )

    return new_data, new_info, new_domain


def get_data(data, has_test=False):
    res = {
        'X_cat': {
            'train': data['numpy']['X_cat_train'],
        },
        'X_num': {
            'train': data['numpy']['X_num_train'],
        },
        'y': {
            'train': data['numpy']['y_train'],
        }
    }

    if has_test:
        res['X_cat']['test'] = data['numpy']['X_cat_test']
        res['X_num']['test'] = data['numpy']['X_num_test']
        res['y']['test'] = data['numpy']['y_test']

    return res


def tabsyn_child_training(
    args,
    parent_name,
    child_name,
    data,
    read_ckpt=False
):
    # train vae
    vae_args = {}
    vae_args['data'] = get_data(data[(parent_name, child_name)]['data'], args.has_test)

    vae_args['max_beta'] = args.max_beta
    vae_args['min_beta'] = args.min_beta
    vae_args['lambd'] = args.lambd
    vae_args['device'] = torch.device(args.device)
    vae_args['info'] = {
        'task_type': 'None',
        'n_classes': 0
    }
    vae_ckpt_dir = os.path.join(
        args.working_dir, 
        child_name, 
        f'{parent_name}_{child_name}', 
        'vae'
    )
    if not os.path.exists(vae_ckpt_dir):
        os.makedirs(vae_ckpt_dir)

    vae_args['ckpt_dir'] = vae_ckpt_dir
    vae_args['vae_epochs'] = args.vae_epochs
    vae_args['vae_batch_size'] = args.vae_batch_size
    vae_args['has_y'] = args.has_y
    vae_args['has_test'] = args.has_test

    if args.big_data:
        vae_epochs = json.load(open(os.path.join(args.data_dir, 'vae_epochs.json')))
        vae_args['vae_epochs'] = vae_epochs[child_name]

    vae_args['read_ckpt'] = read_ckpt

    train_vae(vae_args)

    tabsyn_args = {}
    tabsyn_args['device'] = torch.device(args.device)
    tabsyn_ckpt_dir = os.path.join(
        args.working_dir, 
        child_name, 
        f'{parent_name}_{child_name}',
        'model'
    )
    if not os.path.exists(tabsyn_ckpt_dir):
        os.makedirs(tabsyn_ckpt_dir)
    tabsyn_args['ckpt_dir'] = tabsyn_ckpt_dir
    tabsyn_args['tabsyn_num_epochs'] = args.tabsyn_num_epochs
    tabsyn_args['embedding_save_path'] = os.path.join(vae_args['ckpt_dir'], 'train_z.npy')
    tabsyn_args['read_ckpt'] = read_ckpt

    train_tabsyn(tabsyn_args)

    if parent_name is None:
        return {
            'vae': vae_ckpt_dir,
            'tabsyn': tabsyn_ckpt_dir,
            'classifier': None,
            'num_classes': None
        }
    else:
        classifier_args = {}
        classifier_args['device'] = torch.device(args.device)
        classifier_args['ckpt_dir'] = tabsyn_args['ckpt_dir']
        classifier_args['classifier_train_split_ratio'] = args.classifier_train_split_ratio
        classifier_args['embedding_save_path'] = tabsyn_args['embedding_save_path']
        classifier_save_path = os.path.join(
            args.working_dir, 
            child_name, 
            f'{parent_name}_{child_name}',
            'classifier'
        )
        if not os.path.exists(classifier_save_path):
            os.makedirs(classifier_save_path)
        classifier_args['classifier_save_path'] = classifier_save_path
        classifier_args['label_path'] = os.path.join(
            args.working_dir, 
            child_name, 
            f'{parent_name}_{child_name}_cluster_label.npy'
        )
        classifier_args['batch_size'] = args.batch_size
        assert(np.equal(
            np.load(classifier_args['label_path']),
            data[(parent_name, child_name)]['labels']
        ).all())
        classifier_args['classifier_epochs'] = args.classifier_epochs
        if args.big_data:
            classifier_args['classifier_epochs'] = vae_args['vae_epochs'] // 2

        classifier_args['read_ckpt'] = read_ckpt

        num_classes = tabsyn_train_classifier(classifier_args)

        result = {
            'vae': vae_ckpt_dir,
            'tabsyn': tabsyn_ckpt_dir,
            'classifier': classifier_save_path,
            'num_classes': num_classes
        }

    return result


def tabsyn_sample_child(
    args,
    parent_name,
    child_name,
    child_result,
    info,
    data,
    group_lengths_prob_dicts=None,
    sampled_parent_labels=None,
    num_classes=None
):
    sample_args = {}
    sample_args['device'] = torch.device(args.device)
    sample_args['classifier_ckpt_path'] = child_result['classifier']
    sample_args['ckpt_dir'] = child_result['tabsyn']
    save_dir = os.path.join(
        args.working_dir, 
        child_name, 
        'unmatched_sample'
    )
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    sample_args['save_path'] = None
    sample_args['sample_batch_size'] = args.sample_batch_size
    sample_args['info'] = info
    sample_args['task_type'] = 'None'
    sample_args['has_y'] = args.has_y
    sample_args['has_test'] = args.has_test
    sample_args['embedding_save_path'] = os.path.join(child_result['vae'], 'train_z.npy')
    sample_args['decoder_save_path'] = os.path.join(child_result['vae'], 'decoder.pt')
    sample_args['data'] = get_data(data, args.has_test)
    sample_args['num_steps'] = args.num_steps

    if parent_name is None:
        return cond_sample(sample_args)
    else:
        group_labels = sampled_parent_labels
        sampled_group_sizes = []
        ys = []
        for group_label in group_labels:
            if not group_label in group_lengths_prob_dicts:
                sampled_group_sizes.append(0)
                continue
            sampled_group_size = sample_from_dict(group_lengths_prob_dicts[group_label])
            sampled_group_sizes.append(sampled_group_size)
            ys.extend([group_label] * sampled_group_size)

        sample_args['labels'] = np.array(ys)
        sample_args['classifier_scale'] = args.classifier_scale
        sample_args['num_classes'] = num_classes

        sampled_dataframe = cond_sample(sample_args)

        return sampled_dataframe, sampled_group_sizes
