from typing import Optional, Callable

import numpy as np
import torch.nn
from torch import inference_mode

from src.utils.edm import get_sigmas
from src.utils.utils import create_one_hot_torch
from torch_utils.stats import get_torch_stats
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.utils import get_class_name, get_object_name


@torch.inference_mode()
def multi_step_sampler(
        num_steps: int,
        time_steps: list[int],
        net: torch.nn.Module,
        noises: torch.Tensor,
        class_labels: torch.Tensor = None,
        randn_like: Callable[[torch.Tensor], torch.Tensor] = torch.randn_like,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        same_noise: bool = True,
        device: str = None
) -> torch.Tensor:
    Logger.debug(
        f'{get_class_name(multi_step_sampler)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'net: {get_object_name(net)}, '
        f'noises: {get_torch_stats(noises)}, '
        f'class_labels: {get_torch_stats(class_labels) if class_labels is not None else None}, '
        f'multiply_noises: {multiply_noises}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'same_noise: {same_noise}, '
        f'device: {device}'
    )
    assert multiply_noises or not same_noise, 'same_noise is only relevant when multiply_noises is True'
    assert time_steps[0] == 0, 'the first time step must be 0'
    assert time_steps[-1] == num_steps, 'the last time step must be the number of steps'

    if device is None:
        device: str = get_default_device()

    sigmas: torch.Tensor = torch.from_numpy(get_sigmas(
        num_steps=num_steps,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        rho=rho
    ))

    batch_size: int = noises.shape[0]

    x_next: torch.Tensor = noises.to(torch.float64) * sigmas[0] if multiply_noises else noises.to(torch.float64)
    for i in range(len(time_steps) - 1):
        t_cur = sigmas[time_steps[i]].to(device)
        t_next = sigmas[time_steps[i + 1]].to(device)

        x_cur: torch.Tensor = x_next

        gamma: float = min(s_churn / num_steps, np.sqrt(2) - 1) if s_min <= t_cur <= s_max else 0
        t_hat: torch.Tensor = torch.as_tensor(t_cur + gamma * t_cur)
        x_hat: torch.Tensor = x_cur + torch.sqrt(t_hat ** 2 - t_cur ** 2) * s_noise * randn_like(x_cur)

        if torch.max(torch.abs(x_hat - x_cur)) != 0:
            Logger.warning('warning: sampling is not deterministic.')

        denoised: torch.Tensor = net(
            x_hat,
            torch.ones(batch_size).to(device) * t_hat,
            class_labels
        ).to(torch.float64)

        x_next: torch.Tensor = denoised + t_next * (noises if same_noise else randn_like(x_cur))

    return x_next


@inference_mode()
def inference_multi_step(
        model: torch.nn.Module,
        num_steps: int,
        time_steps: list[int],
        noises: torch.Tensor,
        labels: torch.Tensor = None,
        num_classes: Optional[int] = None,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        same_noise: bool = True,
        device: str = None
) -> torch.Tensor:
    assert labels is None or num_classes is not None, 'labels must be None or num_classes must not be None'
    if device is None:
        device: str = get_default_device()
    return multi_step_sampler(
        num_steps=num_steps,
        time_steps=time_steps,
        net=model,
        noises=noises.to(device),
        class_labels=create_one_hot_torch(labels, num_classes).to(device) if labels is not None else None,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        rho=rho,
        s_churn=s_churn,
        s_min=s_min,
        s_max=s_max,
        s_noise=s_noise,
        multiply_noises=multiply_noises,
        same_noise=same_noise
    )
