import json
import subprocess
from typing import Union

Variant = dict[str, Union[str, float, int]]

default_variant: Variant = {
    "random_seed": 0,
    "experiment_kwargs.config.population.n_speakers": 0,
    "experiment_kwargs.config.population.n_listeners": 0,
    "training_steps": int(2e5),  # Originally int(2e5)
}


variants: dict[str, Variant] = {
    "1x1": {
        "experiment_kwargs.config.population.n_speakers": 1,
        "experiment_kwargs.config.population.n_listeners": 1,
        "experiment_kwargs.config.population.num_agents_per_step": 1,
    },
    "5x5": {
        "experiment_kwargs.config.population.n_speakers": 5,
        "experiment_kwargs.config.population.n_listeners": 5,
        "experiment_kwargs.config.population.num_agents_per_step": 5,
    },
    "10x10": {
        "experiment_kwargs.config.population.n_speakers": 10,
        "experiment_kwargs.config.population.n_listeners": 10,
        "experiment_kwargs.config.population.num_agents_per_step": 10,
    },
    "1x10": {
        "experiment_kwargs.config.population.n_speakers": 10,
        "experiment_kwargs.config.population.n_listeners": 10,
        "experiment_kwargs.config.population.num_agents_per_step": 10,
    },
    "10x1": {
        "experiment_kwargs.config.population.n_speakers": 10,
        "experiment_kwargs.config.population.n_listeners": 1,
        "experiment_kwargs.config.population.num_agents_per_step": 10,
    },
}


def make_cmd(var: Variant) -> list[str]:
    config_flags = [f"--config.{k}={v}" for k, v in var.items()]
    # fmt: off
    return [
        "python",
        "-m", "emergent_communication_at_scale.main",
        f"--config=emergent_communication_at_scale/configs/lewis_config.py:imagenet",
        # To skip training and eval only, uncomment the follwing:
        # "--config.skip_train",
        *config_flags,
    ]
    # fmt: on

def main() -> None:
    for name, variant in variants.items():
        variant = default_variant | variant
        # For now, ImageNet is the only dataset that works; the CelebA provided
        # is missing metdata.
        checkpoint_dir = f"./checkpoint/imagenet-{name}"
        variant["checkpoint_dir"] = checkpoint_dir
        cmd = make_cmd(variant)
        subprocess.run(cmd)
        with open(checkpoint_dir + "/metadata.json", "r+") as fo:
            data = json.load(fo)
            fo.seek(0)
            json.dump({"metrics": {"system": data}}, fo)
            fo.truncate()


if __name__ == "__main__":
    main()
