import os.path
from experiments.experiment_utils import experiment
import numpy as np
import csv
from collections import OrderedDict
import tabular_deep_smote.tabular_deep_smote_wrapper as tabular_deep_smote_wrapper
import experiments.experiment_settings as experiment_settings

np.set_printoptions(precision=4)

def gen_compare_table(AP_per_dataset, method_names, dataset_names, json_logs_per_dataset, table_name):
    # Compare
    ranks = np.zeros((0, len(AP_per_dataset[0])))
    is_best = np.zeros((0, len(AP_per_dataset[0])))
    is_worst = np.zeros((0, len(AP_per_dataset[0])))
    for dataset_APs in AP_per_dataset:
        dataset_APs = np.around(dataset_APs, 3)
        value_set = set(dataset_APs)
        if len(value_set) == 1:
            continue
        else:
            rank_map = dict(zip(list(value_set), np.array(list(value_set)).argsort().argsort()))  # higher rank -> better
            dataset_ranks = np.array([rank_map[method_AP] for method_AP in dataset_APs])
            dataset_is_best = dataset_ranks == len(value_set) - 1
            dataset_is_worst = dataset_ranks == 0
            ranks = np.concatenate((ranks, dataset_ranks[None, :]), 0)
            is_best = np.concatenate((is_best, dataset_is_best[None, :]), 0)
            is_worst = np.concatenate((is_worst, dataset_is_worst[None, :]), 0)

    """
    avg_ranks = np.around(ranks.mean(axis=0), 3)
    ap_means = np.around(np.array(AP_per_dataset).mean(axis=0), 3)
    num_best = is_best.sum(axis=0).astype(int)
    num_worst = is_worst.sum(axis=0).astype(int)
    """

    with open(f'experiments/results/{config_name}_{table_name}.csv', 'w') as f:
        writer = csv.writer(f)
        writer.writerow([''] + method_names)
        for idx, (dataset_APs, json_log) in enumerate(zip(AP_per_dataset, json_logs_per_dataset)):
            writer.writerow([dataset_names[idx]] + list(dataset_APs))
        """
        writer.writerow([''] + method_names + ['best_losses'] + ['best_epoch'] + ['best_hparams'])
        assert len(AP_per_dataset) == len(json_logs_per_dataset)
        for idx, (dataset_APs, json_log) in enumerate(zip(AP_per_dataset, json_logs_per_dataset)):
            writer.writerow([dataset_names[idx]] + list(dataset_APs) + [json_log['best_losses']['model_loss'],
                                                                        json_log['best_epoch'], json_log['best_params']])
        writer.writerow(['AP Mean'] + list(ap_means))
        writer.writerow(['Avg Rank'] + list(avg_ranks))
        writer.writerow(['Num Best'] + list(num_best))
        writer.writerow(['Num Worst'] + list(num_worst))
        """

########
## MAIN
########
if not os.path.isdir('experiments/results'):
    os.mkdir('experiments/results')
