# Taken from: https://github.com/JXYin24/MCM

import torch
import numpy as np
from baselines.MCM.Model.Trainer import Trainer
from baselines.MCM.options import BaseOptions

import os
import pandas as pd

import time

from baselines.MCM.utils import get_data_paths


def MCM(path):
    model_config = {
        'dataset_name': 'ADBench',
        'data_dim': 100,
        'epochs': 200,
        'learning_rate': 0.05,
        'sche_gamma': 0.98,
        'mask_num': 50,
        'lambda': 5,
        'device': 'cuda:0',
        'data_dir': 'Data/',
        'runs': 1,
        'batch_size': 256, 
        'en_nlayers': 3,
        'de_nlayers': 3,
        'hidden_dim': 256,
        'z_dim': 128,
        'mask_nlayers': 3,
        'random_seed': 42,
        'num_workers': 0
    }


    name = "MCM"
    
    print(path)
    
    torch.manual_seed(model_config['random_seed'])
    torch.cuda.manual_seed(model_config['random_seed'])
    np.random.seed(model_config['random_seed'])
    if model_config['num_workers'] > 0:
        torch.multiprocessing.set_start_method('spawn')
    result = []
    runs = model_config['runs']
    mse_rauc, mse_ap, mse_f1 = np.zeros(runs), np.zeros(runs), np.zeros(runs)
    for i in range(runs):
        start_time = time.time()
        trainer = Trainer(run=i, model_config=model_config, path=path)
        model = trainer.training(model_config['epochs'])
        end_time = time.time(); time_fit = end_time - start_time 
                    
        mse_score2 = trainer.evaluate_train(model, mse_rauc, mse_ap, mse_f1)
        start_time = time.time()
        mse_score = trainer.evaluate(model, mse_rauc, mse_ap, mse_f1)
        end_time = time.time(); time_inference = end_time - start_time
    mean_mse_auc , mean_mse_pr , mean_mse_f1 = np.mean(mse_rauc), np.mean(mse_ap), np.mean(mse_f1)

    print('##########################################################################')
    print("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f"
        % (mean_mse_auc, mean_mse_pr))
    print("mse: average f1: %.4f" % (mean_mse_f1))
    results_name = './DTE/results/' + model_config['dataset_name'] + '.txt'

    with open(results_name,'a') as file:
        file.write("epochs: %d lr: %.4f gamma: %.2f masks: %d lambda: %.1f " % (
            model_config['epochs'], model_config['learning_rate'], model_config['sche_gamma'], model_config['mask_num'], model_config['lambda']))
        file.write('\n')
        file.write("de_layer: %d  hidden_dim: %d z_dim: %d mask_layer: %d" % (model_config['de_nlayers'], model_config['hidden_dim'], model_config['z_dim'], model_config['mask_nlayers']))
        file.write('\n')
        file.write("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f average f1: %.4f" % (
            mean_mse_auc, mean_mse_pr, mean_mse_f1))
        file.write('\n')

    # path = path.rstrip('/')
    # parts = path.split(os.sep)

    # # Get the last folder name
    # dataset = os.path.basename(parts[-2])


    # df_F1.loc[dataset, name] = mean_mse_f1
    # df_AUCROC.loc[dataset, name] = mean_mse_auc
    # df_AUCPR.loc[dataset, name] = mean_mse_pr
    
    # # df_train.loc[dataset, name] = time_fit
    # # df_inference.loc[dataset, name] = time_inference
    
    # df_F1.to_csv(f1_name)
    # df_AUCROC.to_csv(aucroc_name)
    # df_AUCPR.to_csv(aucpr_name)

    return mse_score, mse_score2, mean_mse_auc , mean_mse_pr , mean_mse_f1, time_fit, time_inference