import argparse

from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from ldm.modules.losses.coop import CoOp

# custom
import ldm.modules.losses.datasets.oxford_pets
import ldm.modules.losses.datasets.oxford_flowers
import ldm.modules.losses.datasets.fgvc_aircraft
import ldm.modules.losses.datasets.dtd
import ldm.modules.losses.datasets.eurosat
import ldm.modules.losses.datasets.stanford_cars
import ldm.modules.losses.datasets.food101
import ldm.modules.losses.datasets.sun397
import ldm.modules.losses.datasets.caltech101
import ldm.modules.losses.datasets.ucf101
import ldm.modules.losses.datasets.imagenet

import ldm.modules.losses.datasets.imagenet_sketch
import ldm.modules.losses.datasets.imagenetv2
import ldm.modules.losses.datasets.imagenet_a
import ldm.modules.losses.datasets.imagenet_r


def reset_cfg(cfg, args):
    if args.root:
        cfg.DATASET.ROOT = args.root

    if args.classname:
        cfg.DATASET.CLASSNAMES = args.classname

    if args.output_dir:
        cfg.OUTPUT_DIR = args.output_dir

    if args.resume:
        cfg.RESUME = args.resume

    if args.seed:
        cfg.SEED = args.seed

    if args.source_domains:
        cfg.DATASET.SOURCE_DOMAINS = args.source_domains

    if args.target_domains:
        cfg.DATASET.TARGET_DOMAINS = args.target_domains

    if args.transforms:
        cfg.INPUT.TRANSFORMS = args.transforms

    if args.trainer:
        cfg.TRAINER.NAME = args.trainer

    if args.backbone:
        cfg.MODEL.BACKBONE.NAME = args.backbone

    if args.head:
        cfg.MODEL.HEAD.NAME = args.head


def extend_cfg(cfg):
    """
    Add new config variables.

    E.g.
        from yacs.config import CfgNode as CN
        cfg.TRAINER.MY_MODEL = CN()
        cfg.TRAINER.MY_MODEL.PARAM_A = 1.
        cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
        cfg.TRAINER.MY_MODEL.PARAM_C = False
    """
    from yacs.config import CfgNode as CN

    cfg.TRAINER.COOP = CN()
    cfg.TRAINER.COOP.N_CTX = 4  # number of context vectors
    cfg.TRAINER.COOP.CSC = False  # class-specific context
    cfg.TRAINER.COOP.CTX_INIT = ""  # initialization words
    cfg.TRAINER.COOP.PREC = "fp16"  # fp16, fp32, amp
    cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end"  # 'middle' or 'end' or 'front'

    cfg.TRAINER.COCOOP = CN()
    cfg.TRAINER.COCOOP.N_CTX = 16  # number of context vectors
    cfg.TRAINER.COCOOP.CTX_INIT = ""  # initialization words
    cfg.TRAINER.COCOOP.PREC = "fp16"  # fp16, fp32, amp

    cfg.DATASET.SUBSAMPLE_CLASSES = "all"  # all, base or new


def setup_cfg(args):
    cfg = get_cfg_default()
    extend_cfg(cfg)

    # 1. From the dataset config file
    if args.dataset_config_file:
        cfg.merge_from_file(args.dataset_config_file)

    # 2. From the method config file
    if args.config_file:
        cfg.merge_from_file(args.config_file)

    # 3. From input arguments
    reset_cfg(cfg, args)

    # # 4. From optional input arguments
    # cfg.merge_from_list(args.opts)

    cfg.freeze()

    return cfg


class CoOpLossComputer:
    def __init__(self, seed=None):
        temp_parser = argparse.ArgumentParser()
        temp_parser.add_argument("--root", type=str, default="/data8/user/CoOp/data/", help="path to dataset")
        temp_parser.add_argument("--output-dir", type=str, default="/data8/user/CoOp/output/", help="output directory")
        temp_parser.add_argument(
            "--resume",
            type=str,
            default="",
            help="checkpoint directory (from which the training resumes)",
        )
        temp_parser.add_argument(
            "--seed", type=int, default=1, help="only positive value enables a fixed seed"
        )
        temp_parser.add_argument(
            "--source-domains", type=str, nargs="+", help="source domains for DA/DG"
        )
        temp_parser.add_argument(
            "--target-domains", type=str, nargs="+", help="target domains for DA/DG"
        )
        temp_parser.add_argument(
            "--transforms", type=str, nargs="+", help="data augmentation methods"
        )
        temp_parser.add_argument(
            "--classname", type=str, default="increasing_classes_1000.txt", help="classnames"
        )
        temp_parser.add_argument(
            "--config-file", type=str, default="/data8/user/CoOp/configs/trainers/CoOp/vit_b16_ep50.yaml", help="path to config file"
        )
        temp_parser.add_argument(
            "--dataset-config-file",
            type=str,
            default="/data8/user/CoOp/configs/datasets/imagenet.yaml",
            help="path to config file for dataset setup",
        )
        temp_parser.add_argument("--trainer", type=str, default="CoOp", help="name of trainer")
        temp_parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
        temp_parser.add_argument("--head", type=str, default="", help="name of head")
        temp_parser.add_argument("--eval-only", action="store_true", help="evaluation only")
        temp_parser.add_argument(
            "--model-dir",
            type=str,
            default="",
            help="load model from this directory for eval-only mode",
        )
        temp_parser.add_argument(
            "--load-epoch", type=int, help="load model weights at this epoch for evaluation"
        )
        temp_parser.add_argument(
            "--no-train", action="store_true", help="do not call trainer.train()"
        )
        temp_parser.add_argument(
            "opts",
            default=None,
            nargs=argparse.REMAINDER,
            help="modify config options using the command-line",
        )
        temp_args, _ = temp_parser.parse_known_args()
        cfg = setup_cfg(temp_args)
        if cfg.SEED >= 0:
            print("Setting fixed seed: {}".format(cfg.SEED))
            set_random_seed(cfg.SEED)
        self.coop_model = CoOp(cfg)

def return_coop_model():
    m = CoOpLossComputer()
    return m.coop_model