import argparse
from typing import Any

from scripts.plots.scripts.grid.main import run as run_main
from utils.logger.logger import Logger
from utils.utils import get_class_name


def run(
        name: str,
        seed_entropy_noises: int = None,
        seed_entropy_labels: int = None
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'name: {name}, seed_entropy_noises: {seed_entropy_noises}, seed_entropy_labels: {seed_entropy_labels}'
    )
    run_main(
        name=name,
        model_name='edm-cifar10-32x32-cond-vp',
        edm_model_path='PATH/state_dicts/edm-cifar10-32x32-cond-vp.pth',
        vgg_model_path='PATH/results/lines/lines-cifar10-32x32-2-loss-lpips-vgg-64-bs-256-br-2-22074219-5353-4ac4-bd27-214df672cda5/checkpoints/last.converted/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_0.66667.pth',
        gan_model_path='PATH/results/lines-gan-cifar10-32x32-12-bs-32-br-16-two-step-f7471df7-9f31-4bb3-be81-ec83a3796772.0/checkpoints/10000/ema/beta_0.999_update_every_1_update_after_step_100_inv_gamma_1.0_power_0.75.pth',
        save_folder='samples-plot-grid',
        image_height=32,
        image_width=32,
        image_channels=3,
        num_samples=500,
        num_steps=18,
        two_time_steps=[0, 1],
        batch_size=50,
        num_classes=10,
        seed_entropy_noises=seed_entropy_noises,
        seed_entropy_labels=seed_entropy_labels,
        num_processes=8
    )
    Logger.debug(f'{get_class_name(run)} - done')


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--name', type=str, required=True)
    parser.add_argument('--seed_entropy_noises', type=int, default=None)
    parser.add_argument('--seed_entropy_labels', type=int, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args = parse_args()
    config = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
