

import argparse
import torch as th
import pandas as pd
from pathlib import Path
from collections import OrderedDict
## Oversamplers
from imblearn.over_sampling import BorderlineSMOTE, SMOTE, SVMSMOTE, RandomOverSampler, SMOTENC, SMOTEN
from smote_variants import polynom_fit_SMOTE, ProWSyn, SMOTE_IPF
from .CtganOversampler import CtganOversampler

## Evaluation Related
from catboost.core import CatBoostClassifier
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

## Data Loaders
from datasets import tabular_data_loaders


import os, random
import numpy as np

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    th.manual_seed(seed)
    th.cuda.manual_seed(seed)
    th.backends.cudnn.deterministic = True
    th.backends.cudnn.benchmark = True

def create_classifier(type, seed, class_weight=None, categorical_features=None):
    if type == "catboost":
        auto_class_weights = "Balanced" if class_weight=='balanced' else None
        return CatBoostClassifier(
            random_seed=seed,
            learning_rate=0.2,
            iterations=100,
            depth=6,
            verbose=False,
            auto_class_weights=auto_class_weights,
            cat_features=categorical_features
        )
    if type == "svm":
        return SVC(kernel="rbf", probability=True, class_weight=class_weight)
    if type == "knn":
        return KNeighborsClassifier()
    if type == "logistic_regression":
        return LogisticRegression(random_state=seed)
    if type == "decision_tree":
        return DecisionTreeClassifier()
    raise Exception("Classifier type not supported")


def evaluate(x_train, y_train, x_test, y_test, categorical_features, classifier_type, seed, class_weight=None):
    """
    x_train,... - pytorch tensors
    """
    x_train, y_train = x_train.float().numpy(), y_train.float().numpy()
    x_test, y_test = x_test.numpy(), y_test.numpy()
    if categorical_features:
        # CatBoost requires categorical columns to be int or str.
        # Torch.tensor cannot includes two dtypes (cat + float) so transition to Pd.dataframe
        x_train = pd.DataFrame(x_train).astype("float")
        x_train[categorical_features] = x_train[categorical_features].astype(int).astype(str)
        x_train = x_train.replace(to_replace='-1', value='None')
        x_test = pd.DataFrame(x_test).astype("float")
        x_test[categorical_features] = x_test[categorical_features].astype(int).astype(str)
        x_test = x_test.replace(to_replace='-1', value='None')
        y_train = pd.DataFrame(y_train).astype(int)
    classifier = create_classifier(type=classifier_type, seed=seed, class_weight=class_weight, categorical_features=categorical_features)
    classifier.fit(x_train, y_train)
    y_predict = classifier.predict_proba(x_test)[:, 1]
    AP = average_precision_score(y_test, y_predict)
    ROC_AUC = roc_auc_score(y_test, y_predict)
    F1 = f1_score(y_test, y_predict > 0.5)
    return AP, ROC_AUC, F1

def assert_balanced(y):
    min = len(y[y == 1])
    maj = len(y[y == 0])
    assert min in range(int(maj*0.95), int(maj*1.05))

def fit_and_evaluate(model, train_data, test_data, categorical_features, classifier_type, seed, pass_cat_at_fit=False):
    x_train, y_train = train_data
    x_test, y_test = test_data
    if pass_cat_at_fit:
        x_all, y_all = model.fit_resample(x_train, y_train, categorical_features)
    else:
        x_all, y_all = model.fit_resample(x_train, y_train)
    x_all = th.from_numpy(x_all) if isinstance(x_all, np.ndarray) else x_all
    y_all = th.from_numpy(y_all) if isinstance(y_all, np.ndarray) else y_all
    assert_balanced(y_all)
    AP, ROC_AUC, F1 = evaluate(x_all, y_all, x_test, y_test, categorical_features, classifier_type, seed)
    return AP, ROC_AUC, F1, (x_all, y_all)

def fit_and_evaluate_v2(model, train_data, test_data, categorical_features, classifier_type, seed):
    x_train, y_train = train_data
    x_test, y_test = test_data
    x_all, y_all = model.sample(x_train.numpy(), y_train.numpy())
    x_all = th.from_numpy(x_all)
    y_all = th.from_numpy(y_all)
    #assert_balanced(y_all)
    AP, ROC_AUC, F1 = evaluate(x_all, y_all, x_test, y_test, categorical_features, classifier_type, seed)
    return AP, ROC_AUC, F1, (x_all, y_all)

def load_and_evaluate(file_pt, train_data, test_data, categorical_features, classifier_type, seed):
    x_train, y_train = train_data
    x_test, y_test = test_data
    x_gen, y_gen = th.load(file_pt)
    x_all = th.cat([x_train] + [x_gen])
    y_all = th.cat([y_train] + [y_gen])
    assert_balanced(y_all)
    AP, ROC_AUC, F1 = evaluate(x_all, y_all, x_test, y_test, categorical_features, classifier_type, seed)
    return AP, ROC_AUC, F1, (x_all, y_all)


