import torch
from custommodules.functional.pos_embed import get_sincos_1d_from_seqlen
from custommodules.init import init_xavier_uniform_zero_bias
from torch import nn

from models.base.single_model_base import SingleModelBase


class TimestepConditioner(SingleModelBase):
    """ https://github.com/facebookresearch/DiT/blob/main/models.py#L27C1-L64C21 but more performant """

    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.num_total_timesteps = self.data_container.get_dataset().getdim_timestep()
        self.dim = dim
        self.static_ctx["condition_dim"] = dim
        # buffer/modules
        self.register_buffer(
            "timestep_embed",
            get_sincos_1d_from_seqlen(seqlen=self.num_total_timesteps, dim=dim),
        )
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim),
            nn.SiLU(),
        )
        # init
        self.reset_parameters()

    def reset_parameters(self):
        self.apply(init_xavier_uniform_zero_bias)

    def forward(self, timestep, velocity):
        # checks + preprocess
        assert timestep.numel() == len(timestep)
        assert velocity.numel() == len(velocity)
        timestep = timestep.flatten()
        velocity = velocity.view(-1, 1)
        # for rollout timestep is simply initialized as 0 -> repeat to batch dimension
        if timestep.numel() == 1:
            timestep = timestep.repeat(velocity.numel())
        # embed
        embed = self.mlp(self.timestep_embed[timestep])
        return embed
