import torch

from dae.utils.generic_utils import ModulesRegister

TRAIN_T_SAMPLERS = ModulesRegister("TRAIN_T_SAMPLERS", lower=True, default="uniform")
FLOW_TRAINERS = ModulesRegister("FLOW_TRAINERS", lower=True, default="flow_matching")

#####################
### Time samplers ###
#####################


@TRAIN_T_SAMPLERS.register("uniform")
class TimeSamplerUniform:
    def __call__(self, batch_size, device):
        return torch.rand(batch_size, device=device)


@TRAIN_T_SAMPLERS.register("logit_normal")
class TimeSamplerLogitNormal:
    def __init__(self, t_mean=0, t_std=1.0):
        self.t_std = t_std
        self.t_mean = t_mean

    def __call__(self, batch_size, device):
        t = torch.randn(batch_size, device=device) * self.t_std + self.t_mean
        return torch.sigmoid(t)


@TRAIN_T_SAMPLERS.register("heavy_tails")
class TimeSamplerHeavyTails:
    def __init__(self, t_scale=1.29):
        self.t_scale = t_scale

    def __call__(self, batch_size, device):
        u = torch.rand(batch_size, device=device)
        return 1 - u - self.t_scale * (torch.cos(u * torch.pi) ** 2 - 1 + u)


#############################
### Flow Matching Trainer ###
#############################


@FLOW_TRAINERS.register("flow_matching")
class FlowMatchingTrainer:
    def __init__(
        self,
        *,
        timescale: float = 1_000,
        sigma_min: float = 0.0,
        xt_norm: bool = False,
        xt_scale: float = 1.0,
        t_sampler="logit_normal",
        t_sampler_args=None,
    ):
        self.prediction_type = None
        self.t_sampler = TRAIN_T_SAMPLERS[t_sampler](**(t_sampler_args or {}))

        # Args
        self.timescale = timescale
        self.sigma_min = sigma_min
        self.xt_norm = xt_norm
        self.xt_scale = xt_scale

    def alpha(self, t):
        return 1.0 - t

    def sigma(self, t):
        return self.sigma_min + t * (1.0 - self.sigma_min)

    def A(self, t):
        return 1.0

    def B(self, t):
        return -(1.0 - self.sigma_min)

    def get_xt_rescale(self, t, x_t):
        # https://arxiv.org/pdf/2301.10972
        # https://arxiv.org/pdf/2410.04081
        if self.xt_norm or self.xt_scale != 1.0:
            t_arr = t.view(-1, 1, 1, 1)
            if self.xt_norm is True:
                xt_norm = x_t.std(dim=(1, 2, 3), keepdim=True)
            elif self.xt_norm == "paper":
                xt_norm = 1 + (1 - t_arr) * (self.xt_scale - 1)
            elif self.xt_norm == "manual":  #
                xt_norm = ((1 - t_arr) ** 2 * self.xt_scale**2 + t_arr**2) ** 0.5
            else:
                raise ValueError(f"Unknown xt_norm value: {self.xt_norm}")
            return xt_norm
        return 1.0

    def add_noise(self, x, t, noise=None):
        noise = torch.randn_like(x) if noise is None else noise
        s = [x.shape[0]] + [1] * (x.dim() - 1)
        x_t = self.alpha(t).view(*s) * x * self.xt_scale + self.sigma(t).view(*s) * noise
        x_t = x_t / self.get_xt_rescale(t, x_t)

        return x_t, noise

    def loss(self, fn, x, t=None, fn_kwargs=None, noise=None):
        if fn_kwargs is None:
            fn_kwargs = {}

        if t is None:
            t = torch.rand(x.shape[0], device=x.device)
        x_t, noise = self.add_noise(x, t, noise=noise)

        v_pred = fn(x_t, t=t * self.timescale, **fn_kwargs)

        target = self.A(t) * self.xt_scale * x + self.B(t) * noise  # -dxt/dt
        target = target / self.get_xt_rescale(t, x_t)

        loss = ((v_pred.float() - target.float()) ** 2).mean()
        return loss, (x_t, noise, t, v_pred)

    def sample_t(self, batch_size, device):
        return self.t_sampler(batch_size, device=device)

    def get_prediction(
        self,
        fn,
        x_t,
        t,
        fn_kwargs=None,
        uncond_fn_kwargs=None,
        guidance=1.0,
    ):
        # Conditional prediction
        if guidance != 0.0:
            v_pred = fn(x_t, t=t * self.timescale, **(fn_kwargs or {}))
        else:
            v_pred = 0.0

        # Unconditional v_prediction with CFG guidance
        if guidance != 1.0:
            uncond_v_pred = fn(x_t, t=t * self.timescale, **(uncond_fn_kwargs or {}))
            v_pred = uncond_v_pred + guidance * (v_pred - uncond_v_pred)

        return v_pred

    def step(self, x_t, v_pred, cur_t, next_t=0):
        if not isinstance(v_pred, torch.Tensor):
            v_pred = torch.tensor(v_pred, device=x_t.device)
        cur_t = cur_t.reshape((-1,) + (1,) * (x_t.dim() - 1))
        next_xt = x_t + v_pred * (cur_t - next_t)
        prev_scale = self.get_xt_rescale(cur_t, x_t)
        next_scale = self.get_xt_rescale(cur_t, next_xt)

        return next_xt * prev_scale / next_scale