def experiment(train_pt, test_pt,
               deep_smote_new_minority_pt,
               categorical_features: list,
               classifier_type,
               seed=42,
               m_neighbors=10,            ## the number of neighbors to detect border points
               k_neighbors=5,             ## the number of neighbors border points will interpolate with
               verbose=False,
):
    seed_everything(seed) #seed_everything(seed, workers=True)
    result_format = "{:<40} - AP:{:<8.3f}, ROC_AUC:{:.3f}, F1:{:.3f}"
    eval_results = OrderedDict()
    oversampled_data = OrderedDict()

    if categorical_features:
        assert classifier_type != 'svm'

    def _round(x):
        return round(x, 4)

    # Load Data
    data = tabular_data_loaders.load_tabular_data(Path(train_pt), None, Path(test_pt))

    total_train_data = (data.x_train_total, data.y_train_total)
    test_data = (data.x_test, data.y_test)

    # No oversampling
    AP, ROC_AUC, F1 = evaluate(*total_train_data,
                               *test_data,
                               categorical_features, classifier_type, seed)
    eval_results['no_os'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
    if verbose: print(result_format.format("none", AP, ROC_AUC, F1))

    if categorical_features: ############################################### Start only Cat methods

        if len(categorical_features) == data.x_test.shape[1]:
            # SMOTE N
            smoten_model = SMOTEN(random_state=seed)
            AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(smoten_model, total_train_data, test_data, categorical_features, classifier_type, seed)
            eval_results['smoten'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
            oversampled_data['smoten'] = (x_all, y_all)
            if verbose: print(result_format.format("smoten", AP, ROC_AUC, F1))
        else:
            # SMOTE NC
            smotenc_model = SMOTENC(categorical_features=categorical_features,
                                    random_state=seed,
                                    k_neighbors=k_neighbors)
            AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(smotenc_model, total_train_data, test_data, categorical_features, classifier_type, seed)
            eval_results['smotenc'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
            oversampled_data['smotenc'] = (x_all, y_all)
            if verbose: print(result_format.format("smotenc", AP, ROC_AUC, F1))
    ############################################################################### End only Cat methods

    else: ############################################### Start only Numeric methods

        # SMOTE
        smote_model = SMOTE(random_state=seed,
                            k_neighbors=k_neighbors)
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(smote_model, total_train_data, test_data, None, classifier_type, seed)
        eval_results['smote'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['smote'] = (x_all, y_all)
        if verbose: print(result_format.format("smote", AP, ROC_AUC, F1))

        # B-SMOTE 1
        bsmote1_model = BorderlineSMOTE(
            random_state=seed,
            kind='borderline-1',
            m_neighbors=m_neighbors,
            k_neighbors=k_neighbors,
            )
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(bsmote1_model, total_train_data, test_data, None, classifier_type, seed)
        eval_results['bsmote1'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['bsmote1'] = (x_all, y_all)
        if verbose: print(result_format.format("bsmote-1", AP, ROC_AUC, F1))

        # B-SMOTE 2
        bsmote2_model = BorderlineSMOTE(
            random_state=seed,
            kind='borderline-2',
            m_neighbors=m_neighbors,
            k_neighbors=k_neighbors,
            )
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(bsmote2_model, total_train_data, test_data, None, classifier_type, seed)
        eval_results['bsmote2'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['bsmote2'] = (x_all, y_all)
        if verbose: print(result_format.format("bsmote-2", AP, ROC_AUC, F1))

        # PolySMOTE
        poly_fit_model = polynom_fit_SMOTE(random_state=seed)
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate_v2(poly_fit_model, total_train_data, test_data, None, classifier_type, seed)
        eval_results['poly_fit'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['poly_fit'] = (x_all, y_all)
        if verbose: print(result_format.format("poly_fit", AP, ROC_AUC, F1))

        # ProWSyn
        prowsyn_model = ProWSyn(random_state=seed)
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate_v2(prowsyn_model, total_train_data, test_data, None, classifier_type, seed)
        eval_results['prowsyn'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['prowsyn'] = (x_all, y_all)
        if verbose: print(result_format.format("prowsyn", AP, ROC_AUC, F1))

        # SMOTE-IPF
        smote_ipf = SMOTE_IPF(n_folds=5, random_state=seed)
        AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate_v2(smote_ipf, total_train_data, test_data, None, classifier_type, seed)
        eval_results['smote_ipf'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
        oversampled_data['smote_ipf'] = (x_all, y_all)
        if verbose: print(result_format.format("smote_ipf", AP, ROC_AUC, F1))
    ############################################################################### End only Numeric methods


    # Random oversample
    ros_model = RandomOverSampler(random_state=seed)
    AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(ros_model, total_train_data, test_data,
                                                   categorical_features, classifier_type, seed)
    eval_results['ros'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
    oversampled_data['ros'] = (x_all, y_all)
    if verbose: print(result_format.format("ros", AP, ROC_AUC, F1))

    # Reweight
    AP, ROC_AUC, F1= evaluate(*total_train_data,
                              *test_data,
                              categorical_features,
                              classifier_type, seed,
                              class_weight='balanced')
    eval_results['rw'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
    if verbose: print(result_format.format("rw", AP, ROC_AUC, F1))

    # CTGAN Oversampling
    ctgan_model = CtganOversampler(emb_dim=32)
    AP, ROC_AUC, F1, (x_all, y_all) = fit_and_evaluate(ctgan_model, total_train_data, test_data,
                                                       categorical_features, classifier_type, seed, pass_cat_at_fit=True)
    eval_results['ctgan'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
    oversampled_data['ctgan'] = (x_all, y_all)
    if verbose: print(result_format.format("ctgan", AP, ROC_AUC, F1))

    # TD-SMOTE
    AP, ROC_AUC, F1, (x_all, y_all) = load_and_evaluate(deep_smote_new_minority_pt, total_train_data, test_data,
                                                        categorical_features, classifier_type, seed)
    eval_results['td_smote'] = {'AP': _round(AP), 'ROC_AUC': _round(ROC_AUC), 'F1': _round(F1)}
    oversampled_data['td_smote'] = (x_all, y_all)
    if verbose: print(result_format.format("td_smote", AP, ROC_AUC, F1))

    return eval_results, oversampled_data