# Taken from: https://github.com/JXYin24/MCM

import torch
import numpy as np
from baselines.MCM.Model.Trainer import Trainer
from baselines.MCM.options import BaseOptions

import os
import pandas as pd

import time

from baselines.MCM.utils import get_data_paths


def MCM(path):
    model_config = {
        'dataset_name': 'ADBench',
        'data_dim': 100,
        'epochs': 200,
        'learning_rate': 0.05,
        'sche_gamma': 0.98,
        'mask_num': 50,
        'lambda': 5,
        'device': 'cuda:0',
        'data_dir': 'Data/',
        'runs': 1,
        'batch_size': 256, 
        'en_nlayers': 3,
        'de_nlayers': 3,
        'hidden_dim': 256,
        'z_dim': 128,
        'mask_nlayers': 3,
        'random_seed': 42,
        'num_workers': 0
    }

# if __name__ == "__main__":

    # opt = BaseOptions().parse()

    # if not os.path.exists(opt.output_directory):
    #     os.makedirs(opt.output_directory)

    # aucroc_name = opt.output_directory + "/seed_"  + str(opt.dataset_seed) + "_AUCROC.csv"
    # aucpr_name = opt.output_directory + "/seed_" + str(opt.dataset_seed) + "_AUCPR.csv"
    # f1_name = opt.output_directory + "/seed_" + str(opt.dataset_seed) + "_AUCF1.csv"
    # # train_name = opt.output_directory + str(opt.dataset_seed) + "results" + "_TrainTime.csv"
    # # inference_name = opt.output_directory + str(opt.dataset_seed) + "results" + "_InferenceTime.csv"
    
    # try:
    #     df_AUCROC = pd.read_csv(aucroc_name, index_col = 0) 
    # except:
    #     df_AUCROC = pd.DataFrame(data=None)
    # try:
    #     df_AUCPR = pd.read_csv(aucpr_name, index_col = 0)
    # except:
    #     df_AUCPR = pd.DataFrame(data=None)
    # try:
    #     df_F1 = pd.read_csv(f1_name, index_col = 0)
    # except:
    #     df_F1 = pd.DataFrame(data=None)
    # # try:
    # #     df_train = pd.read_csv(train_name, index_col = 0)
    # # except:
    # #     df_train = pd.DataFrame(data=None)
    # # try:
    # #     df_inference = pd.read_csv(inference_name, index_col = 0)
    # # except:
    # #     df_inference = pd.DataFrame(data=None)


    name = "MCM"
    # data_paths = get_data_paths(opt.data_directory, opt.dataset_seed)

    # start = 1

    # for path in data_paths:
    print(path)
    # if start == 1 or path == '/home/manhirt/Git/SSSD/data/adbench_seeds_sc_val/44_Wilt/seed_0.pkl':
    #     start = 1


    torch.manual_seed(model_config['random_seed'])
    torch.cuda.manual_seed(model_config['random_seed'])
    np.random.seed(model_config['random_seed'])
    if model_config['num_workers'] > 0:
        torch.multiprocessing.set_start_method('spawn')
    result = []
    runs = model_config['runs']
    mse_rauc, mse_ap, mse_f1 = np.zeros(runs), np.zeros(runs), np.zeros(runs)
    for i in range(runs):
        start_time = time.time()
        trainer = Trainer(run=i, model_config=model_config, path=path)
        trainer.training(model_config['epochs'])
        end_time = time.time(); time_fit = end_time - start_time 
                    
        start_time = time.time()
        mse_score = trainer.evaluate(mse_rauc, mse_ap, mse_f1)
        end_time = time.time(); time_inference = end_time - start_time
    mean_mse_auc , mean_mse_pr , mean_mse_f1 = np.mean(mse_rauc), np.mean(mse_ap), np.mean(mse_f1)

    print('##########################################################################')
    print("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f"
        % (mean_mse_auc, mean_mse_pr))
    print("mse: average f1: %.4f" % (mean_mse_f1))
    results_name = './results/' + model_config['dataset_name'] + '.txt'

    with open(results_name,'a') as file:
        file.write("epochs: %d lr: %.4f gamma: %.2f masks: %d lambda: %.1f " % (
            model_config['epochs'], model_config['learning_rate'], model_config['sche_gamma'], model_config['mask_num'], model_config['lambda']))
        file.write('\n')
        file.write("de_layer: %d  hidden_dim: %d z_dim: %d mask_layer: %d" % (model_config['de_nlayers'], model_config['hidden_dim'], model_config['z_dim'], model_config['mask_nlayers']))
        file.write('\n')
        file.write("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f average f1: %.4f" % (
            mean_mse_auc, mean_mse_pr, mean_mse_f1))
        file.write('\n')

    # path = path.rstrip('/')
    # parts = path.split(os.sep)

    # # Get the last folder name
    # dataset = os.path.basename(parts[-2])


    # df_F1.loc[dataset, name] = mean_mse_f1
    # df_AUCROC.loc[dataset, name] = mean_mse_auc
    # df_AUCPR.loc[dataset, name] = mean_mse_pr
    
    # # df_train.loc[dataset, name] = time_fit
    # # df_inference.loc[dataset, name] = time_inference
    
    # df_F1.to_csv(f1_name)
    # df_AUCROC.to_csv(aucroc_name)
    # df_AUCPR.to_csv(aucpr_name)

    return mean_mse_auc , mean_mse_pr , mean_mse_f1, time_fit, time_inference