import argparse
import os.path
import time
from typing import Any

import numpy as np

from scripts.plots.create_labels import run as run_create_labels
from scripts.plots.create_noises import run as run_create_noises
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.utils import get_class_name
from scripts.plots.sample_edm import run as run_sample_edm
from scripts.plots.sample_multi_step import run as run_sample_multi_step
from scripts.plots.create_grid import run as run_create_grid


def run(
        name: str,
        model_name: str,
        edm_model_path: str,
        vgg_model_path: str,
        gan_model_path: str,
        save_folder: str,
        image_height: int,
        image_width: int,
        image_channels: int,
        num_samples: int,
        num_steps: int,
        two_time_steps: list[int],
        batch_size: int = 50,
        edm_model_keys: list[str] = None,
        vgg_model_keys: list[str] = None,
        gan_model_keys: list[str] = None,
        num_classes: int = None,
        seed_entropy_noises: int = None,
        seed_entropy_labels: int = None,
        num_processes: int = 8,
        device: str = None
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'name: {name}, '
        f'model_name: {model_name}, '
        f'edm_model_path: {edm_model_path}, '
        f'vgg_model_path: {vgg_model_path}, '
        f'gan_model_path: {gan_model_path}, '
        f'save_folder: {save_folder}, '
        f'image_height: {image_height}, '
        f'image_width: {image_width}, '
        f'image_channels: {image_channels}, '
        f'num_samples: {num_samples}, '
        f'num_steps: {num_steps}, '
        f'two_time_steps: {two_time_steps}, '
        f'batch_size: {batch_size}, '
        f'edm_model_keys: {edm_model_keys}, '
        f'vgg_model_keys: {vgg_model_keys}, '
        f'gan_model_keys: {gan_model_keys}, '
        f'num_classes: {num_classes}, '
        f'seed_entropy_noises: {seed_entropy_noises}, '
        f'seed_entropy_labels: {seed_entropy_labels}, '
        f'num_processes: {num_processes}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    conditional: bool = num_classes is not None
    Logger.debug(f'conditional: {conditional}')

    assert len(two_time_steps) == 2, f'two_time_steps must have length 2: {two_time_steps}'

    if seed_entropy_noises is None:
        s: np.random.SeedSequence = np.random.SeedSequence()
        seed_entropy_noises = s.entropy
    Logger.debug(f'seed_entropy_noises: {seed_entropy_noises}')
    time.sleep(0.1)
    if seed_entropy_labels is None:
        s: np.random.SeedSequence = np.random.SeedSequence()
        seed_entropy_labels = s.entropy
    Logger.debug(f'seed_entropy_labels: {seed_entropy_labels}')

    noises_folder: str = os.path.join(save_folder, f'noises-{name}-{seed_entropy_noises}')
    labels_folder: str = os.path.join(save_folder, f'labels-{name}-{seed_entropy_labels}') if conditional else None
    edm_folder: str = os.path.join(save_folder, f'edm-{name}')
    vgg_1_step_folder: str = os.path.join(save_folder, f'vgg-1-{name}')
    vgg_2_step_folder: str = os.path.join(save_folder, f'vgg-2-{name}')
    gan_1_step_folder: str = os.path.join(save_folder, f'gan-1-{name}')
    gan_2_step_folder: str = os.path.join(save_folder, f'gan-2-{name}')
    image_save_path: str = os.path.join(save_folder, f'grid-{name}.png')

    run_create_noises(
        height=image_height,
        width=image_width,
        channels=image_channels,
        num_samples=num_samples,
        save_folder=noises_folder,
        batch_size=batch_size,
        n_processes=num_processes,
        seed_entropy=seed_entropy_noises
    )
    if conditional:
        run_create_labels(
            num_classes=num_classes,
            num_samples=num_samples,
            save_folder=labels_folder,
            batch_size=batch_size,
            n_processes=num_processes,
            seed_entropy=seed_entropy_labels
        )

    run_sample_edm(
        model_name=model_name,
        model_load_path=edm_model_path,
        num_steps=num_steps,
        image_height=image_height,
        image_width=image_width,
        image_channels=image_channels,
        outputs_folder=edm_folder,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        num_classes=num_classes,
        model_load_keys=edm_model_keys,
        num_samples=num_samples,
        batch_size=batch_size,
        num_processes=num_processes,
        device=device
    )

    run_sample_multi_step(
        model_name=model_name,
        model_load_path=vgg_model_path,
        num_steps=num_steps,
        time_steps=[0, num_steps],
        image_height=image_height,
        image_width=image_width,
        image_channels=image_channels,
        outputs_folder=vgg_1_step_folder,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        num_classes=num_classes,
        model_load_keys=vgg_model_keys,
        num_samples=num_samples,
        batch_size=batch_size,
        num_processes=num_processes,
        device=device
    )
    run_sample_multi_step(
        model_name=model_name,
        model_load_path=vgg_model_path,
        num_steps=num_steps,
        time_steps=[*two_time_steps, num_steps],
        image_height=image_height,
        image_width=image_width,
        image_channels=image_channels,
        outputs_folder=vgg_2_step_folder,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        num_classes=num_classes,
        model_load_keys=vgg_model_keys,
        num_samples=num_samples,
        batch_size=batch_size,
        num_processes=num_processes,
        device=device
    )

    run_sample_multi_step(
        model_name=model_name,
        model_load_path=gan_model_path,
        num_steps=num_steps,
        time_steps=[0, num_steps],
        image_height=image_height,
        image_width=image_width,
        image_channels=image_channels,
        outputs_folder=gan_1_step_folder,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        num_classes=num_classes,
        model_load_keys=gan_model_keys,
        num_samples=num_samples,
        batch_size=batch_size,
        num_processes=num_processes,
        device=device
    )
    run_sample_multi_step(
        model_name=model_name,
        model_load_path=gan_model_path,
        num_steps=num_steps,
        time_steps=[*two_time_steps, num_steps],
        image_height=image_height,
        image_width=image_width,
        image_channels=image_channels,
        outputs_folder=gan_2_step_folder,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        num_classes=num_classes,
        model_load_keys=gan_model_keys,
        num_samples=num_samples,
        batch_size=batch_size,
        num_processes=num_processes,
        device=device
    )

    run_create_grid(
        image_folders=[
            edm_folder,
            vgg_1_step_folder,
            vgg_2_step_folder,
            gan_1_step_folder,
            gan_2_step_folder
        ],
        save_path=image_save_path,
        num_samples=num_samples
    )

    Logger.debug(f'{get_class_name(run)} - done')


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--index', type=int, required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--edm_model_path', type=str, required=True)
    parser.add_argument('--vgg_model_path', type=str, required=True)
    parser.add_argument('--gan_model_path', type=str, required=True)
    parser.add_argument('--save_folder', type=str, required=True)
    parser.add_argument('--image_height', type=int, required=True)
    parser.add_argument('--image_width', type=int, required=True)
    parser.add_argument('--image_channels', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    parser.add_argument('--num_steps', type=int, required=True)
    parser.add_argument('--two_time_steps', nargs='+', type=int, required=True)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--edm_model_keys', nargs='+', type=str, default=None)
    parser.add_argument('--vgg_model_keys', nargs='+', type=str, default=None)
    parser.add_argument('--gan_model_keys', nargs='+', type=str, default=None)
    parser.add_argument('--num_classes', type=int, default=None)
    parser.add_argument('--seed_entropy_noises', type=int, default=None)
    parser.add_argument('--seed_entropy_labels', type=int, default=None)
    parser.add_argument('--num_processes', type=int, default=8)
    parser.add_argument('--device', type=str, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
