import sys, os, random
import torch
import numpy as np
import json
import argparse
from datasets import tabular_data_loaders
from pathlib import Path
import copy

# Models
import tabular_deep_smote.models as models

# Validation - Optuna
import optuna
import joblib
from datasets.dataset_utils import imbalance_preserving_Kfold

# Evaluations
from experiments.experiment_utils import experiment, load_and_evaluate

# Visualization
from visualization import visualizer


###############################################################################
## MAIN
###############################################################################

def seed_everything(seed):
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def oversample_and_evaluate(args, ae, train_data, test_data):
    x_train, y_train = train_data
    ae.eval()
    oversample_results = ae.oversample(data=(x_train, y_train),
                                       new_minority_save_path=args.new_minority_save_path,
                                       oversample_ratio=args.oversample_ratio)
    AP, ROC_AUC, F1, _ = load_and_evaluate(args.new_minority_save_path,
                                           train_data, test_data,
                                           args.categorical_features,
                                           args.classifier_type,
                                           args.classifier_seed)
    return AP, oversample_results


def train_oversample_evaluate(args, path_best_ae, train_data, test_data, hparams):
    ae = models.TDSMOTE(dataset_name=args.dataset_name,
                    categorical_features=args.categorical_features,
                    lambda_metric_learn=hparams['lambda_metric_learn'],
                    label_smoothing=hparams['label_smoothing'],
                    metric_learn_type=args.metric_learn_type,
                    mode_specific_normalization=not args.no_mode_specific_normalization,
                    reweight_loss=not args.no_reweight_loss,
                    rec_reweight_loss=not args.no_rec_reweight_loss,
                    balance_b4_train=not args.no_balance_b4_train,
                    latent_dim_ratio=hparams['latent_dim_ratio'],  #args.latent_dim_ratio,
                    enc_hidden_dims=args.enc_hidden_dims,
                    dec_hidden_dims=args.dec_hidden_dims,
                    swap_prob=args.swap_prob,
                    batch_size=args.batch_size,
                    lr=args.lr,
                    lr_decay=args.lr_decay,
                    epochs=args.epochs,
                    train_on=args.train_on,
                    early_stop_no_limit=args.early_stop_no_limit,
                    early_stop_train=args.early_stop_train,
                    early_stop_val=args.early_stop_val,
                    early_stop_val_type=args.early_stop_val_type,
                    smote_algo_type=args.smote_algo_type,
                    m_neighbors=args.m_neighbors,
                    k_neighbors=args.k_neighbors,
                    knn_algorithm=args.knn_algorithm,
                    importance_oversampling=not args.no_importance_oversampling,
                    filter_margin=args.filter_margin,
                    classifier_type=args.classifier_type,
                    gen_visuals=args.gen_visuals,
                    verbose=args.verbose,
                    device=args.device)
    train_results = ae.fit(train_data, None, path_best_ae)
    AP, oversample_results = oversample_and_evaluate(args, ae, train_data, test_data)
    return train_results, AP, oversample_results, ae


def optuna_train_oversample_evaluate(args, path_best_ae, train_data, validation_data, hparams):
    ae = models.TDSMOTE(dataset_name=args.dataset_name,
                    categorical_features=args.categorical_features,
                    lambda_metric_learn=hparams['lambda_metric_learn'],
                    metric_learn_type=args.metric_learn_type,
                    label_smoothing=hparams['label_smoothing'],
                    mode_specific_normalization=not args.no_mode_specific_normalization,
                    reweight_loss=not args.no_reweight_loss,
                    rec_reweight_loss=not args.no_rec_reweight_loss,
                    balance_b4_train=not args.no_balance_b4_train,
                    latent_dim_ratio=hparams['latent_dim_ratio'],
                    enc_hidden_dims=args.enc_hidden_dims,
                    dec_hidden_dims=args.dec_hidden_dims,
                    swap_prob=args.swap_prob,
                    batch_size=args.batch_size,
                    lr=args.lr,
                    lr_decay=args.lr_decay,
                    epochs=args.epochs,
                    train_on=args.train_on,
                    early_stop_no_limit=args.early_stop_no_limit,
                    early_stop_train=args.early_stop_train,
                    early_stop_val=args.early_stop_val,
                    early_stop_val_type=args.early_stop_val_type,
                    smote_algo_type=args.smote_algo_type,
                    m_neighbors=args.m_neighbors,
                    k_neighbors=args.k_neighbors,
                    knn_algorithm=args.knn_algorithm,
                    importance_oversampling=not args.no_importance_oversampling,
                    filter_margin=args.filter_margin,
                    classifier_type=args.classifier_type,
                    gen_visuals=args.gen_visuals,
                    verbose=args.verbose,
                    device=args.device)

    train_results = ae.fit(train_data, validation_data, path_best_ae)  # use validation for early stop
    AP = oversample_and_evaluate(args, ae, train_data, validation_data)[0]
    return AP, train_results


