import numpy as np
import time
import logging

import robosuite.utils.transform_utils as T
from robosuite.controllers.interpolators.base_interpolator import Interpolator
import toppra as toppra
from toppra.interpolator import SplineInterpolator
from toppra.algorithm import TOPPRA
from toppra.constraint import JointVelocityConstraint, JointAccelerationConstraint, JointTorqueConstraint

logger = logging.getLogger("toppra_interpolator")


class TOPPRAInterpolator(Interpolator):
    """
    Interpolator that uses TOPPRA to generate a time-optimal trajectory.

    Args:
        ndim (int): Number of dimensions to interpolate
        controller_freq (float): Frequency (Hz) of the controller
        vel_limits (np.ndarray): Velocity limits for each dimension
        accel_limits (np.ndarray): Acceleration limits for each dimension
        ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation)
            Specified string determines assumed type of input:
                `'euler'`: Euler orientation inputs
    """
    def __init__(
        self,
        ndim,
        controller_freq,
        policy_freq,
        torque_limits=None,
        vel_limits=None,
        accel_limits=None,
        ori_interpolate=None,
    ):
        self.dim = ndim
        self.ori_interpolate = ori_interpolate
        self.controller_freq = controller_freq
        self.policy_freq = policy_freq
        self.total_steps = int(np.ceil(controller_freq / policy_freq)) # only used for fallback
        self.torque_limits = None
        self.vel_limits = None
        self.accel_limits = None
        if torque_limits is not None:
            self.torque_limits = np.array(torque_limits)
        if vel_limits is not None:
            self.vel_limits = np.array(vel_limits)
        if accel_limits is not None:
            self.accel_limits = np.array(accel_limits)

        self.step = 0
        self.future_waypoints = []

        self.set_states(dim=ndim, ori=ori_interpolate)

    def set_states(self, dim=None, ori=None):
        """
        Updates self.dim and self.ori_interpolate.

        Initializes self.start and self.goal with correct dimensions.

        Args:
            ndim (None or int): Number of dimensions to interpolate

            ori_interpolate (None or str): If set, assumes that we are interpolating angles (orientation)
                Specified string determines assumed type of input:

                    `'euler'`: Euler orientation inputs
                    `'quat'`: Quaternion inputs
        """
        # Update self.dim and self.ori_interpolate
        self.dim = dim if dim is not None else self.dim
        self.ori_interpolate = ori if ori is not None else self.ori_interpolate

        # Set start and goal states
        if self.ori_interpolate is not None:
            if self.ori_interpolate == "euler":
                self.start = np.zeros(3)
            else:  # quaternions
                self.start = np.array((0, 0, 0, 1))
        else:
            self.start = np.zeros(self.dim)
        self.goal = np.array(self.start)

    def toppra_instance(self, waypoints, start, start_vel, inv_dyn=None, num_gridpoints_per_waypoint=None):
        path = SplineInterpolator(
            np.linspace(0, len(waypoints) / self.policy_freq, len(waypoints) + 1, endpoint=True), # timestamps
            np.concatenate((start[None, :], waypoints), axis=0), # waypoints
            bc_type=((1, start_vel), "natural")
        )

        constraints = []
        if self.vel_limits is not None:
            vlim = np.vstack([-self.vel_limits, self.vel_limits]).T
            pc_vel = JointVelocityConstraint(vlim)
            constraints.append(pc_vel)
        if self.accel_limits is not None:
            alim = np.vstack([-self.accel_limits, self.accel_limits]).T
            pc_acc = JointAccelerationConstraint(alim)
            constraints.append(pc_acc)
        if inv_dyn is not None and self.torque_limits is not None:
            tau_lim = np.vstack(self.torque_limits).T
            pc_tau = JointTorqueConstraint(inv_dyn, tau_lim, fs_coef=np.zeros(self.dim))
            constraints.append(pc_tau)

        instance = TOPPRA(
            constraints, 
            path, 
            parametrizer="ParametrizeConstAccel",
            gridpoints=None if num_gridpoints_per_waypoint is None else np.linspace(0, len(waypoints) / self.policy_freq, len(waypoints) * num_gridpoints_per_waypoint + 1, endpoint=True)
        )
    
        return instance

    def set_goal(self, waypoints, start=None, start_vel=None, inv_dyn=None, execution_length=1, last_vel=0):
        """
        Takes a list of waypoints and computes the time-optimal trajectory.
        """
        logger.info(f"TOPPRA received {len(waypoints)} waypoints. Starting trajectory computation.")
        start_time = time.time()
        logger.debug(f"TOPPRA set_goal received with {len(waypoints)} waypoints.")
        self.start = np.array(self.goal) if start is None else np.array(start)
        self.goal = waypoints[execution_length - 1] # replan after reaching first waypoint
        if start_vel is None:
            start_vel = np.zeros(self.dim)

        instance = self.toppra_instance(waypoints, self.start, start_vel, inv_dyn)
        traj = instance.compute_trajectory(1, last_vel, sd_end_min=0)
        
        # global num_inspect
        # instance.inspect(fig_name="toppra_inspect_{}".format(num_inspect))
        # num_inspect += 1

        if traj is None:
            # raise ValueError("TOPPRA failed to compute a trajectory.")
            logger.warning("TOPPRA failed to compute a trajectory.")
            # visualize_problem(path, np.concatenate((self.start[None, :], waypoints), axis=0))
            # global num_inspect
            # instance.inspect(fig_name="toppra_inspect_{}".format(num_inspect))
            # num_inspect += 1

            logger.info("Trying to use previous trajectory plan as fallback...")

            if len(self.future_waypoints) < execution_length:
                logger.warning("Not enough future waypoints, using default PD control.")
                self.qs = np.repeat(waypoints, self.total_steps, axis=0)
                self.qds = np.zeros_like(self.qs)
                self.qdds = np.zeros_like(self.qs)

                execution_end_idx = self.total_steps * execution_length
                self.future_waypoints = []
                self.step = 0

                logger.info(f"TOPPRA set_goal (PD fallback) took {time.time() - start_time:.4f} seconds.")
                return False, False, execution_end_idx, instance, None
            else: # use previous trajectory plan from TOPPRA
                logger.warning("Falling back to using previous trajectory plan.")
                self.qs = self.qs[self.step:]
                self.qds = self.qds[self.step:]
                self.qdds = self.qdds[self.step:]

                execution_end_idx = np.argmin(np.linalg.norm(self.qs - self.future_waypoints[execution_length - 1], axis=1))
                self.future_waypoints = self.future_waypoints[execution_length:]
                self.step = 0

                logger.info(f"TOPPRA set_goal (previous plan fallback) took {time.time() - start_time:.4f} seconds.")
                return False, True, execution_end_idx, instance, None
        
        logger.info(f"TOPPRA trajectory computed successfully. Duration: {traj.duration:.4f}s")
        num_ts_samples = np.ceil(traj.duration * self.controller_freq).astype(int)
        ts_samples = np.linspace(0, num_ts_samples / self.controller_freq, num_ts_samples)
        # ts_samples[-1] = traj.duration
        self.qs = traj(ts_samples, 0)[1:] # 0th point is current state
        self.qds = traj(ts_samples, 1)[1:]
        self.qdds = traj(ts_samples, 2)[1:]

        # determine the time corresponding to the target waypoint using TOPPRA's s->t mapping
        s_grid = instance.problem_data.gridpoints
        t_grid = traj._ts
        # use the path's own waypoint parameter value
        ss_waypoints, _ = instance.path.waypoints
        s_target = ss_waypoints[execution_length]
        # interpolate time for the target path position
        t_target = np.interp(s_target, s_grid, t_grid)
        # find the first controller sample at or after t_target, aligned with qs which uses ts_samples[1:]
        ts_samples_q = ts_samples[1:] # 0th point is current state
        execution_end_idx = int(np.searchsorted(ts_samples_q, t_target, side='left'))
        if execution_end_idx >= len(ts_samples_q):
            execution_end_idx = len(ts_samples_q) - 1
        
        logger.debug(f"  -> Trajectory has {len(self.qs)} steps. Execution end index: {execution_end_idx}")

        self.step = 0

        self.future_waypoints = waypoints[execution_length:]

        logger.info(f"TOPPRA set_goal (success) took {time.time() - start_time:.4f} seconds.")
        return True, True, execution_end_idx, instance, traj

    def get_interpolated_goal(self):
        """
        Provides the next step in interpolation.
        """
        
        pos = self.qs[self.step]
        vel = self.qds[self.step]
        acc = self.qdds[self.step]

        self.step += 1

        return pos, vel, acc

    def get_interpolated_trajectory(self):
        return self.qs, self.qds, self.qdds
    
