# MIT License

# Copyright (c) 2022 Intelligent Systems Lab Org

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# File author: Shariq Farooq Bhat

import json
import os

from zoedepth.utils.easydict import EasyDict as edict

from zoedepth.utils.arg_utils import infer_type
import pathlib
import platform

ROOT = pathlib.Path(__file__).parent.parent.resolve()

HOME_DIR = os.path.expanduser("./data")

COMMON_CONFIG = {
    "save_dir": os.path.expanduser("./depth_anything_finetune"),
    "project": "ZoeDepth",
    "tags": '',
    "notes": "",
    "gpu": None,
    "root": ".",
    "uid": None,
    "print_losses": False
}

DATASETS_CONFIG = {
    "kitti": {
        "dataset": "kitti",
        "min_depth": 0.001,
        "max_depth": 80,
        "data_path": os.path.join(HOME_DIR, "Kitti/raw_data"),
        "gt_path": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"),
        "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
        "input_height": 352,
        "input_width": 1216,  # 704
        "data_path_eval": os.path.join(HOME_DIR, "Kitti/raw_data"),
        "gt_path_eval": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"),
        "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",

        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,

        "do_random_rotate": True,
        "degree": 1.0,
        "do_kb_crop": True,
        "garg_crop": True,
        "eigen_crop": False,
        "use_right": False
    },
    "kitti_test": {
        "dataset": "kitti",
        "min_depth": 0.001,
        "max_depth": 80,
        "data_path": os.path.join(HOME_DIR, "Kitti/raw_data"),
        "gt_path": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"),
        "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt",
        "input_height": 352,
        "input_width": 1216,
        "data_path_eval": os.path.join(HOME_DIR, "Kitti/raw_data"),
        "gt_path_eval": os.path.join(HOME_DIR, "Kitti/data_depth_annotated_zoedepth"),
        "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt",

        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,

        "do_random_rotate": False,
        "degree": 1.0,
        "do_kb_crop": True,
        "garg_crop": True,
        "eigen_crop": False,
        "use_right": False
    },
    "nyu": {
        "dataset": "nyu",
        "avoid_boundary": False,
        "min_depth": 1e-3,   # originally 0.1
        "max_depth": 10,
        "data_path": os.path.join(HOME_DIR, "nyu"),
        "gt_path": os.path.join(HOME_DIR, "nyu"),
        "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt",
        "input_height": 480,
        "input_width": 640,
        "data_path_eval": os.path.join(HOME_DIR, "nyu"),
        "gt_path_eval": os.path.join(HOME_DIR, "nyu"),
        "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt",
        "min_depth_eval": 1e-3,
        "max_depth_eval": 10,
        "min_depth_diff": -10,
        "max_depth_diff": 10,

        "do_random_rotate": True,
        "degree": 1.0,
        "do_kb_crop": False,
        "garg_crop": False,
        "eigen_crop": True
    },
    "ibims": {
        "dataset": "ibims",
        "ibims_root": os.path.join(HOME_DIR, "iBims1/m1455541/ibims1_core_raw/"),
        "eigen_crop": True,
        "garg_crop": False,
        "do_kb_crop": False,
        "min_depth_eval": 0,
        "max_depth_eval": 10,
        "min_depth": 1e-3,
        "max_depth": 10
    },
    "sunrgbd": {
        "dataset": "sunrgbd",
        "sunrgbd_root": os.path.join(HOME_DIR, "SUNRGB-D"),
        "eigen_crop": True,
        "garg_crop": False,
        "do_kb_crop": False,
        "min_depth_eval": 0,
        "max_depth_eval": 8,
        "min_depth": 1e-3,
        "max_depth": 10
    },
    "diml_indoor": {
        "dataset": "diml_indoor",
        "diml_indoor_root": os.path.join(HOME_DIR, "DIML/indoor/sample/testset/"),
        "eigen_crop": True,
        "garg_crop": False,
        "do_kb_crop": False,
        "min_depth_eval": 0,
        "max_depth_eval": 10,
        "min_depth": 1e-3,
        "max_depth": 10
    },
    "diml_outdoor": {
        "dataset": "diml_outdoor",
        "diml_outdoor_root": os.path.join(HOME_DIR, "DIML/outdoor/test/LR"),
        "eigen_crop": False,
        "garg_crop": True,
        "do_kb_crop": False,
        "min_depth_eval": 2,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 80
    },
    "diode_indoor": {
        "dataset": "diode_indoor",
        "diode_indoor_root": os.path.join(HOME_DIR, "DIODE/val/indoors/"),
        "eigen_crop": True,
        "garg_crop": False,
        "do_kb_crop": False,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 10,
        "min_depth": 1e-3,
        "max_depth": 10
    },
    "diode_outdoor": {
        "dataset": "diode_outdoor",
        "diode_outdoor_root": os.path.join(HOME_DIR, "DIODE/val/outdoor/"),
        "eigen_crop": False,
        "garg_crop": True,
        "do_kb_crop": False,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 80
    },
    "hypersim_test": {
        "dataset": "hypersim_test",
        "hypersim_test_root": os.path.join(HOME_DIR, "HyperSim/"),
        "eigen_crop": True,
        "garg_crop": False,
        "do_kb_crop": False,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 10
    },
    "vkitti": {
        "dataset": "vkitti",
        "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"),
        "eigen_crop": False,
        "garg_crop": True,
        "do_kb_crop": True,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 80
    },
    "vkitti2": {
        "dataset": "vkitti2",
        "vkitti2_root": os.path.join(HOME_DIR, "vKitti2/"),
        "eigen_crop": False,
        "garg_crop": True,
        "do_kb_crop": True,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 80,
    },
    "ddad": {
        "dataset": "ddad",
        "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"),
        "eigen_crop": False,
        "garg_crop": True,
        "do_kb_crop": True,
        "min_depth_eval": 1e-3,
        "max_depth_eval": 80,
        "min_depth": 1e-3,
        "max_depth": 80,
    },
}

ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"]
ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor",  "vkitti2", "ddad"]
ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR

