from dataset import process_dataset, get_dataset
from models import *
from utils import *
from runner import *
from config import args
import numpy as np
import torch
import torch.nn as nn
from load_data import DataLoader, SyntheticGenerator

if __name__ == '__main__':
    # seed_everything(args.seed)
    # data_loader = DataLoader(args)
    # dataset = data_loader.load_dataset(args.inid)

    # get_dataset(args, args.inid) 
    # # 获得source data，其中少量标签
    # source_data = data_loader.load_dataset(args.inid)
    # # 获得target data，其中没有标签
    # target_data = data_loader.load_dataset(args.outid)

    source_data = get_dataset(args, args.inid) 
    target_data = get_dataset(args, args.outid) 
     
    # 处理两个数据集的数据
    print("********************process source data********************")
    source_data = process_dataset(args, source_data)
    print("********************process target data********************")
    target_data = process_dataset(args, target_data)
    # 训练
    acc, auc_roc, parity, equality = train(args, source_data, target_data)
    print("==========={}============".format(args.outid))
    print('Acc: {:.2f} ± {:.2f}'.format(np.mean(acc), np.std(acc)))
    print('auc_roc: {:.2f} ± {:.2f}'.format(np.mean(auc_roc), np.std(auc_roc)))
    print('parity: {:.2f} ± {:.2f}'.format(np.mean(parity), np.std(parity)))
    print('equality: {:.2f} ± {:.2f}'.format(np.mean(equality), np.std(equality)))
