import argparse
import torch

from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer


def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)


def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head

    # for mire
    if args.nocls:
        cfg.OSDG.NO_CLS = args.nocls
    if args.adddims:
        cfg.OSDG.ADD_DIMS = args.adddims


    if args.lr:
        cfg.OPTIM.LR = args.lr
    if args.bs:
        cfg.DATALOADER.TRAIN_X.BATCH_SIZE = args.bs
    if args.epochs:
        cfg.OPTIM.MAX_EPOCH = args.epochs

    
    # for ADA
    cfg.ADA.bp_grl = args.bp_grl
    cfg.ADA.mining_grl = args.mining_grl
    cfg.ADA.topk = args.topk
    cfg.ADA.mining_th = args.mining_th

    cfg.ADA.fda_loss_coef = args.fda_loss_coef
    cfg.ADA.ua_loss_coef = args.ua_loss_coef
    cfg.ADA.ua_loss_coef1 = args.ua_loss_coef1
    cfg.ADA.penalty_coef = args.penalty_coef
    cfg.ADA.smooth_coef = args.smooth_coef
    cfg.ADA.warmup_epoch = args.warmup_epoch
    cfg.ADA.adv_grl = args.adv_grl
    cfg.TEST.EVALUATOR = args.evaluator
    cfg.ADA.TTA.OPTIM.LR = args.tta_lr
    cfg.ADA.TTA.epoch = args.tta_epoch
    cfg.ADA.TTA.tta_ce_coef = args.tta_ce_coef
    cfg.ADA.TTA.tta_sc_coef = args.tta_sc_coef
    cfg.ADA.TTA.steps = args.tta_steps
    cfg.TEST.NO_TEST = args.only_train
    cfg.TEST.PER_CLASS_RESULT = args.per_class_result
    cfg.ADA.DSPATH = args.dspath


def extend_cfg(cfg):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    pass


def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


def main(args):
    cfg = setup_cfg(args)
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    setup_logger(cfg.OUTPUT_DIR)

    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True

    print_args(args, cfg)
    print("Collecting env info ...")
    print("** System info **\n{}\n".format(collect_env_info()))

    trainer = build_trainer(cfg)  # 跑的参数是用的cfg的 所以要保证args和cfg的是一样的

    if args.eval_only:
        trainer.load_model(args.model_dir, epoch=args.load_epoch)
        trainer.test()
        return

    if not args.no_train:
        trainer.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="data/PACSori/", help="path to dataset")
    parser.add_argument(
        "--output-dir", type=str, default="", help="output directory"
    )
    parser.add_argument(
        "--resume",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="only positive value enables a fixed seed"
    )
    parser.add_argument(
        "--source-domains",
        type=str,
        nargs="+",
        help="source domains for DA/DG"
    )
    parser.add_argument(
        "--target-domains",
        type=str,
        nargs="+",
        help="target domains for DA/DG"
    )
    parser.add_argument(
        "--transforms", type=str, nargs="+", help="data augmentation methods"
    )
    parser.add_argument(
        "--config-file", type=str, default="configs/trainers/dg/vanilla/pacs.yaml", help="path to config file"
    )
    parser.add_argument(
        "--dataset-config-file",
        type=str,
        default="configs/datasets/dg/pacs_r18.yaml",
        help="path to config file for dataset setup",
    )
    parser.add_argument(
        "--trainer", type=str, default="Vanilla2", help="name of trainer"
    )
    parser.add_argument(
        "--backbone", type=str, default="", help="name of CNN backbone"
    )
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument(
        "--eval-only", action="store_true", help="evaluation only"
    )
    parser.add_argument(
        "--model-dir",
        type=str,
        default="",
        help="load model from this directory for eval-only mode",
    )
    parser.add_argument(
        "--load-epoch",
        type=int,
        help="load model weights at this epoch for evaluation"
    )
    parser.add_argument(
        "--no-train", action="store_true", help="do not call trainer.train()"
    )
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="modify config options using the command-line",
    )

    #for OSDG
    parser.add_argument(
        "--nocls", type=list, default=['horse', 'house', 'person'], help=""
    )
    parser.add_argument(
        "--adddims", type=int, default=1,  help=""
    )

    parser.add_argument(
        "--lr", type=float, default=1e-3, help="learning rate"
    )
    parser.add_argument(
        "--bs", type=int, default=64, help="batch size"
    )
    parser.add_argument(
        "--epochs", type=int, default=30, help="epochs"
    )


    # parameters
    parser.add_argument('--only_train', action="store_true")
    parser.add_argument('--per_class_result', action="store_true")
    parser.add_argument('--bp_grl', type=float, default=0.5, metavar='TH', help='grl adversarial weight (default: 0.5)')
    parser.add_argument('--mining_grl', type=float, default=0.2, metavar='TH', help='grad scaler (default: 0.2)')
    parser.add_argument('--topk', default=1, type=int, help='select potential unk regions, depends on number of known classes')
    parser.add_argument('--mining_th', default=1.0, type=float, metavar='TH', help='unk label')
    parser.add_argument('--fda_loss_coef', default=1.0, type=float)
    parser.add_argument('--ua_loss_coef', default=1.0, type=float)
    parser.add_argument('--ua_loss_coef1', default=0.0, type=float)
    parser.add_argument('--smooth_coef', default=1.0, type=float, help="smoothed CE, gt 1")
    parser.add_argument('--warmup_epoch', default=0, type=int)
    parser.add_argument('--penalty_coef', default=0.05, type=float)
    parser.add_argument('--evaluator', default="Classification_plain")
    parser.add_argument('--adv_grl', type=float, default=0.1, metavar='TH', help='grl adversarial weight (default: 0.1)')
    parser.add_argument('--tta_lr', default=1e-3, type=float)
    parser.add_argument('--tta_epoch', default=1, type=int)
    parser.add_argument('--tta_ce_coef', default=0.1, type=float)
    parser.add_argument('--tta_sc_coef', default=0.0, type=float)
    parser.add_argument('--tta_steps', default=1.0, type=int)
    parser.add_argument('--dspath', default="")


    args = parser.parse_args()
    

    if args.trainer == 'Vanilla':
        args.adddims = 0


    main(args)
