import yaml
import os
import re
import fire
from ast import literal_eval
import argparse
import json
import copy


class Struct:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            if isinstance(value, dict):
                self.__dict__[key] = Struct(**value)
            else:
                self.__dict__[key] = value

    def todict(self):
        # recursively convert to dict
        return {
            k: v.todict() if isinstance(v, Struct) else v
            for k, v in self.__dict__.items()
        }

    def __getitem__(self, index):
        return self.__dict__[index]


class Config:
    def __init__(self, config_file, **kwargs):
        _config = yaml.load(open(config_file), Loader=yaml.FullLoader)
        for key, value in _config.items():
            if isinstance(value, dict):
                self.__dict__[key] = Struct(**value)
            else:
                self.__dict__[key] = value

    def __getitem__(self, index):
        return self.__dict__[index]

    def todict(self):
        # recursively convert to dict
        return {
            k: v.todict() if isinstance(v, Struct) else v
            for k, v in self.__dict__.items()
        }

    def save2yaml(self, path):
        with open(path, "w") as f:
            yaml.dump(self.todict(), f, default_flow_style=False)

    def __str__(self):
        def prepare_dict4print(dict_):
            tmp_dict = copy.deepcopy(dict_)

            def recursive_change_list_to_string(d, summarize=16):
                for k, v in d.items():
                    if isinstance(v, dict):
                        recursive_change_list_to_string(v)
                    elif isinstance(v, list):
                        d[k] = (
                            (
                                str(
                                    v[: summarize // 2] + ["..."] + v[-summarize // 2 :]
                                )
                                + f" (len={len(v)})"
                            )
                            if len(v) > summarize
                            else str(v) + f" (len={len(v)})"
                        )
                    else:
                        pass

            recursive_change_list_to_string(tmp_dict)
            return tmp_dict

        return json.dumps(prepare_dict4print(self.todict()), indent=4, sort_keys=False)


def simplest_type(s):
    try:
        return literal_eval(s)
    except:
        return s


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_file",
        type=str,
        default="../debug.yaml",
    )
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--no_wandb", action="store_true")
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--exp_name", type=str, default="debug")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--epochs", type=int, default=100)
    _args, unknown = parser.parse_known_args()
    cfg = Config(**_args.__dict__)
    print(f"The config of this process is:\n{cfg}")
