import os
import json
import random
import torch
import argparse
import pandas as pd
import numpy as np
from Model.Trainer import Trainer


model_config = {
    "device": "cuda:0",
    "data_dir": "Data/",
    "dataset_name": "cardio",
    "data_dim": 21,
    "epochs": 200,
    "batch_size": 512,
    "learning_rate": 0.00001,
    "model":"npt",
    "basis_type":"OT",
    "attention_type":"OT",
    "basis_num": 5,
    "prototype_num": 5,
    "runs": 1,
    "random_seed": 42,
    "num_workers": 0
}



if __name__ == "__main__":
    seed = model_config.get('random_seed')
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed)
    random.seed(seed)
    if model_config['num_workers'] > 0:
        torch.multiprocessing.set_start_method('spawn')
    result = []
    runs = model_config['runs']
    mse_rauc, mse_ap, mse_f1 = np.zeros(runs), np.zeros(runs), np.zeros(runs)
    ratio_mse_rauc, ratio_mse_ap, ratio_mse_f1 = np.zeros(runs), np.zeros(runs), np.zeros(runs)
    for i in range(runs):
        trainer = Trainer(run=i, model_config=model_config)
        trainer.training(model_config['epochs'])
        trainer.evaluate(mse_rauc, mse_ap, mse_f1)
    mean_mse_auc , mean_mse_pr , mean_mse_f1 = np.mean(mse_rauc), np.mean(mse_ap), np.mean(mse_f1)

    print('##########################################################################')
    print("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f"
          % (mean_mse_auc, mean_mse_pr))
    print("mse: average f1: %.4f" % (mean_mse_f1))
    
    dataset = model_config['dataset_name']
    model_name = model_config['model']
    basis_type = model_config['basis_type']
    attention_type = model_config['attention_type']
    
    
    results_name = f'./Save/{dataset}/{model_name}/{basis_type}_{attention_type}/'

    with open(results_name + 'results.txt','a') as file:
        file.write("mse: average AUC-ROC: %.4f  average AUC-PR: %.4f average f1: %.4f" % (
            mean_mse_auc, mean_mse_pr, mean_mse_f1))
        file.write('\n')

