# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

from enum import Enum
from typing import List

from pyrep.objects.object import Object
from rlbench.backend.robot import Robot
from rlbench.backend.task import Task
from rlbench.backend.waypoints import Waypoint

from failgen.fail_instance import IFailure
import copy


class NoRotationState(Enum):
    IDLE = 0
    APPLIED = 1


class NoRotationFailure(IFailure):
    FAILURE_TYPE = "no_rotation"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
    ):
        super().__init__(
            robot=robot, name=name, waypoints_indices=waypoints_indices
        )

        self._failure_type = NoRotationFailure.FAILURE_TYPE
        self._state = NoRotationState.IDLE
        self.task = None
        self.error_waypoint = None
        self.aft_error_waypoint = None
        self.recover_waypoint = None

    def on_start(self, task: Task, sub_tasks=None) -> None:
        if not self._enabled:
            return
        assert (
            task._waypoints is not None
        ), "NoRotationFailure::on_start >> Must have waypoints loaded"
        
        # 用上一帧的位姿替代当前的位姿，然后继续向下移动，然后复制一个原本的正确位姿
        
        if self._state == NoRotationState.IDLE:
            self._state = NoRotationState.APPLIED
            
            waypoints = task.get_waypoints()
            
            fail_waypoint_name = self._waypoint_fail_name
            fail_waypoint_index = int(fail_waypoint_name[-1])
            waypoint_num = len(waypoints)
            
            new_waypoints = []
            for i in range(fail_waypoint_index):
                new_waypoints.append(waypoints[i])
                
            self.error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.aft_error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.recover_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.error_waypoint._waypoint = self.error_waypoint._waypoint.copy()
            self.aft_error_waypoint._waypoint = self.aft_error_waypoint._waypoint.copy()
            self.recover_waypoint._waypoint = self.recover_waypoint._waypoint.copy()
            
            self.error_waypoint._waypoint.set_orientation(waypoints[fail_waypoint_index-1]._waypoint.get_orientation())
            
            if fail_waypoint_index>=waypoint_num-1:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index-1]._waypoint.get_pose())
            else:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index+1]._waypoint.get_pose())
                
            self.error_waypoint._ext+="error"
            
            new_waypoints.append(self.error_waypoint)
            new_waypoints.append(self.aft_error_waypoint)
            new_waypoints.append(self.recover_waypoint)
            
            for i in range(fail_waypoint_index+1,waypoint_num):
                new_waypoints.append(waypoints[i])
            task._waypoints = new_waypoints
            self.task = task        # Do nothing for now
            
    def on_reset(self) -> None:
        self._state = NoRotationState.IDLE
        if self.error_waypoint is not None:
            self.error_waypoint._waypoint.remove()
            self.aft_error_waypoint._waypoint.remove()
            self.recover_waypoint._waypoint.remove()
            self.error_waypoint = None
            self.aft_error_waypoint = None
            self.recover_waypoint = None

    def on_step(self) -> None:
        # Do nothing for now
        ...

    def on_waypoint(self, point: Waypoint) -> None:
        ...
        """
        if not self._enabled:
            return
        if self._state == NoRotationState.IDLE:
            if point._waypoint.get_name() == self._waypoint_fail_name:
                self._state = NoRotationState.APPLIED
                # Get the previous waypoint to use its orientation
                curr_waypoint_name = point._waypoint.get_name()
                curr_waypoint_idx = int(curr_waypoint_name[len("waypoint") :])
                prev_waypoint_idx = curr_waypoint_idx - 1
                prev_waypoint_name = f"waypoint{prev_waypoint_idx}"
                prev_waypoint = Object.get_object(prev_waypoint_name)
                point._waypoint.set_orientation(prev_waypoint.get_orientation())
        """
