import argparse
import torch
import os
import yaml
from utils import *
import sys

def read_config(args):
    fileNamePath = os.path.split(os.path.realpath(__file__))[0]
    yamlPath = os.path.join(fileNamePath, 'config/config.yaml')
    with open(yamlPath, 'r', encoding='utf-8') as f:
        cont = f.read()
        config_dict = yaml.safe_load(cont)[args.dataset]

    print(config_dict)
    for key, value in config_dict.items():
        args.__setattr__(key, value)
    return args

def mprint(*arg, **kwargs):
    if VERBOSE:  # 仅在 arg.verbose 为 True 时输出
        print(*arg, **kwargs)
        

parser = argparse.ArgumentParser(description='mine Arguments.')

# dataset
# parser.add_argument('--dataset', type=str, default='pokec')
# parser.add_argument('--inid', type=str, default='_z')
# parser.add_argument('--outid', type=str, default='_n')
# parser.add_argument('--dataset', type=str, default='bailA')
# parser.add_argument('--inid', type=str, default='_2')
# parser.add_argument('--outid', type=str, default='_1')
parser.add_argument('--dataset', type=str, default='syn')
parser.add_argument('--inid', type=str, default='-2')
parser.add_argument('--outid', type=str, default='-1')
# parser.add_argument('--dataset', type=str, default='german')
# parser.add_argument('--inid', type=str, default='_2')
# parser.add_argument('--outid', type=str, default='_1') 
parser.add_argument('--lr', type=float, default=0.004)
parser.add_argument('--lr2_reg', type=float, default=0.001)
parser.add_argument('--lreg', type=float, default=1)
parser.add_argument('--ureg', type=float, default=0.1)
parser.add_argument('--train_epochs', type=int, default=500)
parser.add_argument('--adaption_epochs', type=int, default=200)
parser.add_argument('--pre_train_epochs', type=int, default=500)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--tau', type=float, default=1)
parser.add_argument('--train_step', type=int, default=50)

# network
parser.add_argument('--n_layers', type=int, default=3, help='the number of layers')
parser.add_argument('--inter_encoder', type=str, choices=['GCN', 'GAT', 'SAGE', 'MLP', "vanilla"], default='GCN', help='GNN bachbone')
parser.add_argument('--sens_encoder', type=str, choices=['GCN', 'GAT', 'SAGE', 'MLP', "vanilla"], default='GCN', help='GNN bachbone')
parser.add_argument('--pre_train_encoder', type=str, choices=['GCN', 'GAT', 'SAGE', 'MLP'], default='GCN', help='bachbone')
parser.add_argument('--hidden_dim', type=int, default=32)
parser.add_argument("--device_id", type=str, default="2", help="device id for gpu")

# others
parser.add_argument('--log_path', type=str, default='logs/log.txt')
parser.add_argument('--pre_train_path', type=str, default=r'param/pre_train_{}_{}{}.pth')
parser.add_argument('--sens_train_path', type=str, default=r'param/sens_train_{}_{}{}.pth')
parser.add_argument('--inter_train_path', type=str, default=r'param/inter_train_{}_{}{}.pth')
parser.add_argument('--seed', type=int, default=1111)
parser.add_argument('--runs', type=int, default=5)
parser.add_argument('--tune', action='store_true', help='if tune')
parser.set_defaults(tune=True)
parser.add_argument('--verbose', action='store_true', help = "verbose for training")
parser.add_argument('--overwrite', action='store_true', help = "overwrite model storage")
parser.add_argument('--lambda_bma', type=float, default=0.1)
parser.add_argument('--lambda_cons', type=float, default=0.1)
parser.add_argument('--lambda_dis', type=float, default=0.1)
parser.add_argument('--y_pseudo_threshold', type=float, default=0.9)
parser.add_argument('--s_pseudo_threshold', type=float, default=0.9)
parser.add_argument('--top_k_mi', type=int, default=3)

args = parser.parse_args()
sys.stdout = Logger(args.log_path)
    
VERBOSE = args.verbose
if int(args.device_id) >= 0 and torch.cuda.is_available():
    args.device = torch.device("cuda:{}".format(args.device_id))
    mprint("using gpu:{} to train the model".format(args.device_id))
else:
    args.device = torch.device("cpu")
    mprint("using cpu to train the model")
    
if args.tune:
    args = read_config(args)
if args.outid == "all":
    args.outid = ""
print(args)