def objective(trial: optuna.Trial, grid_search, args, path_best_ae, train_data, hparams, optuna_search_space):
    if grid_search:
        # the -5,5,0.05 do not affect the grid search values. It is required according to Optuna documentation.
        suggest_dict = {
            name: trial.suggest_float(name, -5, 5, step=0.05) for name, _ in optuna_search_space.items()
        }
    else:
        suggest_dict = {
            name: trial.suggest_float(name, **values) if name not in ["epochs"] else trial.suggest_int(name, **values)
            for name, values in optuna_search_space.items()
        }
    hparams.update(suggest_dict)
    kf = imbalance_preserving_Kfold(n_splits=5)
    folds_avg_ap = 0
    folds_avg_loss = 0
    num_epochs = 0
    num_validation_iter = 1 if args.single_fold_validation else kf.n_splits
    for x_train, y_train, x_val, y_val in kf.split(*train_data):
        fold_ap, train_results = optuna_train_oversample_evaluate(args, path_best_ae,
                                                                 (x_train, y_train),
                                                                 (x_val, y_val), hparams)
        folds_avg_ap += fold_ap / num_validation_iter
        folds_avg_loss += train_results.best_losses['combined_loss'] / num_validation_iter
        num_epochs += train_results.best_epoch / num_validation_iter
        if args.single_fold_validation:
            break
    trial.set_user_attr("num_epochs", int(num_epochs))
    return folds_avg_loss


def run_optuna_study(args, hp_sampler, grid_search, path_best_ae, train_data,
                     hparams, optuna_search_space, n_trials, study_save_file: Path):
    study = optuna.create_study(
        direction="maximize", sampler=hp_sampler  # maximize metrics (AP / F1 / ROC-AUC...)
    )
    study.optimize(
        lambda trial: objective(trial, grid_search, args, path_best_ae, train_data, hparams, optuna_search_space),
        n_trials=n_trials,
    )
    joblib.dump(study, study_save_file)


