from pipeline_utils import *

import numpy as np
import os
import pandas as pd
import json
from tab_ddpm.utils import *
from pipeline_modules import *
import argparse

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default='movie_lens_1m')
    parser.add_argument('--parent_splitted', type=int, default=0)
    parser.add_argument('--child_splitted', type=int, default=0)
    parser.add_argument('--path_prefix', type=str, default='movie_lens_1m/')
    parser.add_argument('--child_csv_name', type=str, default='user_rates_movie.csv')
    parser.add_argument('--parent_csv_name', type=str, default='movie.csv')
    parser.add_argument('--child_domain_name', type=str, default='user_rates_movie_domain_continuous.json')
    parser.add_argument('--parent_domain_name', type=str, default='movie_domain_continuous.json')
    parser.add_argument('--KEY_SCALER', type=int, default=1)
    parser.add_argument('--PARENT_SCALER', type=int, default=1)
    parser.add_argument('--NUM_CLUSTERS', type=int, default=20)
    parser.add_argument('--CLASSIFIER_SCALE', type=int, default=2)
    parser.add_argument('--CHILD_PRIMARY_KEY', type=str, default='user_rates_movie_id')
    parser.add_argument('--PARENT_PRIMARY_KEY', type=str, default='movie_id')
    parser.add_argument('--SAMPLE_SCALE', type=int, default=1)
    parser.add_argument('--CLUSTER_WITH_CRF', type=int, default=0)
    parser.add_argument('--CRF_CKPT_PATH', type=str, default='temp/001.pkl')
    parser.add_argument('--BASE', type=int, default=10)
    parser.add_argument('--classifier_steps', type=int, default=1000)
    parser.add_argument('--individual_steps', type=int, default=10000)
    parser.add_argument('--household_steps', type=int, default=10000)
    parser.add_argument('--ih_steps', type=int, default=10000)
    parser.add_argument('--model_type', type=str, default='mlp')
    parser.add_argument('--handle_size1', action='store_true')
    parser.add_argument('--PARENT_NAME', type=str, default='movie')
    parser.add_argument('--CHILD_NAME', type=str, default='user_rates_movie')

    args = parser.parse_args()

    save_dir = os.path.join(args.path_prefix, 'save', args.exp_name)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    with open(os.path.join(save_dir, 'args'), 'w') as file:
        json.dump(vars(args), file, indent=4)

    parent_domain_path = args.path_prefix + args.parent_domain_name
    child_domain_path = args.path_prefix + args.child_domain_name

    parent_csv_path = args.path_prefix + args.parent_csv_name
    child_csv_path = args.path_prefix + args.child_csv_name

    parent_domain_dict = json.load(open(parent_domain_path))
    child_domain_dict = json.load(open(child_domain_path))

    individual_cat_cols = []
    individual_num_cols = []

    for col, _ in child_domain_dict.items():
        if child_domain_dict[col]['type'] == 'discrete':
            individual_cat_cols.append(col)
        else:
            individual_num_cols.append(col)

    household_cat_cols = []

    household_num_cols = []

    for col, _ in parent_domain_dict.items():
        if parent_domain_dict[col]['type'] == 'discrete':
            household_cat_cols.append(col)
        else:
            household_num_cols.append(col)

    print('child_cat_cols: ', individual_cat_cols)
    print()
    print('child_num_cols: ', individual_num_cols)
    print()
    print('parent_cat_cols: ', household_cat_cols)
    print()
    print('parent_num_cols: ', household_num_cols)

    model_type = args.model_type

    lr = 0.0006
    weight_decay = 1e-05
    batch_size = 4096
    scheduler = 'cosine'

    test_num_samples = 50000
    num_timesteps = 2000
    gaussian_loss_type = "mse"

    os.environ["CUDA_VISIBLE_DEVICES"] = '0'

    individual_df = pd.read_csv(child_csv_path)
    household_df = pd.read_csv(parent_csv_path)
    individual_domain_dict = json.load(open(child_domain_path))
    household_domain_dict = json.load(open(parent_domain_path))

    household_df_with_cluster, individual_df_with_cluster, group_lengths_prob_dicts = pair_clustering_keep_id(
        individual_df,
        individual_domain_dict,
        household_df,
        household_domain_dict,
        args.CHILD_PRIMARY_KEY,
        args.PARENT_PRIMARY_KEY,
        args.NUM_CLUSTERS,
        args.PARENT_SCALER,
        args.KEY_SCALER,
        args.PARENT_NAME,
        args.CHILD_NAME,
    )

    household_df_with_cluster = household_df_with_cluster.drop(columns=[args.PARENT_PRIMARY_KEY])
    individual_df_with_cluster = individual_df_with_cluster.drop(columns=[args.CHILD_PRIMARY_KEY, args.PARENT_PRIMARY_KEY])

    individual_result = child_training(
        individual_df_with_cluster,
        individual_domain_dict,
        'household',
        'individual',
        args.individual_steps,
        args.classifier_steps,
        batch_size,
        model_type,
        gaussian_loss_type,
        num_timesteps,
        scheduler,
        lr,
        weight_decay,
        test_num_samples
    )

    household_result = child_training(
        household_df_with_cluster,
        household_domain_dict,
        None,
        'household',
        args.household_steps,
        args.classifier_steps,
        batch_size,
        model_type,
        gaussian_loss_type,
        num_timesteps,
        scheduler,
        lr,
        weight_decay,
        test_num_samples
    )

    size1_len = 0

    _, household_generated = sample_from_diffusion(
        df=household_df_with_cluster, 
        df_info=household_result['df_info'], 
        diffusion=household_result['diffusion'],
        dataset=household_result['dataset'],
        label_encoders=household_result['label_encoders'],
        sample_size=int(args.SAMPLE_SCALE * (len(household_df) - size1_len)),
        model_params=household_result['model_params'],
        T_dict=household_result['T_dict'],
        test_batch_size=100000,
    )

    individual_parent_label_index = household_result['column_orders'].index(
        individual_result['df_info']['y_col']
    )

    _, individual_generated, individual_sampled_group_sizes = conditional_sampling_by_group_size(
        df=individual_df_with_cluster, 
        df_info=individual_result['df_info'],
        dataset=individual_result['dataset'],
        label_encoders=individual_result['label_encoders'],
        classifier=individual_result['classifier'],
        diffusion=individual_result['diffusion'],
        group_labels=household_generated.values[:, individual_parent_label_index].astype(float).astype(int).tolist(),
        group_lengths_prob_dicts=group_lengths_prob_dicts,
        sample_batch_size=100000,
        is_y_cond='none',
        classifier_scale=args.CLASSIFIER_SCALE,
    )

    # save the generated data
    household_generated.to_csv(os.path.join(save_dir, 'parent_generated.csv'), index=False)
    individual_generated.to_csv(os.path.join(save_dir, 'child_generated.csv'), index=False)

    parent_keys = list(range(len(household_generated)))
    child_keys = np.repeat(parent_keys, individual_sampled_group_sizes, axis=0).reshape((-1, 1))

    individual_generated_household_ids_arr = np.array(child_keys).reshape(-1, 1)

    individual_generated_individual_ids_arr = np.arange(
        len(individual_generated)
    ).reshape(-1, 1)

    individual_generated_final_arr = np.concatenate(
        [
            individual_generated_individual_ids_arr,
            individual_generated.to_numpy()[:, :-1],#.astype(int), # remove the cluster column
            individual_generated_household_ids_arr
        ],
        axis=1
    )

    individual_final_df = pd.DataFrame(
        individual_generated_final_arr,
        columns=[args.CHILD_PRIMARY_KEY] + individual_num_cols + individual_cat_cols + [args.PARENT_PRIMARY_KEY]
    )#.astype('int64')

    individual_final_df = individual_final_df[individual_df.columns]

    household_generated_final_arr = np.concatenate(
        [
            np.array(parent_keys).reshape(-1, 1),
            household_generated.to_numpy()[:, :-1],#.astype(int) # remove the cluster column
        ],
        axis=1
    )

    household_final_df = pd.DataFrame(
        household_generated_final_arr,
        columns=[args.PARENT_PRIMARY_KEY] + household_num_cols + household_cat_cols + [household_result['df_info']['y_col']]
    )#.astype('int64')

    household_final_df = household_final_df[household_df.columns]

    individual_final_df.to_csv(os.path.join(save_dir, 'child_final.csv'), index=False)
    household_final_df.to_csv(os.path.join(save_dir, 'parent_final.csv'), index=False)

    if args.child_splitted == 1:
        child_final_splitted_df = pd.read_csv(os.path.join(save_dir, 'child_final.csv'))
        child_original_splitted_df = pd.read_csv(args.path_prefix + 'child_splitted.csv')
        child_cat_domain_splitted_dict_list = json.load(open(args.path_prefix + 'child_cat_domain_splitted_list'))
        child_cat_new_col_list_list = json.load(open(args.path_prefix + 'child_cat_new_col_list_list.json'))
        child_splitted_domain_dict = json.load(open(args.path_prefix + 'child_splitted_domain_dict.json'))

        child_recovered_generated_df, _ = recover_df(
            splitted_df=child_final_splitted_df, 
            prefix='child_', 
            splitted_domain_dict=child_splitted_domain_dict, 
            cat_new_col_list_list=child_cat_new_col_list_list,
            cat_domain_splitted_dict_list=child_cat_domain_splitted_dict_list,
            base=args.BASE
        )

        child_recovered_original_df, child_recovered_domain_dict = recover_df(
            splitted_df=child_original_splitted_df,
            prefix='child_',
            splitted_domain_dict=child_splitted_domain_dict, 
            cat_new_col_list_list=child_cat_new_col_list_list,
            cat_domain_splitted_dict_list=child_cat_domain_splitted_dict_list,
            base=args.BASE
        )

        child_recovered_generated_df = pd.concat(
            [
                child_final_splitted_df[args.CHILD_PRIMARY_KEY],
                child_recovered_generated_df,
                child_final_splitted_df[args.PARENT_PRIMARY_KEY]
            ],
            axis=1
        )

        child_recovered_original_df = pd.concat(
            [
                child_original_splitted_df[args.CHILD_PRIMARY_KEY],
                child_recovered_original_df,
                child_original_splitted_df[args.PARENT_PRIMARY_KEY]
            ],
            axis=1
        )

        child_recovered_original_df.to_csv(os.path.join(save_dir, 'child_recovered_original_df.csv'), index=False)
        child_recovered_generated_df.to_csv(os.path.join(save_dir, 'child_recovered_generated_df.csv'), index=False)
        json.dump(child_recovered_domain_dict, open(os.path.join(save_dir, 'child_recovered_domain_dict.json'), 'w'))


    if args.parent_splitted == 1:
        parent_final_splitted_df = pd.read_csv(os.path.join(save_dir, 'parent_final.csv'))
        parent_original_splitted_df = pd.read_csv(args.path_prefix + 'parent_splitted.csv')
        parent_cat_domain_splitted_dict_list = json.load(open(args.path_prefix + 'parent_cat_domain_splitted_list'))
        parent_cat_new_col_list_list = json.load(open(args.path_prefix + 'parent_cat_new_col_list_list.json'))
        parent_splitted_domain_dict = json.load(open(args.path_prefix + 'parent_splitted_domain_dict.json'))

        parent_recovered_generated_df, _ = recover_df(
            splitted_df=parent_final_splitted_df,
            prefix='parent_',
            splitted_domain_dict=parent_splitted_domain_dict,
            cat_new_col_list_list=parent_cat_new_col_list_list,
            cat_domain_splitted_dict_list=parent_cat_domain_splitted_dict_list,
            base=args.BASE
        )

        parent_recovered_original_df, parent_recovered_domain_dict = recover_df(
            splitted_df=parent_original_splitted_df,
            prefix='parent_',
            splitted_domain_dict=parent_splitted_domain_dict,
            cat_new_col_list_list=parent_cat_new_col_list_list,
            cat_domain_splitted_dict_list=parent_cat_domain_splitted_dict_list,
            base=args.BASE
        )

        parent_recovered_generated_df = pd.concat(
            [
                parent_final_splitted_df[args.PARENT_PRIMARY_KEY],
                parent_recovered_generated_df,
            ],
            axis=1
        )

        parent_recovered_original_df = pd.concat(
            [
                parent_original_splitted_df[args.PARENT_PRIMARY_KEY],
                parent_recovered_original_df,
            ],
            axis=1
        )

        parent_recovered_original_df.to_csv(os.path.join(save_dir, 'parent_recovered_original_df.csv'), index=False)
        parent_recovered_generated_df.to_csv(os.path.join(save_dir, 'parent_recovered_generated_df.csv'), index=False)
        json.dump(parent_recovered_domain_dict, open(os.path.join(save_dir, 'parent_recovered_domain_dict.json'), 'w'))


if __name__ == '__main__':
    main()
