from functools import partial

import numpy as np
import torch

from .utils import expand_tensor_dims_as, get_register_fn

_LOSSES = {}
register_loss = get_register_fn(_LOSSES)


class LossBase:
    def __init__(self):
        pass

    def __call__(self, net, x0, model_kwargs={}):
        sigma = self.sample_sigma(len(x0), x0.device)  # Implemented by subclasses
        weight = self.compute_weight(sigma)  # Implemented by subclasses
        weight = expand_tensor_dims_as(weight, x0)
        net_forward = partial(net, **model_kwargs)
        y = x0
        n = torch.randn_like(y) * expand_tensor_dims_as(sigma, y)
        D_yn = net_forward(y + n, sigma)
        loss = weight * ((D_yn - y) ** 2)
        return loss

    def sample_sigma(self, B, device):
        """Sample a batch of noise levels.

        Args:
            B (int): Batch size.
            device (str or torch.device): Device to generate samples on.
        """
        raise NotImplementedError

    def compute_weight(self, sigma):
        """Compute the weight for each noise level.

        Args:
            sigma (torch.Tensor): Noise levels.
        """
        raise NotImplementedError


@register_loss(name="vp")
class VPLoss(LossBase):
    def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
        super().__init__()
        self.beta_d = beta_d
        self.beta_min = beta_min
        self.epsilon_t = epsilon_t

    def sample_sigma(self, B, device):
        rnd_uniform = torch.rand(B, device=device)
        return self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))

    def compute_weight(self, sigma):
        return 1 / sigma**2

    def sigma(self, t):
        t = torch.as_tensor(t)
        return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()


@register_loss(name="ve")
class VELoss(LossBase):
    def __init__(self, sigma_min=0.02, sigma_max=100):
        self.log_sigma_min = np.log(sigma_min)
        self.log_sigma_max = np.log(sigma_max)

    def sample_sigma(self, B, device):
        rnd_uniform = torch.rand(B, device=device)
        return (self.log_sigma_min + rnd_uniform * (self.log_sigma_max - self.log_sigma_min)).exp()

    def compute_weight(self, sigma):
        return 1 / sigma ** 2


@register_loss(name="edm")
class EDMLoss(LossBase):
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
        super().__init__()
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data

    def sample_sigma(self, B, device):
        rnd_normal = torch.randn(B, device=device)
        return (rnd_normal * self.P_std + self.P_mean).exp()

    def compute_weight(self, sigma):
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2


@register_loss(name="edm_wide")
class EDMWideLoss(EDMLoss):
    def __init__(self):
        super().__init__(P_mean=-1.2, P_std=1.8, sigma_data=0.5)


@register_loss(name="edm_ve")
class EDMVELoss(LossBase):
    """EDM loss with VE noise levels."""
    def __init__(self, sigma_min=0.002, sigma_max=80, sigma_data=0.5):
        super().__init__()
        self.log_sigma_min = np.log(sigma_min)
        self.log_sigma_max = np.log(sigma_max)
        self.sigma_data = sigma_data

    def sample_sigma(self, B, device):
        rnd_uniform = torch.rand(B, device=device)
        return (self.log_sigma_min + rnd_uniform * (self.log_sigma_max - self.log_sigma_min)).exp()

    def compute_weight(self, sigma):
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2


@register_loss(name="edm_ve_wide")
class EDMVEWideLoss(EDMVELoss):
    def __init__(self):
        super().__init__(sigma_min=0.0001, sigma_max=80, sigma_data=0.5)


@register_loss(name="edm_discrete")
class EDMDiscreteLoss(LossBase):
    """EDM loss with VE noise levels and discrete timesteps."""
    def __init__(self, sigma_min=0.002, sigma_max=80, sigma_data=0.5, rho=7, num_steps=100):
        super().__init__()
        self.sigma_data = sigma_data
        step_indices = torch.arange(num_steps)
        self.sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

    def sample_sigma(self, B, device):
        rnd = torch.randint(low=0, high=len(self.sigma_steps), size=(B,))
        return self.sigma_steps[rnd].to(device)

    def compute_weight(self, sigma):
        return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2


def dispatch_loss(loss_config):
    name = loss_config.name
    if name not in _LOSSES:
        raise ValueError(f"Unknown loss: {name}")
    return _LOSSES[name]()