RUN_BASELINERS = True
epochs = 100
enc_hidden_dims_default = '32x 16x'
dec_hidden_dims_default = '8x 16x 32x'
TEST_NAME = 'TEST'
seed_list = [42]
APs_per_config = OrderedDict()
for seed in seed_list:
    AP_per_numeric_datasets = []  # basis for numeric datasets table
    ROC_per_numeric_datasets = []  # basis for numeric datasets table
    F1_per_numeric_datasets = [] # basis for numeric datasets table
    numeric_dataset_names = []
    numeric_json_logs = []
    AP_per_cat_datasets = []  # basis for categorical datasets table
    ROC_per_cat_datasets = []  # basis for categorical datasets table
    F1_per_cat_datasets = []  # basis for categorical datasets table
    cat_dataset_names = []
    cat_json_logs = []
    deep_smote_AP_per_dataset = []
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    config_name = f'{TEST_NAME}_{seed}'
    print(f'Config: {config_name}')
    for dataset in experiment_settings.DATASETS:
        if dataset.name == 'webpage':
            enc_hidden_dims = '1024 512'
            dec_hidden_dims = '256 512 1024'
        else:
            enc_hidden_dims = enc_hidden_dims_default
            dec_hidden_dims = dec_hidden_dims_default
        print(f'{dataset.name}')
        args = [f'--dataset_name={dataset.name}',
                f'--train_path={dataset.train_pt}',
                f'--test_path={dataset.test_pt}',
                f'--new_minority_save_path={dataset.new_minority_pt}',
                f'--enc_hidden_dims={enc_hidden_dims}',
                f'--dec_hidden_dims={dec_hidden_dims}',
                f'--epochs={epochs}',
                f'--no_run_optuna',
                f'--no_compare_to_baseliners',
                f'--no_gen_visuals',
                f'--classifier_type={experiment_settings.CLASSIFIER_TYPE}',
                f'--metric_learn_type=normalized_softmax',
                f'--smote_algo_type=orig',
                f'--reweight_loss',
                f'--lambda_metric_learn={1.0}',
                f'--importance_oversampling',
                f'--filter_margin=2',
                f'--no_verbose',
                ]
        if dataset.cat_columns_argument:
            args.append(f'--categorical_features={dataset.cat_columns_argument}')
        json_log = tabular_deep_smote_wrapper.main(args)
        print(f"rec_loss = {json_log['best_losses']['model_loss']} ; best_epoch = {json_log['best_epoch']}")

        if RUN_BASELINERS:
            eval_results, oversampled_data = experiment(dataset.train_pt, dataset.test_pt,
                                                        dataset.new_minority_pt,
                                                        categorical_features=dataset.cat_columns,
                                                        seed=123,
                                                        m_neighbors=10,  ## the number of neighbors to detect border points
                                                        k_neighbors=5,  ## the number of neighbors border points will interpolate with
                                                        classifier_type=experiment_settings.CLASSIFIER_TYPE,
                                                        verbose=False)
            APs = [values['AP'] for method, values in eval_results.items()]
            ROCs = [values['ROC_AUC'] for method, values in eval_results.items()]
            F1s = [values['F1'] for method, values in eval_results.items()]
            print(APs)
            print(ROCs)
            print(F1s)
            if dataset.cat_columns:
                AP_per_cat_datasets.append(APs)
                ROC_per_cat_datasets.append(ROCs)
                F1_per_cat_datasets.append(F1s)
                cat_methods = list(eval_results.keys())
                cat_dataset_names.append(dataset.name)
                cat_json_logs.append(json_log)
            else:
                AP_per_numeric_datasets.append(APs)
                ROC_per_numeric_datasets.append(ROCs)
                F1_per_numeric_datasets.append(F1s)
                numeric_methods = list(eval_results.keys())
                numeric_dataset_names.append(dataset.name)
                numeric_json_logs.append(json_log)
        else:
            deep_smote_AP_per_dataset.append(json_log['AP'])
    if RUN_BASELINERS:
        # generate table for this specific configuration
        if numeric_dataset_names:
            gen_compare_table(AP_per_numeric_datasets, numeric_methods, numeric_dataset_names, numeric_json_logs, 'numeric_AP')
            gen_compare_table(ROC_per_numeric_datasets, numeric_methods, numeric_dataset_names, numeric_json_logs, 'numeric_ROC')
            gen_compare_table(F1_per_numeric_datasets, numeric_methods, numeric_dataset_names, numeric_json_logs, 'numeric_F1')
        if cat_dataset_names:
            gen_compare_table(AP_per_cat_datasets, cat_methods, cat_dataset_names, cat_json_logs, 'cat_AP')
            gen_compare_table(ROC_per_cat_datasets, cat_methods, cat_dataset_names, cat_json_logs, 'cat_ROC')
            gen_compare_table(F1_per_cat_datasets, cat_methods, cat_dataset_names, cat_json_logs, 'cat_F1')
        # add deep_smote APs (which are in -1 index) for the general comparison between configurations
        APs_per_config[config_name] = [AP_per_dataset[-1] for AP_per_dataset in AP_per_numeric_datasets + AP_per_cat_datasets]
    else:
        APs_per_config[config_name] = deep_smote_AP_per_dataset

dataset_names = [dataset.name for dataset in experiment_settings.DATASETS]
with open(f'experiments/results/summary.csv', 'w') as f:
    writer = csv.writer(f)
    writer.writerow([''] + dataset_names)
    for config, APs in APs_per_config.items():
        writer.writerow([config]+list(APs))