COMMON_TRAINING_CONFIG = {
    "dataset": "nyu",
    "distributed": True,
    "workers": 16,
    "clip_grad": 0.1,
    "use_shared_dict": False,
    "shared_dict": None,
    "use_amp": False,

    "aug": True,
    "random_crop": False,
    "random_translate": False,
    "translate_prob": 0.2,
    "max_translation": 100,

    "validate_every": 0.25,
    "log_images_every": 0.1,
    "prefetch": False,
}


def flatten(config, except_keys=('bin_conf')):
    def recurse(inp):
        if isinstance(inp, dict):
            for key, value in inp.items():
                if key in except_keys:
                    yield (key, value)
                if isinstance(value, dict):
                    yield from recurse(value)
                else:
                    yield (key, value)

    return dict(list(recurse(config)))


def split_combined_args(kwargs):
    """Splits the arguments that are combined with '__' into multiple arguments.
       Combined arguments should have equal number of keys and values.
       Keys are separated by '__' and Values are separated with ';'.
       For example, '__n_bins__lr=256;0.001'

    Args:
        kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. 

    Returns:
        dict: Parsed dict with the combined arguments split into individual key-value pairs.
    """
    new_kwargs = dict(kwargs)
    for key, value in kwargs.items():
        if key.startswith("__"):
            keys = key.split("__")[1:]
            values = value.split(";")
            assert len(keys) == len(
                values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})"
            for k, v in zip(keys, values):
                new_kwargs[k] = v
    return new_kwargs


def parse_list(config, key, dtype=int):
    """Parse a list of values for the key if the value is a string. The values are separated by a comma. 
    Modifies the config in place.
    """
    if key in config:
        if isinstance(config[key], str):
            config[key] = list(map(dtype, config[key].split(',')))
        assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]]
                                                     ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}."


