import argparse
import torch
import time
import os

from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer

# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet

import datasets.imagenet_sketch
import datasets.imagenetv2
import datasets.imagenet_a
import datasets.imagenet_r

import trainers.coop
import trainers.kgcoop
import trainers.kgcoop_coop_LMC
import trainers.promptkd
import trainers.promptkd_LMC
import trainers.mma
import trainers.mma_LMC


def print_args(args, cfg):
    print("***************")
    print("** Arguments **")
    print("***************")
    optkeys = list(args.__dict__.keys())
    optkeys.sort()
    for key in optkeys:
        print("{}: {}".format(key, args.__dict__[key]))
    print("************")
    print("** Config **")
    print("************")
    print(cfg)


def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    cfg.RESUME_COOP = args.resume_coop


    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head


def extend_cfg(cfg):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    from yacs.config import CfgNode as CN

    cfg.TRAINER.COOP = CN()
    cfg.TRAINER.COOP.ALPHA = 1.0
    cfg.TRAINER.COOP.N_CTX = 16  # number of context vectors
    cfg.TRAINER.COOP.CSC = False  # class-specific context
    cfg.TRAINER.COOP.CTX_INIT = False  # initialization words
    cfg.TRAINER.COOP.W = 8.0
    cfg.TRAINER.COOP.PREC = "amp"  # fp16, fp32, amp
    cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end"  # 'middle' or 'end' or 'front'
    cfg.TRAINER.COOP.LOSS_TYPE = "cosine"  # 'cosine' or 'l2'
    cfg.TRAINER.COOP.W_LMC = 1.0
    cfg.TRAINER.COOP.NUM_SAMPLES = 5
    cfg.TRAINER.COOP.COOP_LMC = True


    cfg.TRAINER.COCOOP = CN()
    cfg.TRAINER.COCOOP.N_CTX = 16  # number of context vectors
    cfg.TRAINER.COCOOP.CTX_INIT = False  # initialization words
    cfg.TRAINER.COCOOP.PREC = "amp"  # fp16, fp32, amp

    cfg.TRAINER.COOP_CLIP = CN()
    cfg.TRAINER.COOP_CLIP.N_CTX = 4
    cfg.TRAINER.COOP_CLIP.CTX_INIT = False


    cfg.DATASET.SUBSAMPLE_CLASSES = "all"  # all, base or new
    """
    Add new config
    """
    cfg.LOSS = CN()
    cfg.LOSS.GM = False
    cfg.LOSS.NAME = ""
    cfg.LOSS.ALPHA = 0.
    cfg.LOSS.T = 1.
    cfg.LOSS.LAMBDA = 1.


    cfg.TRAINER.SAM = CN()
    cfg.TRAINER.SAM.RHO = 0.05  # SAM's rho parameter
    cfg.TRAINER.SAM.ADAPTIVE = False  # Whether to use adaptive SAM

    cfg.TEST.FINAL_MODEL = "best_val"


def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # 4. From optional input arguments
    cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


def main(args):
    cfg = setup_cfg(args)
    if cfg.SEED >= 0:
        print("Setting fixed seed: {}".format(cfg.SEED))
        set_random_seed(cfg.SEED)
    setup_logger(cfg.OUTPUT_DIR)

    if torch.cuda.is_available() and cfg.USE_CUDA:
        torch.backends.cudnn.benchmark = True

    print_args(args, cfg)
    print("Collecting env info ...")
    print("** System info **\n{}\n".format(collect_env_info()))

    trainer = build_trainer(cfg)

    if args.eval_only_merge_dare:
        trainer.load_model_merge_dare(args.model_dir, model_name=args.model_name)
        trainer.test()
        return


    if args.eval_only_merge:
        trainer.load_model_merge(args.model_dir, lambda_val=args.lambda_val, model_prefix=args.model_prefix)
        trainer.test()
        return


    if args.eval_only:
        trainer.load_model(args.model_dir, epoch=args.load_epoch)
        trainer.test()
        return


    if not args.no_train:
        trainer.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="", help="path to dataset")
    parser.add_argument("--output-dir",
                        type=str,
                        default="",
                        help="output directory")
    parser.add_argument(
        "--resume",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument(
        "--resume-coop",
        type=str,
        default="",
        help="checkpoint directory (from which the training resumes)",
    )
    parser.add_argument("--seed",
                        type=int,
                        default=-1,
                        help="only positive value enables a fixed seed")
    parser.add_argument("--source-domains",
                        type=str,
                        nargs="+",
                        help="source domains for DA/DG")
    parser.add_argument("--target-domains",
                        type=str,
                        nargs="+",
                        help="target domains for DA/DG")
    parser.add_argument("--transforms",
                        type=str,
                        nargs="+",
                        help="data augmentation methods")
    parser.add_argument("--config-file",
                        type=str,
                        default="",
                        help="path to config file")
    parser.add_argument(
        "--dataset-config-file",
        type=str,
        default="",
        help="path to config file for dataset setup",
    )
    parser.add_argument("--trainer",
                        type=str,
                        default="",
                        help="name of trainer")
    parser.add_argument("--backbone",
                        type=str,
                        default="",
                        help="name of CNN backbone")
    parser.add_argument("--head", type=str, default="", help="name of head")
    parser.add_argument("--eval-only",
                        action="store_true",
                        help="evaluation only")
    parser.add_argument(
        "--model-dir",
        type=str,
        default="",
        help="load model from this directory for eval-only mode",
    )
    parser.add_argument("--load-epoch",
                        type=int,
                        help="load model weights at this epoch for evaluation")
    parser.add_argument("--no-train",
                        action="store_true",
                        help="do not call trainer.train()")
    parser.add_argument(
        "opts",
        default=None,
        nargs=argparse.REMAINDER,
        help="modify config options using the command-line",
    )

    # Eval only merge arguments
    parser.add_argument("--eval-only-merge",
                        action="store_true",
                        help="evaluation only")
    parser.add_argument("--eval-only-merge-dare",
                        action="store_true",
                        help="evaluation only")
    parser.add_argument("--lambda-val",
                        type=float,
                        default=None,
                        help="lambda value for merged model (e.g., 1.6). If not specified, uses the first available model.")
    parser.add_argument("--model-prefix",
                        type=str,
                        default=None,
                        help="Model file prefix. If not specified, auto-detects available pattern.")
    parser.add_argument("--model-name",
                        type=str,
                        default=None,
                        help="Model file name. If not specified, auto-detects available pattern.")
    args = parser.parse_args()
    main(args)
