import json
import logging
from pathlib import Path

from omegaconf import OmegaConf


logger = logging.getLogger(__name__)

class Config:
    def __init__(self, cfg_path, options, keys):
        self.config = {}
        self.keys = keys
        self.config_sections = {}

        #self.args = args
        user_config = self._build_opt_list(options)
        config = OmegaConf.load(cfg_path)
        config = OmegaConf.merge(config, user_config)
        self.config = config

        for key in keys:
            self.config_sections[key] = self.get_config_by_key(key)

    def get_config_by_key(self, key):
        keys = key.split('.')
        config_section = self.config
        for k in keys:
            config_section = config_section[k]
        return config_section

    def _convert_to_dot_list(self, opts):
        if opts is None:
            opts = []

        if len(opts) == 0:
            return opts

        has_equal = opts[0].find("=") != -1

        if has_equal:
            return opts

        return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]

    def _build_opt_list(self, opts):
        opts_dot_list = self._convert_to_dot_list(opts)
        return OmegaConf.from_dotlist(opts_dot_list)

    def pretty_print(self):
        logger.info("\n=====  Running Parameters    =====")
        logger.info(self._convert_node_to_json(self.config.run))

        logger.info("\n======  Dataset Attributes  ======")
        logging.info(self._convert_node_to_json(self.config.datasets))

        logger.info(f"\n======  Model Attributes  ======")
        logger.info(self._convert_node_to_json(self.config.llm))
        logger.info(self._convert_node_to_json(self.config.audio_encoders))
        logger.info(self._convert_node_to_json(self.config.video_encoders))
        logger.info(self._convert_node_to_json(self.config.connectors))

    def _convert_node_to_json(self, node):
        container = OmegaConf.to_container(node, resolve=True)
        return json.dumps(container, indent=4, sort_keys=True)

    def to_dict(self):
        return OmegaConf.to_container(self.config)

    def save(self, path):
        Path(path).parent.mkdir(parents=True, exist_ok=True)
        OmegaConf.save(config=self.config, f=path)
    

# # Class Test
# import argparse
# def parse_args():
#     parser = argparse.ArgumentParser(description='train parameters')
#     parser.add_argument("--modality", type=str, default='audiovideoimage', help='modality')
#     parser.add_argument("--task", type=str, default='qa', help='task')
#     parser.add_argument("--cfg-path", type=str, default='config.yaml', help='path to configuration file')
#     parser.add_argument(
#         "--options",
#         nargs="+",
#         help="override some settings in the used config, the key-value pair "
#              "in xxx=yyy format will be merged into config file (deprecate), "
#              "change to --cfg-options instead.",
#     )
#     parser.add_argument(
#         "--llms",
#         type=str,
#         nargs='+',
#         default=['llama'],
#         choices=['llama'],
#         help='models to use as large language model',
#     )
#     # Add --local_rank to support distributed training
#     parser.add_argument("--local_rank", type=int, default=-1, help="local rank for distributed training")
#     parser.add_argument("--deepspeed_config ", type=str, default='', help="local rank for distributed training")
#     return parser.parse_args()
#
# args = parse_args()
# keys = ["audio_encoders.whisper", "audio_encoders.beats", "video_encoders.internvideo2"]
# cfg = Config(args, keys)
# # for key, config_section in cfg.config_sections.items():
# #     print(f"{key}: {config_section}")
# print(cfg.config_sections.keys())


