import numpy as np
import torch 
import pandas as pd
import os 
import sys

import json
import pickle

# Metrics
from sdmetrics import load_demo
from sdmetrics.single_table import LogisticDetection

from matplotlib import pyplot as plt

import argparse
import warnings
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--dataname', type=str, default='adult')
parser.add_argument('--model', type=str, default='real')
parser.add_argument('--path', type=str, default = None, help='The file path of the synthetic data')

args = parser.parse_args()

def reorder(real_data, syn_data, info):
    num_col_idx = info['num_col_idx']
    cat_col_idx = info['cat_col_idx']
    target_col_idx = info['target_col_idx']

    task_type = info['task_type']
    if task_type == 'regression':
        num_col_idx += target_col_idx
    else:
        cat_col_idx += target_col_idx

    real_num_data = real_data[num_col_idx]
    real_cat_data = real_data[cat_col_idx]

    new_real_data = pd.concat([real_num_data, real_cat_data], axis=1)
    new_real_data.columns = range(len(new_real_data.columns))

    syn_num_data = syn_data[num_col_idx]
    syn_cat_data = syn_data[cat_col_idx]
    
    new_syn_data = pd.concat([syn_num_data, syn_cat_data], axis=1)
    new_syn_data.columns = range(len(new_syn_data.columns))

    
    metadata = info['metadata']

    columns = metadata['columns']
    metadata['columns'] = {}

    inverse_idx_mapping = info['inverse_idx_mapping']


    for i in range(len(new_real_data.columns)):
        if i < len(num_col_idx):
            metadata['columns'][i] = columns[num_col_idx[i]]
        else:
            metadata['columns'][i] = columns[cat_col_idx[i-len(num_col_idx)]]
    

    return new_real_data, new_syn_data, metadata

if __name__ == '__main__':

    dataname = args.dataname
    model = args.model

    if not args.path:
        syn_path = f'synthetic/{dataname}/{model}.csv'
    else:
        syn_path = args.path

    # syn_path = f'synthetic/{dataname}/{model}.csv'
    real_path = f'synthetic/{dataname}/real.csv'

    data_dir = f'data/{dataname}' 
    print(syn_path)

    with open(f'{data_dir}/info.json', 'r') as f:
        info = json.load(f)

    if dataname in ['canada', 'fiji', 'uk', 'rwanda', 'indonesia', 'adulta','churn', 'tcga','diabetes','tcgaa']:

        syn_data = pd.read_csv(syn_path, dtype=str)
        real_data = pd.read_csv(real_path, dtype=str)

        discrete_columns = [real_data.columns[i] for i in info['cat_col_idx']]
        numerical_columns = [real_data.columns[i] for i in info['num_col_idx']]
        if info['task_type'] == 'binclass': discrete_columns += real_data.columns[info['target_col_idx']].tolist()
        else: numerical_columns += real_data.columns[info['target_col_idx']].tolist()

        real_data[numerical_columns] = real_data[numerical_columns].astype(float)
        if dataname not in ['adulta', 'churn','tcga','tcgaa','diabetes']:
            real_data['AGE'] = real_data['AGE'].round(0).astype(int)
        
        # real_data[discrete_columns] = real_data[discrete_columns].replace('nan', np.nan, inplace=True)
        real_data[discrete_columns] = real_data[discrete_columns].replace(to_replace=r'(\d+)\.0\b', 
                                                                        value=r'\1', regex=True)
        
        syn_data[numerical_columns] = syn_data[numerical_columns].astype(float)
        if dataname not in ['adulta', 'churn','tcga','tcgaa','diabetes']:
            syn_data['AGE'] = syn_data['AGE'].round(0).astype(int)
        # syn_data[discrete_columns] = syn_data[discrete_columns].replace('nan', np.nan, inplace=True)
        syn_data[discrete_columns] = syn_data[discrete_columns].replace(to_replace=r'(\d+)\.0\b', 
                                                                        value=r'\1', regex=True)
        
        if dataname == 'indonesia':
            cols_to_fix = ['LANDOWN','HOMEFEM','HOMEMALE']
            syn_data[cols_to_fix] = syn_data[cols_to_fix].replace({str(j): f'0{j}' for j in range(10)})
            # syn_data[cols_to_fix] = syn_data[cols_to_fix].replace({0: "00", "0": "00"})
            syn_data['EDATTAIND'] = syn_data['EDATTAIND'].replace({0: "000", "0": "000"})
        
        syn_data.fillna('nan')

    else:
        syn_data = pd.read_csv(syn_path)
        real_data = pd.read_csv(real_path)

    save_dir = f'eval/density/{dataname}/{model}'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    print(real_data[discrete_columns].shape, syn_data[discrete_columns].shape)
    for i in range(real_data[discrete_columns].to_numpy().shape[1]):
        print(i, 'real', np.unique(real_data[discrete_columns].to_numpy()[:,i]))
        print(i, 'synt', np.unique(syn_data[discrete_columns].to_numpy()[:,i]))
        
    real_data.columns = range(len(real_data.columns))
    syn_data.columns = range(len(syn_data.columns))

    metadata = info['metadata']
    metadata['columns'] = {int(key): value for key, value in metadata['columns'].items()}

    new_real_data, new_syn_data, metadata = reorder(real_data, syn_data, info)

    # qual_report.generate(new_real_data, new_syn_data, metadata)

    score = LogisticDetection.compute(
        real_data=new_real_data,
        synthetic_data=new_syn_data,
        metadata=metadata
    )

    print(f'{dataname}, {model}: {score}')
    print(score)