import argparse
import math
from argparse import ArgumentParser

import numpy as np
import torch
from tqdm import tqdm

from external.edm.training.networks import EDMPrecond
from src.models.models.edm import create_edm_model
from src.utils.load import load_from_state_dict
from src.utils.utils import create_one_hot
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.load import load
from utils.numpy.save import save
from utils.utils import get_class_name

from src.fid.multi_step import multi_step_sampler


def run(
        model_name: str,
        model_load_path: str,
        num_steps: int,
        time_steps: list[int],
        image_height: int,
        image_width: int,
        image_channels: int,
        outputs_folder: str,
        noises_folder: str,
        labels_folder: str,
        num_classes: int = None,
        model_load_keys: list[str] = None,
        noises_start_index: int = 0,
        labels_start_index: int = 0,
        outputs_start_index: int = 0,
        num_samples: int = 50,
        batch_size: int = 50,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        same_noise: bool = True,
        num_processes: int = 8,
        device: str = None
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'model_name: {model_name}, '
        f'model_load_path: {model_load_path}, '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'image_height: {image_height}, '
        f'image_width: {image_width}, '
        f'image_channels: {image_channels}, '
        f'outputs_folder: {outputs_folder}, '
        f'noises_folder: {noises_folder}, '
        f'labels_folder: {labels_folder}, '
        f'num_classes: {num_classes}, '
        f'model_load_keys: {model_load_keys}, '
        f'noises_start_index: {noises_start_index}, '
        f'labels_start_index: {labels_start_index}, '
        f'outputs_start_index: {outputs_start_index}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'multiply_noises: {multiply_noises}, '
        f'same_noise: {same_noise}, '
        f'num_processes: {num_processes}, '
        f'device: {device}'
    )

    if device is None:
        device: str = get_default_device()

    conditional: bool = num_classes is not None
    Logger.debug(f'conditional: {conditional}')

    model: EDMPrecond = load_from_state_dict(create_edm_model(model_name), model_load_path, model_load_keys) \
        .eval().to(device)

    num_batches: int = math.ceil(num_samples / batch_size)
    Logger.debug(f'num_batches: {num_batches}')

    for i in tqdm(range(num_batches)):
        start_index: int = i * batch_size
        end_index: int = min((i + 1) * batch_size, num_samples)
        Logger.debug(f'start_index: {start_index}, end_index: {end_index}')

        noises: np.ndarray = load(
            folder=noises_folder,
            n_samples=end_index - start_index,
            start_index=noises_start_index + start_index,
            n_processes=num_processes
        )
        labels: np.ndarray = load(
            folder=labels_folder,
            n_samples=end_index - start_index,
            start_index=labels_start_index + start_index,
            n_processes=num_processes
        ) if conditional else None

        outputs: np.ndarray = multi_step_sampler(
            num_steps=num_steps,
            time_steps=time_steps,
            net=model,
            noises=torch.from_numpy(noises).to(device),
            class_labels=torch.from_numpy(create_one_hot(labels, num_classes)).to(device) if conditional else None,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
            s_churn=s_churn,
            s_min=s_min,
            s_max=s_max,
            s_noise=s_noise,
            multiply_noises=multiply_noises,
            same_noise=same_noise,
            device=device
        ).detach().cpu().numpy()
        save(
            folder=outputs_folder,
            data=outputs,
            start_index=outputs_start_index + start_index,
            n_processes=num_processes
        )

    print('done')


def run_from_config(config: dict) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: ArgumentParser = ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--model_load_path', type=str, required=True)
    parser.add_argument('--num_steps', type=int, required=True)
    parser.add_argument('--time_steps', type=int, nargs='+', required=True)
    parser.add_argument('--image_height', type=int, required=True)
    parser.add_argument('--image_width', type=int, required=True)
    parser.add_argument('--image_channels', type=int, required=True)
    parser.add_argument('--outputs_folder', type=str, required=True)
    parser.add_argument('--noises_folder', type=str, required=True)
    parser.add_argument('--labels_folder', type=str, required=True)
    parser.add_argument('--num_classes', type=int, default=None)
    parser.add_argument('--model_load_keys', type=str, nargs='+', default=None)
    parser.add_argument('--noises_start_index', type=int, default=0)
    parser.add_argument('--labels_start_index', type=int, default=0)
    parser.add_argument('--outputs_start_index', type=int, default=0)
    parser.add_argument('--num_samples', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--sigma_min', type=float, default=0.002)
    parser.add_argument('--sigma_max', type=float, default=80)
    parser.add_argument('--rho', type=float, default=7)
    parser.add_argument('--s_churn', type=float, default=0)
    parser.add_argument('--s_min', type=float, default=0)
    parser.add_argument('--s_max', type=float, default=float('inf'))
    parser.add_argument('--s_noise', type=float, default=1)
    parser.add_argument('--multiply_noises', type=bool, default=True)
    parser.add_argument('--same_noise', type=bool, default=True)
    parser.add_argument('--num_processes', type=int, default=8)
    parser.add_argument('--device', type=str, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
