import torch
from custommodules.functional.pos_embed import get_sincos_1d_from_seqlen
from custommodules.init import init_xavier_uniform_zero_bias
from torch import nn

from models.base.single_model_base import SingleModelBase


class TimestepVelocityConditioner(SingleModelBase):
    """
    https://github.com/facebookresearch/DiT/blob/main/models.py#L27C1-L64C21 but more performant
    additionally a velocity value is encoded
    """

    def __init__(self, dim, mode="add", norm=True, **kwargs):
        super().__init__(**kwargs)
        self.num_total_timesteps = self.data_container.get_dataset().getdim_timestep()
        self.dim = dim
        self.mode = mode
        self.static_ctx["condition_dim"] = dim
        # buffer/modules
        self.register_buffer(
            "timestep_embed",
            get_sincos_1d_from_seqlen(seqlen=self.num_total_timesteps, dim=dim),
        )
        if mode == "concat":
            assert dim % 2 == 0
            self.timestep_proj = nn.Linear(dim, dim // 2)
            self.velocity_proj = nn.Linear(1, dim // 2)
        elif mode == "add":
            self.timestep_proj = nn.Identity()
            self.velocity_proj = nn.Linear(1, dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6) if norm else nn.Identity()
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim),
            nn.SiLU(),
        )
        # init
        self.reset_parameters()

    def reset_parameters(self):
        self.apply(init_xavier_uniform_zero_bias)

    def forward(self, timestep, velocity):
        # checks + preprocess
        assert timestep.numel() == len(timestep)
        assert velocity.numel() == len(velocity)
        timestep = timestep.flatten()
        velocity = velocity.view(-1, 1).float()
        # for rollout timestep is simply initialized as 0 -> repeat to batch dimension
        if timestep.numel() == 1:
            timestep = timestep.repeat(velocity.numel())
        # embed
        timestep_embed = self.timestep_proj(self.timestep_embed[timestep])
        velocity_embed = self.velocity_proj(velocity)
        if self.mode == "concat":
            x = torch.concat([timestep_embed, velocity_embed], dim=1)
        elif self.mode == "add":
            x = timestep_embed + velocity_embed
        else:
            raise NotImplementedError
        embed = self.mlp(self.norm(x))
        return embed
