import torch

from dae.utils.generic_utils import ModulesRegister

T_SCHEDULERS = ModulesRegister("T_SCHEDULERS", lower=True, default="linear")
FLOW_SAMPLERS = ModulesRegister("FLOW_SAMPLERS", lower=True, default="euler")

#########################
### Sampling schdules ###
#########################


@T_SCHEDULERS.register("linear")
class SamplingScheduleLinear:
    def __call__(self, n_steps, device):
        return torch.linspace(1, 0, n_steps + 1, device=device)


@T_SCHEDULERS.register("log_map")
class SamplingScheduleLogMap(SamplingScheduleLinear):
    def __init__(self, n=100):
        self.n = n

    def __call__(self, n_steps, device):
        t_steps = super().__call__(n_steps, device)
        n = torch.tensor(self.n).to(t_steps)
        t_steps = 1 - (torch.log(t_steps * (1 - n) + n)) / (torch.log(n))
        return t_steps


@T_SCHEDULERS.register("pow_shifted")
class SamplingSchedulePowShifted(SamplingScheduleLinear):
    def __init__(self, spacing=2.0):
        self.spacing = spacing

    def __call__(self, n_steps, device):
        t_steps = super().__call__(n_steps, device)
        return t_steps**self.spacing


###############
### Sampler ###
###############


@FLOW_SAMPLERS.register("euler")
class FMEulerSampler:
    def __init__(self, t_scheduler="pow_shifted", steps=None, guidance=1.0):
        self.t_scheduler = T_SCHEDULERS.build(t_scheduler)
        self.default_steps = steps
        self.default_guidance = guidance

    @torch.compiler.disable(recursive=False)
    def sample(
        self,
        fn,
        fm_trainer,
        shape,
        n_steps=None,
        fn_kwargs=None,
        uncond_fn_kwargs=None,
        guidance=None,
        noise=None,
        device=None,
    ):
        if n_steps is None:
            if self.default_steps is None:
                raise ValueError("n_steps must be specified or default_steps must be set in the sampler")
            n_steps = self.default_steps
        if guidance is None:
            guidance = self.default_guidance

        if device is None:
            device = next(fn.parameters()).device
        x_t = torch.randn(shape, device=device) if noise is None else noise
        t_steps = self.t_scheduler(n_steps, device=device)

        with torch.no_grad():
            for i in range(n_steps):
                t = t_steps[i].repeat(x_t.shape[0])
                neg_v = fm_trainer.get_prediction(
                    fn,
                    x_t=x_t,
                    t=t,
                    fn_kwargs=fn_kwargs,
                    uncond_fn_kwargs=uncond_fn_kwargs,
                    guidance=guidance,
                )
                x_t = fm_trainer.step(x_t, neg_v, t_steps[i], t_steps[i + 1])
        return x_t
