from yacs.config import CfgNode as CN


def reset_cfg(cfg, args):
    cfg.config_path = args.config_path
    cfg.gpu_id = args.gpu_id
    cfg.threshold = args.threshold
    cfg.energy = args.energy
    cfg.temperature = args.temperature
    cfg.TSPD.optim.lr = [float(x) for x in args.lr.split(",")]
    cfg.gumbel_lr = args.gumbel_lr
    

def extend_cfg(cfg):
    """
    Add config variables.
    """
    cfg.dataset_root = ""
    cfg.model_backbone_name = ""
    cfg.input_size = (-1, -1)
    cfg.prompt_template = ""
    cfg.scenario = ""
    cfg.dataset = ""
    cfg.num_shots = -1
    cfg.seed = -1
    cfg.use_validation = False
    cfg.load_file = ""
    cfg.eval_only = False

    cfg.train_one_dataset = -1  # if >= 0, then only train corresponding dataset in MTIL
    cfg.zero_shot = False
    cfg.MTIL_order_2 = False
    
    cfg.threshold = 0.96
    cfg.energy = 0.95
    cfg.temperature = 3.0
    cfg.gumbel_lr = 1.0

    cfg.TSPD = CN()
    cfg.TSPD.prompt_depth_vision = 1
    cfg.TSPD.prompt_depth_text = 1
    cfg.TSPD.n_ctx_vision = 12
    cfg.TSPD.n_ctx_text = 12
    cfg.TSPD.optim = CN()
    cfg.TSPD.optim.batch_size = 64
    cfg.TSPD.optim.name = "SGD"
    cfg.TSPD.optim.lr = [0.05] * 11
    cfg.TSPD.optim.max_epoch = 10
    cfg.TSPD.optim.weight_decay = 0
    cfg.TSPD.optim.lr_scheduler = "cosine"
    cfg.TSPD.optim.warmup_epoch = 0
    cfg.TSPD.batchwise_prompt = False


def setup_cfg(args):
    cfg = CN()
    extend_cfg(cfg)
    cfg.merge_from_file(args.config_path)
    reset_cfg(cfg, args)
    cfg.merge_from_list(args.opts)
    return cfg


def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)