def gen_visuals_of_latent_space(args, data, ae, x_all, y_all, type):
    if type=='pca':
        result_dir = Path(f"experiments/PCA/{args.dataset_name}/latent_space")
    elif type=='param_umap':
        result_dir = Path(f"experiments/UMAP/{args.dataset_name}/latent_space")
    visualizer.configure(dir=result_dir)
    if args.plots_suffix is not None:
        suffix = '_'+args.plots_suffix
    else:
        suffix = ''
    ## Before Oversample ##
    Y = data.y_train_total
    with torch.no_grad():
        encoded = ae.encode(data.x_train_total.to(args.device))
    X = encoded.cpu().numpy()
    if type=='pca':
        visualizer.plot_pca(X, Y, "Original_PCA" + suffix)
    elif type=='param_umap':
        visualizer.plot_param_umap(X, Y, "Original_Param_UMAP" + suffix)

    ## Orig Data + Test (w/o oversample) ##
    all_x = torch.cat([data.x_train_total]+[data.x_test])
    with torch.no_grad():
        X = ae.encode(all_x.to(args.device)).cpu()
    y_test_ = ((data.y_test==0)*2 + (data.y_test==1)*3)     # color 2 = majority samples of test, color 3 = minority ...
    Y = torch.cat((data.y_train_total, y_test_), 0)
    if type=='pca':
        visualizer.visualize_oversampled_pca(X, Y, "WITH_TEST_PCA"+suffix)
    elif type=='param_umap':
        visualizer.visualize_oversampled_param_umap(X, Y, "WITH_TEST_PCA" + suffix)

    ## After Oversample ##
    Y = y_all
    with torch.no_grad():
        x_all_encoded = ae.encode(x_all.to(args.device)).cpu()
    #min_center = ae.center_loss.min_center.detach().cpu()
    #maj_center = ae.center_loss.maj_center.detach().cpu()
    #min_center = ae.center_loss.linear_classifier.weight[1].detach().cpu()
    #maj_center = ae.center_loss.linear_classifier.weight[0].detach().cpu()
    #centers = torch.cat([min_center[None, :]]+[maj_center[None, :]])
    #X = torch.cat([x_all_encoded]+[centers]) #x_all_encoded
    #Y = torch.cat([y_all] + [torch.tensor([3, 4])])
    X = x_all_encoded
    if type=='pca':
        visualizer.visualize_oversampled_pca(X.numpy(), Y.numpy(), "Deep_SMOTE_PCA"+suffix)
    elif type=='param_umap':
        visualizer.visualize_oversampled_param_umap(X.numpy(), Y.numpy(), "Deep_SMOTE_Param_UMAP" + suffix)

def gen_visuals_of_orig_space(args, data, oversampled_data, type):
    if type=='pca':
        result_dir = Path(f"experiments/PCA/{args.dataset_name}/orig_space")
    elif type=='param_umap':
        result_dir = Path(f"experiments/UMAP/{args.dataset_name}/orig_space")
    visualizer.configure(dir=result_dir)
    if args.plots_suffix is not None:
        suffix = '_' + args.plots_suffix
    else:
        suffix = ''
    # Before Oversample
    if type=='pca':
        visualizer.plot_pca(data.x_train_total, data.y_train_total, "before_oversample" + suffix)
    elif type=='param_umap':
        visualizer.plot_param_umap(data.x_train_total, data.y_train_total, "before_oversample" + suffix)
    # Before Oversample + Test
    X = torch.cat([data.x_train_total] + [data.x_test])
    y_test_ = ((data.y_test == 0) * 2 + (data.y_test == 1) * 3) # color 2 = majority samples of test, color 3 = minority
    Y = torch.cat((data.y_train_total, y_test_), 0)
    if type=='pca':
        visualizer.visualize_oversampled_pca(X, Y, "before_oversample_with_test" + suffix)
    elif type=='param_umap':
        visualizer.visualize_oversampled_param_umap(X, Y, "before_oversample_with_test" + suffix)
    # Oversampled
    for method, all_data in oversampled_data.items():
        x_all, y_all = all_data
        if type=='pca':
            visualizer.visualize_oversampled_pca(x_all, y_all, f"{method}_oversample" + suffix)
        elif type=='param_umap':
            visualizer.visualize_oversampled_param_umap(x_all, y_all, f"{method}_oversample" + suffix)

def gen_visuals_of_orig_space_V2(args, data, oversampled_data, type):
    if type=='pca':
        result_dir = Path(f"experiments/PCA/{args.dataset_name}/orig_space")
    elif type=='param_umap':
        result_dir = Path(f"experiments/UMAP/{args.dataset_name}/orig_space")
    visualizer.configure(dir=result_dir)
    if args.plots_suffix is not None:
        suffix = '_' + args.plots_suffix
    else:
        suffix = ''
    # Oversampled
    for method, all_data in oversampled_data.items():
        x_all, y_all = all_data
        if type=='pca':
            visualizer.plot_pca(x_all, y_all, f"{method}_oversample" + suffix)
        elif type=='param_umap':
            visualizer.plot_param_umap(x_all, y_all, f"{method}_oversample" + suffix)
        # After Oversample + Test
        X = torch.cat([x_all] + [data.x_test])
        y_test_ = ((data.y_test == 0) * 2 + (data.y_test == 1) * 3) # color 2 = majority samples of test, color 3 = minority
        Y = torch.cat((y_all, y_test_), 0)
        if type=='pca':
            visualizer.visualize_oversampled_pca(X, Y, f"{method}_oversample_with_test" + suffix)
        elif type=='param_umap':
            visualizer.visualize_oversampled_param_umap(X, Y, f"{method}_oversample_with_test" + suffix)