def get_model_config(model_name, model_version=None):
    """Find and parse the .json config file for the model.

    Args:
        model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory.
        model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None.

    Returns:
        easydict: the config dictionary for the model.
    """
    config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json"
    config_file = os.path.join(ROOT, "models", model_name, config_fname)
    if not os.path.exists(config_file):
        return None

    with open(config_file, "r") as f:
        config = edict(json.load(f))

    # handle dictionary inheritance
    # only training config is supported for inheritance
    if "inherit" in config.train and config.train.inherit is not None:
        inherit_config = get_model_config(config.train["inherit"]).train
        for key, value in inherit_config.items():
            if key not in config.train:
                config.train[key] = value
    return edict(config)


def update_model_config(config, mode, model_name, model_version=None, strict=False):
    model_config = get_model_config(model_name, model_version)
    if model_config is not None:
        config = {**config, **
                  flatten({**model_config.model, **model_config[mode]})}
    elif strict:
        raise ValueError(f"Config file for model {model_name} not found.")
    return config


def check_choices(name, value, choices):
    # return  # No checks in dev branch
    if value not in choices:
        raise ValueError(f"{name} {value} not in supported choices {choices}")


KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase",
                  "prefetch", "cycle_momentum"]  # Casting is not necessary as their int casted values in config are 0 or 1


def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs):
    """Main entry point to get the config for the model.

    Args:
        model_name (str): name of the desired model.
        mode (str, optional): "train" or "infer". Defaults to 'train'.
        dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None.
    
    Keyword Args: key-value pairs of arguments to overwrite the default config.

    The order of precedence for overwriting the config is (Higher precedence first):
        # 1. overwrite_kwargs
        # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json
        # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json
        # 4. common_config: Default config for all models specified in COMMON_CONFIG

    Returns:
        easydict: The config dictionary for the model.
    """


    check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"])
    check_choices("Mode", mode, ["train", "infer", "eval"])
    if mode == "train":
        check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None])

    config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG})
    config = update_model_config(config, mode, model_name)

    # update with model version specific config
    version_name = overwrite_kwargs.get("version_name", config["version_name"])
    config = update_model_config(config, mode, model_name, version_name)

    # update with config version if specified
    config_version = overwrite_kwargs.get("config_version", None)
    if config_version is not None:
        print("Overwriting config with config_version", config_version)
        config = update_model_config(config, mode, model_name, config_version)

    # update with overwrite_kwargs
    # Combined args are useful for hyperparameter search
    overwrite_kwargs = split_combined_args(overwrite_kwargs)
    config = {**config, **overwrite_kwargs}

    # Casting to bool   # TODO: Not necessary. Remove and test
    for key in KEYS_TYPE_BOOL:
        if key in config:
            config[key] = bool(config[key])

    # Model specific post processing of config
    parse_list(config, "n_attractors")

    # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs
    if 'bin_conf' in config and 'n_bins' in overwrite_kwargs:
        bin_conf = config['bin_conf']  # list of dicts
        n_bins = overwrite_kwargs['n_bins']
        new_bin_conf = []
        for conf in bin_conf:
            conf['n_bins'] = n_bins
            new_bin_conf.append(conf)
        config['bin_conf'] = new_bin_conf

    if mode == "train":
        orig_dataset = dataset
        if dataset == "mix":
            dataset = 'nyu'  # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader
        if dataset is not None:
            config['project'] = f"MonoDepth3-{orig_dataset}"  # Set project for wandb

    if dataset is not None:
        config['dataset'] = dataset
        config = {**DATASETS_CONFIG[dataset], **config}
        

    config['model'] = model_name
    typed_config = {k: infer_type(v) for k, v in config.items()}
    # add hostname to config
    config['hostname'] = platform.node()
    return edict(typed_config)


def change_dataset(config, new_dataset):
    config.update(DATASETS_CONFIG[new_dataset])
    return config
