from yacs.config import CfgNode as CN
import os


def load_cfg(config_file):
    cfg = CN()

    cfg.dataset = CN()
    cfg.dataset.source_name = "cora"
    cfg.dataset.target_name = "cora"
    cfg.dataset.source_label_num = 7
    cfg.dataset.target_label_num = 7
    cfg.dataset.use_text = True
    cfg.dataset.seed = 0
    cfg.dataset.drop_edge_ratio = 0.0
    cfg.dataset.drop_node_ratio = 0.0
    cfg.dataset.text_mask_ratio = 0.0

    cfg.llm = CN()
    cfg.llm.model_name = "sentence-transformers/all-MiniLM-L6-v2"
    cfg.llm.max_neighbors = 5
    cfg.llm.max_length = 512
    cfg.llm.batch_size = 256

    cfg.llm.output_file = "prompts.json"

    cfg.llm.prompt_generator = CN()
    cfg.llm.prompt_generator.model_name = "llama3.2"
    cfg.llm.prompt_generator.api_url =""
    # cfg.llm.prompt_generator.api_key ="xxx"
    cfg.llm.save_dir = "core/llm_llama3.2_response/cora"

    cfg.llm.prompt_generator.max_workers = 1
    cfg.llm.prompt_generator.max_tokens = 512
    cfg.llm.prompt_generator.temperature = 0.2
    cfg.llm.prompt_generator.top_k = 40
    cfg.llm.prompt_generator.top_p = 0.85
    cfg.llm.prompt_generator.do_sample = True
    cfg.llm.prompt_generator.timeout = 120
    cfg.llm.prompt_generator.retry_delay = 1
    cfg.llm.prompt_generator.json_mode = True
    cfg.llm.prompt_generator.stream = False

    cfg.model = CN()
    cfg.model.type = "gat"
    cfg.model.node_type = "gcn"
    cfg.model.checkpoint_dir = "./saved_models"

    cfg.model.sage = CN()
    cfg.model.sage.aggregator = "mean"
    cfg.model.sage.normalize = 'batchnorm'
    cfg.model.sage.source_in_channels = 384
    cfg.model.sage.target_in_channels = 384
    cfg.model.sage.hidden_channels = 256
    cfg.model.sage.out_channels = 256
    cfg.model.sage.num_layers = 3
    cfg.model.sage.dropout = 0.5

    cfg.gnn = CN()
    cfg.gnn.source_in_channels = 1433
    cfg.gnn.target_in_channels = 1433
    cfg.gnn.hidden_channels = 512
    cfg.gnn.out_channels = 384
    cfg.gnn.heads = 3
    cfg.gnn.num_layers = 5
    cfg.gnn.model_path = 'saved_model/GAT_neighbor'
    cfg.gnn.vocab_size = 50000
    cfg.gcn = CN()
    cfg.gcn.in_channels = 1433
    cfg.gcn.hidden_channels = 512
    cfg.gcn.out_channels = 384

    cfg.model.gat = CN()
    cfg.model.gat.heads = 3
    cfg.model.gat.concat = True
    cfg.training = CN()
    cfg.training.warmup = 20
    cfg.training.lr = 1e-5
    cfg.training.weight_decay = 5e-8
    cfg.training.batch_size = 128
    cfg.training.reg = 0.1
    cfg.training.seqlen = 128
    cfg.training.temperature = 0.07
    cfg.training.neg_k = 1
    cfg.training.fixed_length = 20
    cfg.training.epochs = 30
    cfg.training.edgebatch_size = 256
    cfg.training.device = "cuda"
    cfg.training.shot_mode = 'zero-shot'
    cfg.testing = CN()
    cfg.testing.task = "node"

    if os.path.isfile(config_file):
        cfg.merge_from_file(config_file)
    return cfg