import argparse
from typing import Callable, Any, Optional

import numpy as np
import torch
from tqdm import tqdm

from src.models.models.edm import create_edm_model
from src.utils.edm import get_sigmas
from src.utils.load import load_from_state_dict
from src.utils.utils import create_one_hot
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.load import load
from utils.numpy.save import save
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_object_name, get_class_name


@torch.inference_mode()
def multi_step_euler_sampler(
        num_steps: int,
        time_steps: list[int],
        net: torch.nn.Module,
        noises: np.ndarray,
        class_labels: np.ndarray = None,
        randn_like: Callable[[np.ndarray], np.ndarray] = lambda x: np.random.randn(*x.shape),
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        device: str = None
) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(multi_step_euler_sampler)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'net: {get_object_name(net)}, '
        f'noises: {get_numpy_stats(noises)}, '
        f'class_labels: {get_numpy_stats(class_labels) if class_labels is not None else None}, '
        f'multiply_noises: {multiply_noises}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'device: {device}'
    )
    assert time_steps[0] == 0, 'the first time step must be 0'
    assert time_steps[-1] == num_steps, 'the last time step must be the number of steps'

    if device is None:
        device: str = get_default_device()

    sigmas: np.ndarray = get_sigmas(
        num_steps=num_steps,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        rho=rho
    )

    batch_size: int = noises.shape[0]

    x_next: np.ndarray = noises.astype(np.float64) * sigmas[0] if multiply_noises else noises.astype(np.float64)
    for i in range(len(time_steps) - 1):
        t_cur = sigmas[time_steps[i]]
        t_next = sigmas[time_steps[i + 1]]

        x_cur: np.ndarray = x_next

        gamma: float = min(s_churn / num_steps, np.sqrt(2) - 1) if s_min <= t_cur <= s_max else 0
        t_hat: np.ndarray = np.asarray(t_cur + gamma * t_cur)
        x_hat: np.ndarray = x_cur + np.sqrt(t_hat ** 2 - t_cur ** 2) * s_noise * randn_like(x_cur)

        if np.max(np.abs(x_hat - x_cur)) != 0:
            Logger.warning('warning: sampling is not deterministic.')

        denoised: np.ndarray = net(
            torch.from_numpy(x_hat).to(device),
            torch.from_numpy(np.ones(batch_size) * t_hat).to(device),
            torch.from_numpy(class_labels).to(device)
        ).to(torch.float64).detach().cpu().numpy()

        d_cur: np.ndarray = (x_hat - denoised) / t_hat
        x_next: np.ndarray = x_hat + (t_next - t_hat) * d_cur

    return x_next


@torch.inference_mode()
def inference_multi_step_euler_batch(
        num_steps: int,
        time_steps: list[int],
        model: torch.nn.Module,
        noises: np.ndarray,
        num_classes: Optional[int] = None,
        labels: np.ndarray = None,
        randn_like: Callable[[np.ndarray], np.ndarray] = lambda x: np.random.randn(*x.shape),
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        device: str = None
) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(inference_multi_step_euler_batch)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'model: {get_object_name(model)}, '
        f'noises: {get_numpy_stats(noises)}, '
        f'num_classes: {num_classes}, '
        f'labels: {get_numpy_stats(labels)}, '
        f'multiply_noises: {multiply_noises}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'device: {device}'
    )
    return multi_step_euler_sampler(
        num_steps=num_steps,
        time_steps=time_steps,
        net=model,
        noises=noises,
        class_labels=create_one_hot(labels, num_classes) if labels is not None else None,
        randn_like=randn_like,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        rho=rho,
        s_churn=s_churn,
        s_min=s_min,
        s_max=s_max,
        s_noise=s_noise,
        multiply_noises=multiply_noises,
        device=device
    )


