import time
import json
import os

from src.utils import logger, pause_to_confirm


os.makedirs("outputs/results", exist_ok=True)
os.makedirs("outputs/config", exist_ok=True)

TIMESTAMP = time.strftime('%Y-%m-%d_%H-%M-%S')

class ConfigBase(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __key(self, key):
        return "" if key is None else key.lower()

    def __str__(self):
        return json.dumps(self)

    def __setattr__(self, key, value):
        self[self.__key(key)] = value

    def __getattr__(self, key):
        return self.get(self.__key(key))

    def __getitem__(self, key):
        return super().get(self.__key(key))

    def __setitem__(self, key, value):
        return super().__setitem__(self.__key(key), value)

    def to_dict(self):
        return dict(self)

class Config(ConfigBase):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.timestamp = TIMESTAMP
        self.path = f"outputs/config/config_{self.timestamp}_{self.run_name}.json"
        self.results_path = f"outputs/results/results_{self.timestamp}_{self.run_name}.jsonl"

        if self.resume:
            logger.warning(f"Resume from {self.resume}")
            time.sleep(10)
            self.load_from_file(self.resume)  # 上面的部分会被覆盖
        else:
            with open(self.path, "w") as f:
                json.dump(self, f, ensure_ascii=False, indent=4)

        logger.info(f"========== Run Name: {self.run_name} ==========")
        logger.info(f"LLM: {self.llm_model}")
        logger.info(f"Embed: {self.embed_model}")
        logger.info(f"Rerank: {self.rerank_model}")
        logger.info(f"Dataset: {self.dataset}")
        logger.info(f"Inter Thres: Node: {self.node_count_threshold}, Edge: {self.edge_count_threshold}")
        logger.info(f"Rerank Thres: {self.rerank_threshold}, Rerank Rate: {self.rerank_rate}, Top Thres: {self.top_threshold}")
        logger.info(f"Resume: {self.resume}")
        logger.info(f"Config Path: {self.path}")
        pause_to_confirm("Confirm to continue", pause_time=10)

        if self.use_data == "paths":
            logger.warning("!!!! USE_DATA is `paths`, Only for debug !!!!")

        if self.debug_ids:
            logger.error(f"Debug mode, only process {len(self.debug_ids)} instances")

        if self.pre_retrieve_path and os.path.exists(self.pre_retrieve_path):
            logger.warning(f"Pre Retrieve Path: {self.pre_retrieve_path}")
            if self.use_data != "pre_retrieve":
                raise ValueError("The pre_retrieve_path is not empty, but use_data is not `pre_retrieve`")

        if self.save_rerank_result:
            logger.warning(f"Save Rerank Result: {self.save_rerank_result}")
            pause_to_confirm("Confirm to save rerank result", pause_time=3)


    def load_from_file(self, file_path):
        KEEP_KEYS = ["resume", "run_name"]
        assert os.path.exists(file_path), f"Config file {file_path} does not exist"
        logger.info(f"Loading config from {file_path}")
        with open(file_path, "r") as f:
            config = json.load(f)
            for k in KEEP_KEYS:
                config[k] = self[k]
            self.update(config) # if some config item is not in the file, it will be set to original value


