# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

from enum import Enum
from typing import List, Tuple

import numpy as np
from rlbench.backend.robot import Robot
from rlbench.backend.task import Task
from rlbench.backend.waypoints import Waypoint

from failgen.fail_instance import IFailure
import copy

class TranslationState(Enum):
    IDLE = 0
    APPLIED = 1


class TranslationFailure(IFailure):
    FAILURE_TYPE = "translation"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        translation_axis: str = "x",
        translation_range: Tuple[float, float] = (-0.1, 0.1),
    ):
        super().__init__(
            robot=robot, name=name, waypoints_indices=waypoints_indices
        )

        self._failure_type = TranslationFailure.FAILURE_TYPE
        self._axis = translation_axis
        self._range = translation_range
        self._state = TranslationState.IDLE
        self.task = None
        self.error_waypoint = None
        self.aft_error_waypoint = None
        self.recover_waypoint = None

    def on_start(self, task: Task, sub_tasks=None) -> None:
        if not self._enabled:
            return
        assert (
            task._waypoints is not None
        ), "TranslationFailure::on_start >> Must have waypoints loaded"
        
        # 用上一帧的位姿替代当前的位姿，然后继续向下移动，然后复制一个原本的正确位姿
        
        if self._state == TranslationState.IDLE:
            self._state = TranslationState.APPLIED
            
            waypoints = task.get_waypoints()
            
            fail_waypoint_name = self._waypoint_fail_name
            fail_waypoint_index = int(fail_waypoint_name[-1])
            waypoint_num = len(waypoints)
            
            new_waypoints = []
            for i in range(fail_waypoint_index):
                new_waypoints.append(waypoints[i])
                
            self.error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.aft_error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.recover_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.error_waypoint._waypoint = self.error_waypoint._waypoint.copy()
            self.aft_error_waypoint._waypoint = self.aft_error_waypoint._waypoint.copy()
            self.recover_waypoint._waypoint = self.recover_waypoint._waypoint.copy()


            delta = np.random.uniform(self._range[0], self._range[1])
            position = self.error_waypoint._waypoint.get_position()
            
            if self._axis == "x":
                position[0] += delta
            elif self._axis == "y":
                position[1] += delta
            elif self._axis == "z":
                position[2] += delta
            self.error_waypoint._waypoint.set_position(position)
            
            
            if fail_waypoint_index>=waypoint_num-1:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index-1]._waypoint.get_pose())
            else:
                self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index+1]._waypoint.get_pose())
            if "open_gripper()" in self.error_waypoint.get_ext():
                self.error_waypoint.clear_ext()
            self.error_waypoint._ext+="error"
            
            new_waypoints.append(self.error_waypoint)
            new_waypoints.append(self.aft_error_waypoint)
            new_waypoints.append(self.recover_waypoint)
            
            for i in range(fail_waypoint_index+1,waypoint_num):
                new_waypoints.append(waypoints[i])
            task._waypoints = new_waypoints
            self.task = task        # Do nothing for now
            
    def on_reset(self) -> None:
        self._state = TranslationState.IDLE
        if self.error_waypoint is not None:
            self.error_waypoint._waypoint.remove()
            self.aft_error_waypoint._waypoint.remove()
            self.recover_waypoint._waypoint.remove()
            self.error_waypoint = None
            self.aft_error_waypoint = None
            self.recover_waypoint = None


    def on_step(self) -> None:
        # Do nothing for now
        ...

    def on_waypoint(self, point: Waypoint) -> None:
        ...
        """
        if not self._enabled:
            return
        if self._state == TranslationState.IDLE:
            if point._waypoint.get_name() == self._waypoint_fail_name:
                self._state = TranslationState.APPLIED
                # Apply the translation perturbation in the corresponding axis
                delta = np.random.uniform(self._range[0], self._range[1])
                position = point._waypoint.get_position()
                if self._axis == "x":
                    position[0] += delta
                elif self._axis == "y":
                    position[1] += delta
                elif self._axis == "z":
                    position[2] += delta
                point._waypoint.set_position(position)
        """


class TranslationXFailure(TranslationFailure):
    FAILURE_TYPE = "translation_x"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        translation_range: Tuple[float, float] = (-0.1, 0.1),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            translation_axis="x",
            translation_range=translation_range,
        )

        self._failure_type = TranslationXFailure.FAILURE_TYPE


class TranslationYFailure(TranslationFailure):
    FAILURE_TYPE = "translation_y"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        translation_range: Tuple[float, float] = (-0.1, 0.1),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            translation_axis="y",
            translation_range=translation_range,
        )

        self._failure_type = TranslationYFailure.FAILURE_TYPE


class TranslationZFailure(TranslationFailure):
    FAILURE_TYPE = "translation_z"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        translation_range: Tuple[float, float] = (-0.1, 0.1),
    ):
        super().__init__(
            robot=robot,
            name=name,
            waypoints_indices=waypoints_indices,
            translation_axis="z",
            translation_range=translation_range,
        )

        self._failure_type = TranslationZFailure.FAILURE_TYPE
