import os
import pandas as pd
import json
import pickle
import argparse

os.environ['CUDA_PATH'] = '/u3/w3pang/pkgs/cuda-11.7.0'

from tabsyn.eval.eval_quality import eval_metrics
from tabsyn.eval.eval_mle import eval_mle
from tabsyn.eval.eval_dcr import eval_dcr
from tabsyn.eval.eval_detection import eval_detection

def get_info(syn_df, domain_dict, target_col):
    num_col_index = []
    cat_col_index = []
    target_col_index = []
    info = {}
    table_cols = list(syn_df.columns)
    for i in range(len(table_cols)):
        col = table_cols[i]
        if col in domain_dict and col != target_col:
            if domain_dict[col]['type'] == 'discrete':
                cat_col_index.append(i)
            else:
                num_col_index.append(i)
        if col == target_col:
            target_col_index.append(i)
            if col in domain_dict:
                if domain_dict[col]['type'] == 'discrete':
                    if domain_dict[col]['size'] == 2:
                        info['task_type'] = 'binclass'
                    else:
                        info['task_type'] = 'multiclass'
                else:
                    info['task_type'] = 'regression'

    info['num_col_idx'] = num_col_index
    info['cat_col_idx'] = cat_col_index
    info['target_col_idx'] = target_col_index
    if not 'task_type' in info:
        info['task_type'] = 'None'

    return info

def compute_alpha_beta(real_df, syn_df, domain_dict, sample_size=200000):
    # drop id cols
    all_columns = list(real_df.columns)
    id_cols = [col for col in all_columns if '_id' in col]
    real_df = real_df.drop(columns=id_cols)
    id_cols = [col for col in list(syn_df.columns) if '_id' in col]
    syn_df = syn_df.drop(columns=id_cols)

    info = get_info(syn_df, domain_dict, '')
    syn_df = syn_df.dropna()

    sample_size = min(sample_size, len(syn_df), len(real_df))

    syn_df = syn_df.sample(sample_size)
    real_df = real_df.sample(sample_size)

    if len(real_df) > len(syn_df):
        real_df = real_df.sample(len(syn_df))
    elif len(real_df) < len(syn_df):
        syn_df = syn_df.sample(len(real_df))

    alpha, beta = eval_metrics(syn_df, real_df, info)
    return alpha, beta

def compute_all_mle(syn_df, test_df, domain_dict):
    mles = {}
    for col, _ in domain_dict.items():
        print('Computing MLE for column:', col)
        mle = compute_mle(syn_df, test_df, domain_dict, col)
        mles[col] = mle
        print(f'MLE for column {col}: {mle}')
    return mles

def compute_mle(syn_df, test_df, domain_dict, target_col):
    # drop id cols
    all_columns = list(test_df.columns)
    id_cols = [col for col in all_columns if '_id' in col]

    test_df = test_df.drop(columns=id_cols)
    id_cols = [col for col in list(syn_df.columns) if '_id' in col]
    syn_df = syn_df.drop(columns=id_cols)

    info = get_info(syn_df, domain_dict, target_col)
    syn_df = syn_df.dropna()

    mle = eval_mle(syn_df.values, test_df.values, info)
    return mle

def compute_dcr(syn_df, real_df, test_df, domain_dict):
    # drop id cols
    all_columns = list(syn_df.columns)
    id_cols = [col for col in all_columns if '_id' in col]
    real_df = real_df.drop(columns=id_cols)
    syn_df = syn_df.drop(columns=id_cols)
    test_df = test_df.drop(columns=id_cols)

    info = get_info(syn_df, domain_dict, '')

    dcr_score = eval_dcr(syn_df, real_df, test_df, info)
    return dcr_score

def compute_detection(syn_df, real_df, domain_dict):
    # drop id cols
    all_columns = list(real_df.columns)
    id_cols = [col for col in all_columns if '_id' in col]
    real_df = real_df.drop(columns=id_cols)
    id_cols = [col for col in list(syn_df.columns) if '_id' in col]
    syn_df = syn_df.drop(columns=id_cols)

    detection_score = eval_detection(syn_df, real_df, domain_dict)
    return detection_score

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--real_data_path', type=str)
    parser.add_argument('--syn_data_path', type=str)
    parser.add_argument('--test_data_path', type=str)
    parser.add_argument('--domain_dict_path', type=str)
    parser.add_argument('--table_name', type=str)
    parser.add_argument('--alpha_beta_sample_size', type=int, default=200000)
    parser.add_argument('--save_path')

    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)

    real_data = pd.read_csv(args.real_data_path)
    syn_data = pd.read_csv(args.syn_data_path)
    # test_data = pd.read_csv(args.test_data_path)
    domain_dict = json.load(open(args.domain_dict_path, 'r'))

    if not os.path.exists(os.path.join(args.save_path, f'{args.table_name}_alpha.pkl')) or\
        not os.path.exists(os.path.join(args.save_path, f'{args.table_name}_beta.pkl')):
        alpha, beta = compute_alpha_beta(real_data, syn_data, domain_dict)
        with open(os.path.join(args.save_path, f'{args.table_name}_alpha.pkl'), 'wb') as f:
            pickle.dump(alpha, f)

        with open(os.path.join(args.save_path, f'{args.table_name}_beta.pkl'), 'wb') as f:
            pickle.dump(beta, f)

    else:
        alpha = pickle.load(open(os.path.join(args.save_path, f'{args.table_name}_alpha.pkl'), 'rb'))
        beta = pickle.load(open(os.path.join(args.save_path, f'{args.table_name}_beta.pkl'), 'rb'))

    print(f'alpha: {alpha}, beta: {beta}')
    

    # if not os.path.exists(os.path.join(args.save_path, f'{args.table_name}_mles.pkl')):
    #     mles = compute_all_mle(syn_data, test_data, domain_dict)
    #     with open(os.path.join(args.save_path, f'{args.table_name}_mles.pkl'), 'wb') as f:
    #         pickle.dump(mles, f)
    # else:
    #     mles = pickle.load(open(os.path.join(args.save_path, f'{args.table_name}_mles.pkl'), 'rb'))
    

    # if not os.path.exists(os.path.join(args.save_path, f'{args.table_name}_dcr.pkl')):
    #     dcr_score = compute_dcr(syn_data, real_data, test_data, domain_dict)
    #     with open(os.path.join(args.save_path, f'{args.table_name}_dcr.pkl'), 'wb') as f:
    #         pickle.dump(dcr_score, f)
    # else:
    #     dcr_score = pickle.load(open(os.path.join(args.save_path, f'{args.table_name}_dcr.pkl'), 'rb'))

    # print(f'DCR Score: {dcr_score}')
    
    if not os.path.exists(os.path.join(args.save_path, f'{args.table_name}_detection.pkl')):
        detection_score = compute_detection(syn_data, real_data, domain_dict)
        with open(os.path.join(args.save_path, f'{args.table_name}_detection.pkl'), 'wb') as f:
            pickle.dump(detection_score, f)
    else:
        detection_score = pickle.load(open(os.path.join(args.save_path, f'{args.table_name}_detection.pkl'), 'rb'))

    print(f'Detection Score: {detection_score}')

    