def inference_multi_step_euler(
        num_steps: int,
        time_steps: list[int],
        model: torch.nn.Module,
        noises_folder: str,
        labels_folder: str,
        output_folder: str,
        num_samples: int,
        batch_size: int,
        num_classes: Optional[int] = None,
        start_noise_index: int = 0,
        start_label_index: int = 0,
        start_output_index: int = 0,
        randn_like: Callable[[np.ndarray], np.ndarray] = lambda x: np.random.randn(*x.shape),
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        n_processes: int = 8,
        device: str = None
) -> None:
    print(
        f'{get_class_name(inference_multi_step_euler)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'model: {get_object_name(model)}, '
        f'noises_folder: {noises_folder}, '
        f'labels_folder: {labels_folder}, '
        f'output_folder: {output_folder}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'num_classes: {num_classes}, '
        f'start_noise_index: {start_noise_index}, '
        f'start_label_index: {start_label_index}, '
        f'start_output_index: {start_output_index}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'multiply_noises: {multiply_noises}, '
        f'n_processes: {n_processes}, '
        f'device: {device}'
    )
    n_batches: int = int(np.ceil(num_samples / batch_size))
    for batch_index in tqdm(range(n_batches)):
        start_index = batch_index * batch_size
        end_index = min(start_index + batch_size, num_samples)
        noises_batch: np.ndarray = load(
            folder=noises_folder,
            n_samples=end_index - start_index,
            start_index=start_index + start_noise_index,
            n_processes=n_processes
        )
        labels_batch: np.ndarray = load(
            folder=labels_folder,
            n_samples=end_index - start_index,
            start_index=start_index + start_label_index,
            n_processes=n_processes
        ) if num_classes is not None else None
        output_batch: np.ndarray = inference_multi_step_euler_batch(
            num_steps=num_steps,
            time_steps=time_steps,
            model=model,
            noises=noises_batch,
            num_classes=num_classes,
            labels=labels_batch,
            randn_like=randn_like,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
            s_churn=s_churn,
            s_min=s_min,
            s_max=s_max,
            s_noise=s_noise,
            multiply_noises=multiply_noises,
            device=device
        )
        save(
            folder=output_folder,
            data=output_batch,
            start_index=start_index + start_output_index,
            n_processes=n_processes
        )


def run(
        num_steps: int,
        time_steps: list[int],
        model_name: str,
        model_load_path: str,
        noises_folder: str,
        labels_folder: str,
        output_folder: str,
        num_samples: int,
        batch_size: int,
        num_classes: Optional[int] = None,
        start_noise_index: int = 0,
        start_label_index: int = 0,
        start_output_index: int = 0,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        model_load_keys: list[str] = None,
        multiply_noises: bool = True,
        n_processes: int = 8,
        device: str = None
) -> None:
    print(
        f'{get_class_name(run)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'model_name: {model_name}, '
        f'model_load_path: {model_load_path}, '
        f'noises_folder: {noises_folder}, '
        f'labels_folder: {labels_folder}, '
        f'output_folder: {output_folder}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'num_classes: {num_classes}, '
        f'start_noise_index: {start_noise_index}, '
        f'start_label_index: {start_label_index}, '
        f'start_output_index: {start_output_index}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'model_load_keys: {model_load_keys}, '
        f'multiply_noises: {multiply_noises}, '
        f'n_processes: {n_processes}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    conditional: bool = num_classes is not None
    Logger.debug(f'conditional: {conditional}')

    Logger.debug('creating model')
    model: torch.nn.Module = create_edm_model(model_name)
    model = load_from_state_dict(model, load_path=model_load_path, load_keys=model_load_keys)
    model.eval()
    model.to(device)

    inference_multi_step_euler(
        num_steps=num_steps,
        time_steps=time_steps,
        model=model,
        noises_folder=noises_folder,
        labels_folder=labels_folder,
        output_folder=output_folder,
        num_samples=num_samples,
        batch_size=batch_size,
        num_classes=num_classes,
        start_noise_index=start_noise_index,
        start_label_index=start_label_index,
        start_output_index=start_output_index,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        rho=rho,
        s_churn=s_churn,
        s_min=s_min,
        s_max=s_max,
        s_noise=s_noise,
        multiply_noises=multiply_noises,
        n_processes=n_processes,
        device=device
    )


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--num_steps', type=int, required=True)
    parser.add_argument('--time_steps', type=int, nargs='+', required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--model_load_path', type=str, required=True)
    parser.add_argument('--noises_folder', type=str, required=True)
    parser.add_argument('--labels_folder', type=str, required=True)
    parser.add_argument('--output_folder', type=str, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    parser.add_argument('--batch_size', type=int, required=True)
    parser.add_argument('--num_classes', type=int, default=None)
    parser.add_argument('--start_noise_index', type=int, default=0)
    parser.add_argument('--start_label_index', type=int, default=0)
    parser.add_argument('--start_output_index', type=int, default=0)
    parser.add_argument('--sigma_min', type=float, default=0.002)
    parser.add_argument('--sigma_max', type=float, default=80)
    parser.add_argument('--rho', type=float, default=7)
    parser.add_argument('--s_churn', type=float, default=0)
    parser.add_argument('--s_min', type=float, default=0)
    parser.add_argument('--s_max', type=float, default=float('inf'))
    parser.add_argument('--s_noise', type=float, default=1)
    parser.add_argument('--model_load_keys', type=str, nargs='+', default=None)
    parser.add_argument('--multiply_noises', type=bool, default=True)
    parser.add_argument('--n_processes', type=int, default=8)
    parser.add_argument('--device', type=str, default=None)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
