import torch
from mp_pytorch.mp import MPFactory
from omegaconf import OmegaConf

from ltsgns_mp.util.own_types import ConfigDict


class ProDMPPredictor(torch.nn.Module):
    def __init__(self, num_dof, mp_config: ConfigDict, num_time_steps: int, device: str):
        super().__init__()

        self.mp_config = mp_config
        self.num_dof = num_dof
        self.num_time_steps = num_time_steps
        self._output_size = None


        # give tau in range between 0 and 1. ProDMP expects tau to be in real time, hence we multiply by num_time_steps
        # This is true for learn tau = True and learn tau = False

        self.learn_tau = mp_config.learn_tau
        if self.learn_tau:
            self.min_tau = mp_config.min_tau * self.num_time_steps
            self.max_tau = mp_config.max_tau * self.num_time_steps
        # delete min/max tau from config since ProDMP crashes otherwise
        mp_config = OmegaConf.to_container(mp_config, resolve=True)
        self._relative_pos = mp_config["relative_pos"]
        del mp_config["relative_pos"]
        if "min_tau" in mp_config:
            del mp_config["min_tau"]
        if "max_tau" in mp_config:
            del mp_config["max_tau"]
        mp_config["num_dof"] = num_dof
        mp_config["tau"] = mp_config["tau"] * self.num_time_steps
        self.mp_config = OmegaConf.create(mp_config)

        self._trajectory_prediction_times = torch.arange(0, self.num_time_steps)
        if device == "cuda":
            self._mp = MPFactory.init_mp(**self.mp_config, device="cuda")
        else:
            self._mp = MPFactory.init_mp(**self.mp_config, device="cpu")

    @property
    def output_size(self):
        if self._output_size is None:
            self._output_size = (
                self.num_dof * self.mp_config["mp_args"]["num_basis"]
                + self.num_dof
                + 1 * self.learn_tau
            )
        return self._output_size

    @property
    def mp(self):
        return self._mp

    def forward(
        self, pos, vel, basis_weights: torch.Tensor, prediction_times: torch.Tensor | None = None,
            output_vel: bool = True, initial_time: int | None = None
    ) -> torch.Tensor:
        if self.learn_tau:
            sigmoid_result = (
                torch.sigmoid(basis_weights[..., 0]) * self.max_tau + self.min_tau
            )
            basis_weights = torch.cat(
                (sigmoid_result.unsqueeze(-1), basis_weights[..., 1:]), dim=-1
            )

        batch_size = pos.shape[:-1]
        device = pos.device

        if initial_time is None:
            initial_time = torch.zeros(batch_size, device=device)

        # repeat prediction times for each batch
        if prediction_times is None:
            prediction_times = self._trajectory_prediction_times.repeat(*batch_size, 1).to(device)

        # Predict trajectory
        self.mp.reset()
        if self._relative_pos:
            # offset the current position and add it again to the end.
            init_pos = torch.zeros_like(pos)
        else:
            init_pos = pos
        self.mp.update_inputs(
            times=prediction_times,
            init_pos=init_pos,
            init_vel=vel,
            init_time=initial_time,
            params=basis_weights,
        )
        trajectories = self.mp.get_trajs()

        if self._relative_pos:
            # add time dimension to pos, to add this to every time step
            trajectories["pos"] = trajectories["pos"] + pos[..., None, :]

        if output_vel:
            output_pos = trajectories["pos"]
            output_vel = trajectories["vel"]
            output = torch.cat((output_pos, output_vel), dim=-1)
        else:
            output = trajectories["pos"]
        return output
