import torch

from models.base.single_model_base import SingleModelBase


class Identity(SingleModelBase):
    def __init__(self, mode, noise_std=None, **kwargs):
        super().__init__(**kwargs)
        self.mode = mode
        self.noise_std = noise_std

    def forward(self, x, *_, **__):
        if self.mode == "last_timestep":
            num_channels = self.output_shape[0]
            x_hat = x[:, -num_channels:]
        else:
            raise NotImplementedError
        if self.noise_std is not None:
            x_hat = x_hat + torch.randn_like(x_hat) * self.noise_std
        return dict(x_hat=x_hat)

    def rollout_teacher_forced(self, x, *_, **__):
        assert self.noise_std is None
        return x[:, :-1]