# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

from enum import Enum
from typing import List, Tuple

import numpy as np
from rlbench.backend.robot import Robot
from rlbench.backend.task import Task
from rlbench.backend.waypoints import Waypoint

from failgen.fail_instance import IFailure
import copy


class RotationState(Enum):
    IDLE = 0
    APPLIED = 1


class RotationFailure(IFailure):
    FAILURE_TYPE = "rotation"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        rotation_axis: str = "x",
        rotation_range: Tuple[float, float] = (-0.1 * np.pi, 0.1 * np.pi),
    ):
        super().__init__(
            robot=robot, name=name, waypoints_indices=waypoints_indices
        )

        self._failure_type = RotationFailure.FAILURE_TYPE
        self._axis = rotation_axis
        self._range = rotation_range
        self._state = RotationState.IDLE
        self.task = None
        self.error_waypoint = None
        self.aft_error_waypoint = None
        self.recover_waypoint = None

    def on_start(self, task: Task, sub_tasks=None) -> None:
        if not self._enabled:
            return
        assert (
            task._waypoints is not None
        ), "TranslationFailure::on_start >> Must have waypoints loaded"
        
        # 用上一帧的位姿替代当前的位姿，然后继续向下移动，然后复制一个原本的正确位姿
        
        if self._state == RotationState.IDLE:
            self._state = RotationState.APPLIED
            
            waypoints = task.get_waypoints()
            
            fail_waypoint_name = self._waypoint_fail_name
            fail_waypoint_index = int(fail_waypoint_name[-1])
            waypoint_num = len(waypoints)
            
            new_waypoints = []
            for i in range(fail_waypoint_index):
                new_waypoints.append(waypoints[i])
                
            self.error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.aft_error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.recover_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.error_waypoint._waypoint = self.error_waypoint._waypoint.copy()
            self.aft_error_waypoint._waypoint = self.aft_error_waypoint._waypoint.copy()
            self.recover_waypoint._waypoint = self.recover_waypoint._waypoint.copy()

            delta = np.random.uniform(self._range[0], self._range[1])
            orientation = self.error_waypoint._waypoint.get_orientation()
            
            if self._axis == "x":
                orientation[0] += delta
            elif self._axis == "y":
                orientation[1] += delta
            elif self._axis == "z":
                orientation[2] += delta
            self.error_waypoint._waypoint.set_orientation(orientation)
            
            
            if fail_waypoint_index>=waypoint_num-1:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index-1]._waypoint.get_pose())
            else:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index+1]._waypoint.get_pose())
            
            if "open_gripper()" in self.error_waypoint.get_ext():
                self.error_waypoint.clear_ext()
            self.error_waypoint._ext+="error"
            
            new_waypoints.append(self.error_waypoint)
            new_waypoints.append(self.aft_error_waypoint)
            new_waypoints.append(self.recover_waypoint)
            
            for i in range(fail_waypoint_index+1,waypoint_num):
                new_waypoints.append(waypoints[i])
            task._waypoints = new_waypoints
            self.task = task        # Do nothing for now
            
    def on_reset(self) -> None:
        self._state = RotationState.IDLE
        if self.error_waypoint is not None:
            self.error_waypoint._waypoint.remove()
            self.aft_error_waypoint._waypoint.remove()
            self.recover_waypoint._waypoint.remove()
            self.error_waypoint = None
            self.aft_error_waypoint = None
            self.recover_waypoint = None

    def on_step(self) -> None:
        # Do nothing for now
        ...

    def on_waypoint(self, point: Waypoint) -> None:
        ...
        """
        if not self._enabled:
            return
        if self._state == RotationState.IDLE:
            if point._waypoint.get_name() == self._waypoint_fail_name:
                self._state = RotationState.APPLIED
                # Apply the translation perturbation in the corresponding axis
                delta = np.random.uniform(self._range[0], self._range[1])
                orientation = point._waypoint.get_orientation()
                if self._axis == "x":
                    orientation[0] += delta
                elif self._axis == "y":
                    orientation[1] += delta
                elif self._axis == "z":
                    orientation[2] += delta
                point._waypoint.set_orientation(orientation)
        """

class RotationXFailure(RotationFailure):
    FAILURE_TYPE = "rotation_x"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        rotation_range: Tuple[float, float] = (-0.1 * np.pi, 0.1 * np.pi),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            rotation_axis="x",
            rotation_range=rotation_range,
        )

        self._failure_type = RotationXFailure.FAILURE_TYPE


class RotationYFailure(RotationFailure):
    FAILURE_TYPE = "rotation_y"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        rotation_range: Tuple[float, float] = (-0.1 * np.pi, 0.1 * np.pi),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            rotation_axis="y",
            rotation_range=rotation_range,
        )

        self._failure_type = RotationYFailure.FAILURE_TYPE


class RotationZFailure(RotationFailure):
    FAILURE_TYPE = "rotation_z"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        rotation_range: Tuple[float, float] = (-0.1 * np.pi, 0.1 * np.pi),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            rotation_axis="z",
            rotation_range=rotation_range,
        )

        self._failure_type = RotationZFailure.FAILURE_TYPE
