import os.path
from typing import Any

from ...utils.config import get_config as get_config_util


def get_config(
        model_load_path: str,
        time_steps: list[int],
        save_path: str = None,
        model_load_keys: list[str] = None,
        multiply_noises: bool = True,
        same_noise: bool = True
) -> dict[str, Any]:
    return get_config_util(
        model_name='edm-imagenet-64x64-cond-adm',
        model_load_path=model_load_path,
        num_steps=40,
        time_steps=time_steps,
        num_classes=1000,
        image_height=64,
        image_width=64,
        image_channels=3,
        reference_path='PATH/imagenet-64x64-edm.npz',
        num_samples=50_000,
        batch_size=50,
        save_path=save_path,
        model_load_keys=model_load_keys,
        multiply_noises=multiply_noises,
        same_noise=same_noise,
        n_processes=8
    )


def get_configs(
        model_load_path: str,
        save_folder: str = None,
        model_load_keys: list[str] = None,
        multiply_noises: bool = True,
        same_noise: bool = True
) -> list[dict[str, Any]]:
    configs: list[dict[str, Any]] = []
    for t in range(1, 20):
        configs.append(get_config(
            model_load_path=model_load_path,
            time_steps=[0, t, 40],
            save_path=os.path.join(save_folder, f't_{t}', 'index.pickle'),
            model_load_keys=model_load_keys,
            multiply_noises=multiply_noises,
            same_noise=same_noise
        ))
    for t1 in range(1, 20):
        for t2 in range(t1 + 1, 20):
            configs.append(get_config(
                model_load_path=model_load_path,
                time_steps=[0, t1, t2, 40],
                save_path=os.path.join(save_folder, f't_{t1}_{t2}', 'index.pickle'),
                model_load_keys=model_load_keys,
                multiply_noises=multiply_noises,
                same_noise=same_noise
            ))
    return configs


def main() -> None:
    configs: list[dict[str, Any]] = get_configs(model_load_path='')
    print(configs)
    print(len(configs))


if __name__ == '__main__':
    main()
