import torch
import numpy as np
from Model.Trainer import Trainer
model_config = {
    'dataset_name': '13_fraud',
    'data_dim': 29,
    'epochs': 150,
    'learning_rate': 0.01,
    'sche_gamma': 0.98,
    'mask_num': 15,
    'lambda': 1,
    'device': 'cuda:0',
    'data_dir': './Data/',
    'runs': 1,
    'batch_size': 512, 
    'en_nlayers': 3,
    'de_nlayers': 3,
    'hidden_dim': 256,
    'z_dim': 128,
    'mask_nlayers': 3,
    'random_seed': 42,
    'num_workers': 0
}

if __name__ == "__main__":
 best_auc,best_prc=0,0
 best_lr=0
 best_lam=0
 for lr in [1e-2]:
  for lamda in [10,1,0.1]:
    model_config['lambda']=lamda
    model_config['learning_rate']=lr
    torch.manual_seed(model_config['random_seed'])
    torch.cuda.manual_seed(model_config['random_seed'])
    np.random.seed(model_config['random_seed'])
    if model_config['num_workers'] > 0:
        torch.multiprocessing.set_start_method('spawn')
    result = []
    runs = model_config['runs']
    mse_rauc, mse_ap, mse_f1 = np.zeros(runs), np.zeros(runs), np.zeros(runs)
    for i in range(runs):
        trainer = Trainer(run=i, model_config=model_config)
        trainer.training(model_config['epochs'])
        trainer.evaluate(mse_rauc, mse_ap, mse_f1)
    mean_mse_auc , mean_mse_pr , mean_mse_f1 = np.mean(mse_rauc), np.mean(mse_ap), np.mean(mse_f1)
    if mean_mse_auc>best_auc:
      best_auc=mean_mse_auc
      best_prc=mean_mse_pr
      best_lr=lr
      best_lam=lamda

    print('##########################################################################')
    print("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f"
          % (best_auc, best_prc))
    print("learning rate: %.4f  lambda: %.4f"
          % (best_lr, best_lam))
    #print("mse: average f1: %.4f" % (mean_mse_f1))
    #results_name = './results/' + model_config['dataset_name'] + '.txt'
