import copy
from pathlib import Path
from typing import Literal, Optional
from data_utils.functional._misc import flatten_dict
import yaml

from cyclopts import App

app = App()


@app.default
def main(
    folder: Path = Path("./configurations"),
    results_folder: Optional[Path] = None,
    device: Optional[Literal["cpu", "cuda", "mps"]] = None,
    acc_batches: Optional[int] = None,
):
    set_args = {
        "results_folder": results_folder.as_posix() if results_folder else None,
        "device": device,
        "acc_batches": acc_batches,
    }
    set_args = {k: v for k, v in set_args.items() if v is not None}
    for f in folder.glob("*.yaml"):
        config = load_config(f)
        config.update(set_args)
        create_hessian_configs(config, folder)

        for cfg_instance in flatten_dict(config):
            cfg_out = folder / cfg_instance["exp_name"]  # type: ignore
            dump_config_instance(cfg_instance, cfg_out, f.stem)

    print("Configurations created.")


def load_config(path: Path):
    with open(path, "r") as stream:
        config = yaml.safe_load(stream)
        if config is None:
            raise ValueError(f"Config file {path} is faulty.")
    return config


def create_hessian_configs(config: dict, folder: Path):
    config_without_bits = copy.deepcopy(config)
    config_without_bits["bits"] = 0
    config_without_bits["strategy"] = []
    for cfg in flatten_dict(config_without_bits):
        cfg["strategy"] = []
        dump_config_instance(cfg, folder / "hessian", "hessian")


def dump_config_instance(cfg_instance: dict, cfg_out: Path, out_file_prefix: str):
    cfg_instance = copy.deepcopy(cfg_instance)
    b = cfg_instance["bits"]
    strategy = cfg_instance["strategy"]
    assert strategy is not None
    nbatches = cfg_instance["nbatches"]
    network = cfg_instance["network"]
    train_dataset = cfg_instance.get("train_dataset", cfg_instance["dataset"])
    cfg_instance["strategy"] = (
        [strategy] if not isinstance(strategy, list) else strategy
    )
    cfg_instance["exp_name"] = f"{cfg_instance['exp_name']}_{nbatches}"  # type: ignore
    strategy_str = f"_{strategy}"

    cfg_out.mkdir(exist_ok=True)
    with open(
        cfg_out
        / f"{out_file_prefix}_b={b}_{network}_{nbatches}{strategy_str}_{train_dataset}.yaml",
        "w",
    ) as stream:
        yaml.dump(cfg_instance, stream)


if __name__ == "__main__":
    app()
