import torch


class LinearPointRobot(object):
    def __init__(self, x0, x_dim=4, u_dim=2, pos_lims=None, vel_lims=None, dt=0.1,
                 tensor_kwargs={"device": "cpu", "dtype": torch.float32}):
        self.x = x0.clone().detach() if torch.is_tensor(x0) else torch.tensor(x0)
        self.x_dim = x_dim
        self.u_dim = u_dim
        self.pos_lims = pos_lims
        self.vel_lims = vel_lims
        self.dt = dt
        self.tensor_kwargs = tensor_kwargs

        self.x = self.x.to(**tensor_kwargs)

        self.A = torch.eye(self.x_dim, **tensor_kwargs)
        self.A[0, 2] = self.dt
        self.A[1, 3] = self.dt
        self.B = torch.cat((0.5 * self.dt**2 * torch.eye(self.u_dim, **tensor_kwargs),
                            self.dt * torch.eye(self.u_dim, **tensor_kwargs)), dim=0)

        # Batch params
        self.T = None
        self.batch_A, self.batch_B = None, None

    def reset(self, x0):
        self.x = x0.clone().detach() if torch.is_tensor(x0) else torch.tensor(x0)
        self.x = self.x.to(**self.tensor_kwargs)

    def step(self, u):
        self.x = self._linear_model(self.x, u)

    def rollout(self, u):
        assert u.size(-1) == self.u_dim, "Control input has wrong number of dimensions, {}".format(u.size(-1))
        assert u.ndim == 3 or u.ndim == 2, "Control input must be shape (N, T, D) or (T, D)."

        N = u.size(0) if u.ndim == 3 else 1
        T = u.size(-2)
        bA, bB = self._create_batch_matrices(T)

        x_0 = self.x.clone().detach().view(1, self.x_dim)
        u = u.view(N, T * self.u_dim)
        traj = torch.matmul(x_0, bA.T) + torch.matmul(u, bB.T)
        traj = torch.cat([x_0.view(-1).repeat(N, 1, 1), traj.reshape(N, T, self.x_dim)], dim=1)
        if N == 1:
            traj = traj.reshape(T + 1, self.x_dim)

        # Slow way.
        # traj = [self.x.clone().detach()]
        # for i in range(T):
        #     traj.append(self._linear_model(traj[-1], u[i]))
        # traj = torch.stack(traj)

        return traj

    def _linear_model(self, x, u):
        return torch.matmul(self.A, x) + torch.matmul(self.B, u)

    def _create_batch_matrices(self, T):
        if self.T == T:
            return self.batch_A, self.batch_B

        self.batch_A = torch.cat([torch.linalg.matrix_power(self.A, i + 1) for i in range(T)], dim=0)
        self.batch_B = torch.eye(self.x_dim * T, **self.tensor_kwargs)
        for i in range(0, T - 1):
            self.batch_B[i * self.x_dim:(i + 1) * self.x_dim,
                         (i + 1) * self.x_dim:] = self.batch_A.T[:, :(T - i - 1) * self.x_dim]
        self.batch_B = self.batch_B.T.mm(torch.block_diag(*[self.B for i in range(T)]))
        self.T = T

        return self.batch_A, self.batch_B
