import asyncio
import logging
from threading import Lock
from typing import Callable, Optional

import numpy as np
import torch

from furniture_bench.async_utils.async_ws_server import AsyncWebsocketServer
from robosuite.controllers.interpolators.toppra_interpolator import TOPPRAInterpolator

logger = logging.getLogger("toppra_server")


class ToppraServer(AsyncWebsocketServer):
    def __init__(
        self,
        interpolator: TOPPRAInterpolator,
        host: str = "0.0.0.0",
        port: int = 8767,
        inv_dyn: Callable = None,
        toppra_last_vel: np.ndarray = None,
    ):
        super().__init__(host, port)
        self.interpolator = interpolator
        self.controller_freq = interpolator.controller_freq

        self.inv_dyn = inv_dyn
        self.toppra_last_vel = toppra_last_vel

        # State variables
        self._current_joint_pos = None
        self._current_joint_vel = None
        self._current_waypoints = None
        self._current_traj = None
        self._current_instance = None
        self._state_lock = Lock()
        self._is_replanning = False # DEPRECATED, but kept for now to avoid breaking things.
        self._current_replan_task: Optional[asyncio.Task] = None
        self._current_task_is_high_prio: bool = False

        # Replan request
        self._replan_request_waypoints = None
        self._replan_lock = Lock()

    async def on_message_received(self, websocket, message):
        is_high_prio_request = "replan_request" in message
        
        # --- Prioritization Logic ---
        if self._current_replan_task and not self._current_replan_task.done():
            if is_high_prio_request:
                if self._current_task_is_high_prio:
                    logger.info("High-priority request received while another is in progress. Waiting.")
                    await self._current_replan_task
                else:
                    logger.info("High-priority request received. Cancelling current low-priority task.")
                    self._current_replan_task.cancel()
                    try:
                        await self._current_replan_task
                    except asyncio.CancelledError:
                        pass  # Cancellation is expected
            else:  # Low priority request
                logger.info("Server is busy. Dropping new low-priority request.")
                return
        
        # --- Task Creation ---
        if "state_update" in message:
            state = message["state_update"]
            with self._state_lock:
                self._current_joint_pos = state.get("joint_pos", self._current_joint_pos)
                self._current_joint_vel = state.get("joint_vel", self._current_joint_vel)
            
            if state.get("replan", False):
                with self._state_lock:
                    start_pos, start_vel, current_waypoints = self._current_joint_pos, self._current_joint_vel, self._current_waypoints
                
                context = state # The whole state update is the context.
                
                self._current_task_is_high_prio = False
                self._current_replan_task = asyncio.create_task(self.run_replan_and_broadcast(
                    context, start_pos, start_vel, current_waypoints, None
                ))

        elif "replan_request" in message:
            replan_data = message["replan_request"]
            with self._replan_lock:
                new_waypoints_request = replan_data["waypoints"]
                
            with self._state_lock:
                current_waypoints = self._current_waypoints
                if "joint_pos" in replan_data and "joint_vel" in replan_data:
                    start_pos, start_vel = replan_data["joint_pos"], replan_data["joint_vel"]
                else:
                    start_pos, start_vel = self._current_joint_pos, self._current_joint_vel
            
            context = replan_data # The whole request is the context.
            
            self._current_task_is_high_prio = True
            self._current_replan_task = asyncio.create_task(self.run_replan_and_broadcast(
                context, start_pos, start_vel, current_waypoints, new_waypoints_request
            ))

    async def server_task(self):
        """This task is no longer needed as progress is calculated client-side."""
        pass

    async def run_replan_and_broadcast(self, context, start_pos, start_vel, current_waypoints, new_waypoints_request):
        # This function is now mostly stateless, operating on the captured arguments.
        
        # Extract context that was passed directly.
        completed_in_segment_at_replan = context.get("completed_in_segment", 0)
        replan_id = context.get("replan_id", -1)
        
        try:
            if start_pos is None:
                logger.debug("Skipping replan: No valid robot state yet.")
                return

            # Determine remaining waypoints
            num_completed_in_segment = completed_in_segment_at_replan
            remaining_waypoints = current_waypoints
            if new_waypoints_request is not None:
                remaining_waypoints = new_waypoints_request
                self.interpolator.future_waypoints = []
            elif current_waypoints is not None:
                if num_completed_in_segment > 0:
                    remaining_waypoints = current_waypoints[num_completed_in_segment:]

                # Update the interpolator's future_waypoints for fallback mechanism
                self.interpolator.future_waypoints = remaining_waypoints
            
            if remaining_waypoints is None or len(remaining_waypoints) == 0:
                logger.debug("Skipping replan: No remaining waypoints.")
                return

            loop = asyncio.get_running_loop()
            success, used_previous_plan, _, instance, traj = await loop.run_in_executor(
                None,
                lambda: self.interpolator.set_goal(
                    remaining_waypoints,
                    start=start_pos,
                    start_vel=start_vel,
                    inv_dyn=self.inv_dyn,
                    last_vel=self.toppra_last_vel
                ),
            )

            if success or not used_previous_plan:
                qs, qds, qdds = self.interpolator.get_interpolated_trajectory()
                trajectory = {"qs": qs.tolist(), "qds": qds.tolist(), "qdds": qdds.tolist()}
                
                # Echo the entire context back to the client.
                response = {
                    "trajectory": trajectory,
                    "waypoints": remaining_waypoints.tolist() if remaining_waypoints is not None else [],
                    **context  # Echo back all context fields (replan_id, completed_in_segment, and the new state snapshot)
                }

                if new_waypoints_request is not None:
                    response["request_type"] = "replan_request"
                    controllable_set_size = 0.0
                    initial_controllable_set_size = 0.0
                    if instance is not None:
                        K = instance.compute_controllable_sets(0, 1)
                        if K is not None:
                            gridpoints = instance.problem_data.gridpoints
                            k_upper = K[:, 1]
                            k_lower = K[:, 0]
                            controllable_set_size = np.trapz(k_upper - k_lower, gridpoints) / gridpoints[-1]
                            initial_controllable_set_size = K[0, 1] - K[0, 0]
                    response["controllable_set_size"] = controllable_set_size
                    response["initial_controllable_set_size"] = initial_controllable_set_size
                else:
                    response["request_type"] = "state_update"

                if success and traj is not None:
                    response["pd_fallback"] = False
                    # Get waypoint timesteps for gripper synchronization
                    ss_waypoints, _ = instance.path.waypoints
                    s_grid = instance.problem_data.gridpoints
                    t_grid = traj._ts
                    # The first waypoint in ss_waypoints corresponds to the start position (t=0),
                    # so we get timings for the actual goal waypoints.
                    waypoint_timesteps = np.interp(ss_waypoints[1:], s_grid, t_grid).tolist()
                    response["waypoint_timesteps"] = waypoint_timesteps
                else:
                    response["pd_fallback"] = True
                    logger.warning("TOPPRA failed, using PD fallback. Sending estimated timestamps.")
                    # Estimate timestamps based on policy frequency
                    time_per_waypoint = 1.0 / self.interpolator.policy_freq
                    num_waypoints = len(remaining_waypoints)
                    waypoint_timesteps = [(i + 1) * time_per_waypoint for i in range(num_waypoints)]
                    response["waypoint_timesteps"] = waypoint_timesteps

                await self.broadcast(response)
                
                with self._state_lock:
                    self._current_waypoints = remaining_waypoints
        except asyncio.CancelledError:
            logger.info("Replanning task was cancelled.")
            raise
        finally:
            self._current_replan_task = None
            self._current_task_is_high_prio = False

    def _get_completed_waypoint_count(self, current_step, traj, instance):
        # This method is no longer the source of truth for progress,
        # but can be kept for debugging or internal server estimations if needed.
        if traj is not None and instance is not None:
            t_current = current_step / self.controller_freq
            if t_current < traj.duration:
                s_current = np.interp(t_current, traj._ts, instance.problem_data.gridpoints)
                ss_waypoints, _ = instance.path.waypoints
                return np.sum(ss_waypoints[1:] < s_current)
        elif current_step > 0:
            # Estimate progress during PD fallback
            time_per_waypoint = 1.0 / self.interpolator.policy_freq
            time_elapsed = current_step / self.controller_freq
            return int(time_elapsed / time_per_waypoint)
        return 0
