import argparse
from typing import Any

from scripts.plots.scripts.grid.main import run as run_main
from utils.logger.logger import Logger
from utils.utils import get_class_name


def run(
        name: str,
        seed_entropy_noises: int = None,
        seed_entropy_labels: int = None
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'name: {name}, seed_entropy_noises: {seed_entropy_noises}, seed_entropy_labels: {seed_entropy_labels}'
    )
    run_main(
        name=name,
        model_name='edm-imagenet-64x64-cond-adm',
        edm_model_path='PATH/state_dicts/edm-imagenet-64x64-cond-adm.pth',
        vgg_model_path='PATH/results/lines-imagenet-64x64-9-loss-lpips-vgg-64-bs-32-br-64-ae3da90b-fee8-423a-935c-4feb477edde9.0/checkpoints/last/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_0.66667.pth',
        gan_model_path='PATH/results/lines-gan-imagenet-64x64-14-bs-16-br-32-two-step-fcfc49fd-3236-4672-a3ac-94aa64050860.0/checkpoints/last/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_1.0.pth',
        save_folder='samples-plot-grid',
        image_height=64,
        image_width=64,
        image_channels=3,
        num_samples=500,
        num_steps=40,
        two_time_steps=[0, 1],
        batch_size=50,
        num_classes=1000,
        seed_entropy_noises=seed_entropy_noises,
        seed_entropy_labels=seed_entropy_labels,
        num_processes=8
    )
    Logger.debug(f'{get_class_name(run)} - done')


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--seed_entropy_noises', type=int, default=None)
    parser.add_argument('--seed_entropy_labels', type=int, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args = parse_args()
    config = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
