# =============== PRECONFIG ===============
import os
import sys
import platform

curr_filepath = os.path.dirname(os.path.abspath(__file__))
if platform.system() != 'Windows':
    os.chdir(os.path.abspath(curr_filepath + "/../"))  # change workdir
sys.path[0] = os.path.abspath(curr_filepath)  # change module search path
# =============== PRECONFIG ===============

import argparse
import logging
import shutil
import traceback
import settings
import importlib
from ast import literal_eval
from define import NECESSARY_KEY_IN_CFG_FILE, default_train_cfg, default_environ_cfg
from lib.utils import UnifyConfig, init_all


def get_cfg_obj():
    # resolve arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', '-dt', type=str, help='dataset name', dest='dataset', required=True)
    parser.add_argument('--cls', type=str, help='class name', dest='entrypoint_cls', required=True)
    parser.add_argument('--cfg_file', '-f', type=str, help='config filename', dest='cfg_file', default=None)
    parser.add_argument('--data_cfg.datafmt_cls', type=str, help='datafmt class name', dest='datafmt_cls', default=None)
    parser.add_argument('--model_cfg.model_cls', type=str, help='model class name', dest='model_cls', default=None)
    args, unknown_args = parser.parse_known_args()

    # read config file
    cfg = UnifyConfig.from_py_module(settings)
    cfg.DATASET = args.dataset
    cfg.ENTRYPOINT_CLS = args.entrypoint_cls

    yaml_file_name = args.cfg_file or cfg.ENTRYPOINT_CLS
    yaml_path = f"{cfg.CFG_FOLDER_PATH}/{cfg.DATASET}/{yaml_file_name}.yaml"
    cfg.update(UnifyConfig.from_yml_file(yaml_path))

    # check and convert
    for key in NECESSARY_KEY_IN_CFG_FILE:
        if not cfg.dot_contains(key):
            raise ValueError(f"please ensure keys:{NECESSARY_KEY_IN_CFG_FILE} in {yaml_path}")

    for key in cfg:
        if isinstance(cfg[key], dict):
            cfg[key] = UnifyConfig(cfg[key])

    # update default cfg
    cfg.environ_cfg.update(default_environ_cfg, ignore_same_key=True)
    cfg.train_cfg.update(default_train_cfg, ignore_same_key=True)
    cfg.data_cfg['datafmt_cls'] = args.datafmt_cls or cfg.data_cfg['datafmt_cls']
    cfg.model_cfg['model_cls'] = args.model_cls or cfg.model_cfg['model_cls']
    datafmt_cls = importlib.import_module("datafmt").__getattribute__(cfg.data_cfg['datafmt_cls'])
    cfg.data_cfg.update(
        getattr(datafmt_cls, 'expose_default_cfg') if hasattr(datafmt_cls, 'expose_default_cfg') else {},
        ignore_same_key=True
    )
    model_cls = importlib.import_module("models").__getattribute__(cfg.model_cfg['model_cls'])
    cfg.model_cfg.update(
        getattr(model_cls, 'expose_default_cfg') if hasattr(model_cls, 'expose_default_cfg') else {},
        ignore_same_key=True
    )

    # resolve unknown_args
    if len(unknown_args) > 0:
        assert len(unknown_args) % 2 == 0, f"invalid params: {unknown_args}"
        for i in range(int(len(unknown_args) / 2)):
            assert unknown_args[2 * i].startswith("--")
            key, value = unknown_args[2 * i][2:], unknown_args[2 * i + 1]
            assert cfg.dot_contains(key) is True, f"Param:{key} is not valid"
            if type(cfg.dot_get(key)) != str:
                value = type(cfg.dot_get(key))(literal_eval(value))
            cfg.dot_set(key, value)

    return cfg


def main(cfg):
    cfg.logger.info("====" * 15)
    cfg.logger.info(f"[ID]: {cfg.ID}")
    cfg.logger.info(f"[DATASET]: {cfg.DATASET}")
    cfg.logger.info(f"[ENTRYPOINT_CLS]: {cfg.ENTRYPOINT_CLS}")
    cfg.logger.info(f"[ARGV]: {sys.argv}")
    cfg.logger.info(f"[ALL_CFG]: \n{cfg.dump_fmt()}")
    cfg.dump_file(f"{cfg.temp_folder_path}/cfg.json")
    cfg.logger.info("====" * 15)
    cls = importlib.import_module("entrypoint").__getattribute__(cfg.ENTRYPOINT_CLS)(cfg)
    cls.start()
    cfg.logger.info(f"Task: {cfg.ID} Completed!")


if __name__ == '__main__':
    cfg = get_cfg_obj()
    init_all(cfg, dataset=cfg.DATASET, task=cfg.ENTRYPOINT_CLS)
    try:
        main(cfg)
        logging.shutdown()
        shutil.move(cfg.temp_folder_path, cfg.save_folder_path)
    except Exception as e:
        cfg.logger.error(traceback.format_exc())
        raise e