def main(args):
    # PARSER
    parser = argparse.ArgumentParser()
    default_dataset_name = 'protein_homo'
    src = 'imblearn'  # 'imblearn' OR 'Keel/preprocessed'
    train_path_default = f"datasets/{src}/{default_dataset_name}/{default_dataset_name}_train.pt"
    test_path_default = f"datasets/{src}/{default_dataset_name}/{default_dataset_name}_test.pt"
    new_minority_save_path_default = f"datasets/{src}/{default_dataset_name}/{default_dataset_name}_new_minority_deep_smote.pt"
    default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Setup
    ####################
    parser.add_argument('--dataset_name', default=default_dataset_name)
    parser.add_argument('--categorical_features', nargs='*', default=None)
    parser.add_argument('--train_path', default=train_path_default)
    parser.add_argument('--test_path', default=test_path_default)
    parser.add_argument('--new_minority_save_path', default=new_minority_save_path_default)
    parser.add_argument("--results_dir", default='experiments/results')
    parser.add_argument("--device", default=default_device)
    parser.add_argument("--seed", type=int, default=42)
    # Deep SMOTE
    #####################
    parser.add_argument('--lambda_metric_learn', type=float, default=1.0)
    parser.add_argument('--metric_learn_type', default='normalized_softmax')
    # Arch
    parser.add_argument('--enc_hidden_dims', nargs='+', default='32x 16x')
    parser.add_argument('--dec_hidden_dims', nargs='+', default='8x 16x 32x')
    parser.add_argument('--latent_dim_ratio', type=float, default=0.75)
    # HP selection
    parser.add_argument('--run_optuna', type=bool)
    parser.add_argument('--no_run_optuna', dest='run_optuna', action='store_false')
    parser.set_defaults(run_optuna=True)
    parser.add_argument('--single_fold_validation', type=bool)
    parser.add_argument('--multi_fold_validation', dest='single_fold_validation', action='store_false')
    parser.set_defaults(single_fold_validation=True)
    # Train
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--lr', nargs='*', type=float, default=0.001)
    parser.add_argument('--lr_decay', nargs='*', type=float, default=0.9999)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--train_on', choices=['all', 'min'], default='all')
    parser.add_argument('--early_stop_no_limit', default=False)
    parser.add_argument('--early_stop_train', type=int, default=10)
    parser.add_argument('--early_stop_val', type=int, default=20)
    parser.add_argument('--early_stop_val_type', choices=['only_val', 'val_n_train'], default='only_val')
    parser.add_argument('--no_balance_b4_train', type=bool)
    parser.add_argument('--balance_b4_train', dest='no_balance_b4_train', action='store_false')
    parser.set_defaults(no_balance_b4_train=True)
    parser.add_argument('--no_rec_reweight_loss', type=bool)
    parser.add_argument('--rec_reweight_loss', dest='no_rec_reweight_loss', action='store_false')
    parser.set_defaults(no_rec_reweight_loss=True)
    parser.add_argument('--no_reweight_loss', type=bool)
    parser.add_argument('--reweight_loss', dest='no_reweight_loss', action='store_false')
    parser.set_defaults(no_reweight_loss=True)
    # Oversampling (Inference)
    parser.add_argument('--oversample_ratio', type=float, default=1.0)
    parser.add_argument('--knn_algorithm', default='brute')  # for deep smote's KNN algo. 'auto'
    parser.add_argument('--m_neighbors', type=int, default=10)
    parser.add_argument('--k_neighbors', type=int, default=6)
    parser.add_argument('--smote_algo_type', default='orig')
    parser.add_argument('--filter_margin', type=float, default=None)
    parser.add_argument('--classifier_type', default='catboost')
    parser.add_argument('--no_importance_oversampling', type=bool)
    parser.add_argument('--importance_oversampling', dest='no_importance_oversampling', action='store_false')
    parser.set_defaults(no_importance_oversampling=True)
    # Visualizations
    ####################
    parser.add_argument('--gen_visuals', type=bool)
    parser.add_argument('--no_gen_visuals', dest='gen_visuals', action='store_false')
    parser.set_defaults(gen_visuals=True)
    parser.add_argument('--plots_suffix', default=None)
    parser.add_argument('--verbose', type=bool)
    parser.add_argument('--no_verbose', dest='verbose', action='store_false')
    parser.set_defaults(verbose=True)
    # Others
    ###################
    parser.add_argument('--label_smoothing', type=float, default=0.0)
    parser.add_argument('--swap_prob', type=float, default=0.0)
    parser.add_argument('--no_mode_specific_normalization', type=bool)
    parser.add_argument('--mode_specific_normalization', dest='no_mode_specific_normalization', action='store_false')
    parser.set_defaults(no_mode_specific_normalization=True)
    # Experiment
    ####################
    parser.add_argument('--compare_to_baseliners', type=bool)
    parser.add_argument('--no_compare_to_baseliners', dest='compare_to_baseliners', action='store_false')
    parser.set_defaults(compare_to_baseliners=True)
    parser.add_argument('--classifier_seed', type=int, default=42)

    args = parser.parse_args(args)

    if args.categorical_features:
        args.categorical_features = [int(i) for i in args.categorical_features[0].split(' ')]  # Note: Windows OS does not need [0]
    args.enc_hidden_dims = args.enc_hidden_dims[0].split(' ')  # Note: Windows OS does not need [0]
    args.dec_hidden_dims = args.dec_hidden_dims[0].split(' ')  # Note: Windows OS does not need [0]

    # Setup #########################################################
    json_log = {}
    seed_everything(args.seed)
    if args.verbose:
        print(f'device = {args.device}')
    path_best_ae = f'{args.results_dir}/{args.dataset_name}_best_ae.pth'

    # Load Data #####################################################
    data = tabular_data_loaders.load_tabular_data(args.train_path, None, args.test_path)
    train_data = (data.x_train_total, data.y_train_total)
    test_data = (data.x_test, data.y_test)
    if args.verbose:
        tabular_data_loaders.print_dataset_characteristics(args.dataset_name, data)
    dataset_info = tabular_data_loaders.get_dataset_info(data.x_test_n_train, data.y_test_n_train)

    # Optuna Trials #################################################
    args.single_fold_validation = True if dataset_info.num_min > 200 else False
    hparams = dict({
        'lambda_metric_learn': args.lambda_metric_learn,
        'latent_dim_ratio': args.latent_dim_ratio,
        'label_smoothing': args.label_smoothing
    })
    if args.run_optuna:
        """
        optuna_search_space = {
            "lambda_metric_learn": {"low": 0.1, "high": 10, "log": True},
            "latent_dim_ratio": {"low": 0.40, "high": 0.80, "step": 0.1},
            "label_smoothing": {"low": 0.0, "high": 0.15, "step": 0.05},
        }
        """
        optuna_search_space = {
            'lambda_metric_learn': [1],
            'latent_dim_ratio': [0.75, 1.0],
            'label_smoothing': [0.0]
        }
        hp_sampler = optuna.samplers.GridSampler(optuna_search_space)  # TPESampler(seed=args.seed)
        if isinstance(hp_sampler, optuna.samplers.GridSampler):
            grid_search = True
        else:
            grid_search = False
        if not os.path.isdir(Path(f'Optuna')):
            os.mkdir(Path(f'Optuna'))
        study_save_file = Path("Optuna/study.pkl")
        n_trials = 30
        run_optuna_study(args,
                         hp_sampler,
                         grid_search,
                         path_best_ae,
                         (data.x_train, data.y_train),
                         copy.deepcopy(hparams),
                         optuna_search_space,
                         n_trials,
                         study_save_file)
        study = joblib.load(study_save_file)
        best_params = study.best_params
        print(f"best params: {best_params}")
        json_log['best_params'] = best_params
        hparams.update(best_params)
        # num_epochs = study.trials[study.best_trial.number].user_attrs['num_epochs']
        # print(f"optuna num_epochs = {num_epochs}")
        # hparams['epochs'] = num_epochs
        # json_log['num_epochs'] = num_epochs
    else:
        json_log['best_params'] = 'default'
        # json_log['num_epochs'] = args.epochs
    #################################################################

    train_results, AP, oversample_results, ae = train_oversample_evaluate(args,
                                                                          path_best_ae,
                                                                          train_data,
                                                                          test_data,
                                                                          hparams)
    json_log['AP'] = AP


    if args.gen_visuals:
        gen_visuals_of_latent_space(args, data, ae, oversample_results.x_all, oversample_results.y_all, 'pca')
        # gen_visuals_of_latent_space(args, data, ae, oversample_results.x_all, oversample_results.y_all, 'param_umap')

    if args.compare_to_baseliners:
        eval_results, oversampled_data = experiment(args.train_path, args.test_path,
                                                    args.new_minority_save_path,
                                                    args.categorical_features,
                                                    classifier_type=args.classifier_type,
                                                    seed=args.classifier_seed,
                                                    m_neighbors=args.m_neighbors,  ## num neighbors to detect border points
                                                    k_neighbors=args.k_neighbors,  ## num neighbors to interpolate with
                                                    verbose=False)

        if args.gen_visuals:
            gen_visuals_of_orig_space(args, data, oversampled_data, 'pca')
            # gen_visuals_of_orig_space(args, data, oversampled_data, 'param_umap')
            # gen_visuals_of_orig_space_V2(args, data, oversampled_data, 'param_umap')

        all_APs = [values['AP'] for method, values in eval_results.items()]
        all_ROCs = [values['ROC_AUC'] for method, values in eval_results.items()]
        all_F1s = [values['F1'] for method, values in eval_results.items()]
        print(f'AP: {all_APs}')
        print(f'F1: {all_F1s}')
        print(f'ROC: {all_ROCs}')
        json_log['eval_results'] = eval_results

    """
    # KNN Orig vs Latent
    x_min = data.x_train[data.y_train == 1]
    from sklearn.neighbors import NearestNeighbors
    nbrs = NearestNeighbors(n_neighbors= 6 + 1, algorithm='brute').fit(x_min)
    distances, indices = nbrs.kneighbors(x_min)
    knn_indices = [x[1:] for x in indices]  # remove first index which is always identity

    x_min_enc = ae.encode(x_min)
    nbrs = NearestNeighbors(n_neighbors= 6 + 1, algorithm='brute').fit(x_min_enc)
    distances, indices = nbrs.kneighbors(x_min_enc)
    knn_indices_enc = [x[1:] for x in indices]  # remove first index which is always identity

    intersection_list = []
    for min_idx in range(len(knn_indices)):
        x = list(set(knn_indices[min_idx]).intersection(knn_indices_enc[min_idx]))
        intersection_list.append(len(x))
    json_log['average_intersection_of_knn'] = sum(intersection_list)/len(intersection_list)
    """

    # Store Information in JSON file
    json_log['dataset'] = args.dataset_name
    json_log['num_features'] = dataset_info.num_features
    json_log['num_samples'] = dataset_info.num_samples
    json_log['imb_ratio'] = dataset_info.imb_ratio
    json_log['num_minority'] = dataset_info.num_min
    if train_results:
        json_log['best_losses'] = {k: round(v, 4) for k, v in train_results.best_losses.items()}
        json_log['best_epoch'] = train_results.best_epoch

    # Test Accuracy with encoder
    """
    from sklearn.metrics import average_precision_score
    with torch.no_grad():
        scores = (ae.contr_loss.linear_classifier(ae.encode(data.x_test).to(args.device))).cpu()
        encoder_y_predict = torch.nn.functional.softmax(scores)[:,1]
    json_log['encoder_AP'] = average_precision_score(data.y_test, encoder_y_predict)
    """
    with open(f'{args.results_dir}/{args.dataset_name}_results.json', 'w') as f:
        json.dump(json_log, f)
    return json_log

if __name__ == "__main__":
    main(sys.argv[1:])