def visualize_problem(path, waypoints):
    """Visualize the path and waypoints when TOPPRA fails."""
    import matplotlib.pyplot as plt
    import numpy as np
    plt.clf()

    # Sample points along the path since it's a SplineInterpolator
    s_grid = np.linspace(path.path_interval[0], path.path_interval[1], len(waypoints) * 50)
    path_points = np.array([path(s) for s in s_grid])

    # Create subplots for each joint
    num_joints = path_points.shape[1]
    fig, axes = plt.subplots(num_joints, 1, figsize=(10, 3*num_joints))
    fig.suptitle("TOPPRA Failed - Joint Trajectories Over Time")

    # Plot each joint trajectory
    for i in range(num_joints):
        ax = axes[i] if num_joints > 1 else axes
        ax.plot(s_grid, path_points[:, i], label=f"Path")
        ax.plot(np.linspace(path.path_interval[0], path.path_interval[1], len(waypoints)), waypoints[:, i], "x", label="Waypoints", markersize=10)
        ax.set_xlabel("Normalized Time")
        ax.set_ylabel(f"Joint {i+1} Position")
        ax.grid(True)
        ax.legend()

    plt.tight_layout()
    plt.savefig("toppra_problem_failed.png")
    plt.close()

num_inspect = 0
def inspect(instance, path, waypoints, traj):
    """Inspect the problem internal data."""
    import matplotlib.pyplot as plt
    plt.clf()
    
    fig, axs = plt.subplots(4, 2, figsize=(12, 10), sharex=False) # I will use subplots
    fig.suptitle("TOPPRA Trajectory Analysis", fontsize=16)

    # --- Time-parameterized Trajectory ---
    ts_sample = np.linspace(0, traj.duration, 100)
    qs_sample = traj(ts_sample)
    qds_sample = traj(ts_sample, 1)
    qdds_sample = traj(ts_sample, 2)

    waypoint_timesteps = []
    for wp in waypoints:
        idx = np.argmin(np.linalg.norm(qs_sample - wp, axis=1))
        waypoint_timesteps.append(ts_sample[idx])

    for i in range(path.dof):
        axs[0, 0].plot(ts_sample, qs_sample[:, i], c="C{:d}".format(i))
        axs[0, 0].plot(waypoint_timesteps, waypoints[:, i], "x", c="C{:d}".format(i))
        axs[1, 0].plot(ts_sample, qds_sample[:, i], c="C{:d}".format(i))
        axs[2, 0].plot(ts_sample, qdds_sample[:, i], c="C{:d}".format(i))
    axs[2, 0].set_xlabel("Time (s)")
    axs[0, 0].set_ylabel("Position (rad)")
    axs[1, 0].set_ylabel("Velocity (rad/s)")
    axs[2, 0].set_ylabel("Acceleration (rad/s^2)")
    axs[0, 0].set_title("Time-Parameterized Trajectory")

    sdd_vec, sd_vec, _ = instance.compute_parameterization(
        sd_start=instance.problem_data.sd_vec[0],
        sd_end=instance.problem_data.sd_vec[-1],
        sd_end_min=0)

    t_grid = traj._ts
    t_mid = 0.5 * (t_grid[1:] + t_grid[:-1])
    axs[3, 0].plot(t_grid, sd_vec, c='C0', label=r"Path vel $\dot{s}$")
    axs[3, 0].plot(t_mid, sdd_vec, c='C1', label=r"Path accel $\ddot{s}$")
    axs[3, 0].legend()
    axs[3, 0].set_xlabel("Time (s)")
    axs[3, 0].set_ylabel("Path vel/accel")

    # --- Geometric Path (SplineInterpolator) ---
    s_sample = np.linspace(path.path_interval[0], path.path_interval[1], 100)
    ps_sample = path(s_sample)
    pds_sample = path(s_sample, 1)
    pdds_sample = path(s_sample, 2)
    
    ss = np.linspace(path.path_interval[0], path.path_interval[1], len(waypoints), endpoint=True)

    for i in range(path.dof):
        axs[0, 1].plot(s_sample, ps_sample[:, i], c="C{:d}".format(i))
        axs[0, 1].plot(ss, waypoints[:, i], "x", c="C{:d}".format(i))
        axs[1, 1].plot(s_sample, pds_sample[:, i], c="C{:d}".format(i))
        axs[2, 1].plot(s_sample, pdds_sample[:, i], c="C{:d}".format(i))

    axs[2, 1].set_xlabel("Path Parameter s")
    axs[0, 1].set_ylabel("Position (rad)")
    axs[1, 1].set_ylabel("First Derivative (rad/s)")
    axs[2, 1].set_ylabel("Second Derivative (rad/s^2)")
    axs[0, 1].set_title("Geometric Path (Spline)")

    # --- Path-position path-velocity plot ---
    instance.compute_feasible_sets()
    problem_data = instance.problem_data
    K = problem_data.K
    X = problem_data.X
    ax = axs[3, 1]
    if X is not None:
        ax.plot(X[:, 0], c="green", label="Feasible sets")
        ax.plot(X[:, 1], c="green")
    if K is not None:
        ax.plot(K[:, 0], "--", c="red", label="Controllable sets")
        ax.plot(K[:, 1], "--", c="red")
    if problem_data.sd_vec is not None:
        ax.plot(problem_data.sd_vec ** 2, label="Velocity profile")
    ax.set_title("Path-position path-velocity plot")
    ax.set_xlabel("Path position")
    ax.set_ylabel("Path velocity square")
    ax.legend()

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    
    global num_inspect
    plt.savefig(f"kinematics_inspect_{num_inspect}.png")
    plt.close()