# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

from enum import Enum
from typing import List

from rlbench.backend.robot import Robot
from rlbench.backend.task import Task
from rlbench.backend.waypoints import Waypoint

from failgen.fail_instance import IFailure
import copy


class GraspState(Enum):
    IDLE = 0
    FAIL = 1


class GraspFailure(IFailure):
    FAILURE_TYPE = "grasp"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoint_indices: List[int],
    ):
        super().__init__(
            robot=robot, name=name, waypoints_indices=waypoint_indices
        )

        self._failure_type = GraspFailure.FAILURE_TYPE
        self._state = GraspState.IDLE
        self.original_waypoints = None
        self.task = None
        self.error_waypoint = None
        self.aft_error_waypoint = None
        self.recover_waypoint = None
        

    def on_start(self, task: Task, sub_tasks=None) -> None:    
        if not self._enabled:
            return
        assert (
            task._waypoints is not None
        ), "GraspFailure::on_start >> Must have waypoints loaded"
        
        # 获取task的waypoints，然后在failname对应的waypoint上添加错误，然后在后面添加两个waypoint，一个移动一小段，一个重新抓取
        # 走向错误的目标waypoint,_ext要在末尾加上aft_error(0)

        # _waypoint_fail_name是当前选中 要添加错误的waypoint；self._waypoints_indices是所有的候选错误waypoint；task._waypoints是所有waypoint
        if self._state == GraspState.IDLE:
            self._state = GraspState.FAIL
            waypoints = task.get_waypoints()  
            cond = 1
                       
            if cond == 1:
                fail_waypoint_name = self._waypoint_fail_name
                fail_waypoint_index = int(fail_waypoint_name[-1])
                waypoint_num = len(waypoints)
                
                new_waypoints = []
                for i in range(fail_waypoint_index):
                    new_waypoints.append(waypoints[i])
                    
                self.error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
                self.aft_error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
                self.recover_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
                self.error_waypoint._waypoint = self.error_waypoint._waypoint.copy()
                self.aft_error_waypoint._waypoint = self.aft_error_waypoint._waypoint.copy()
                self.recover_waypoint._waypoint = self.recover_waypoint._waypoint.copy()
                
                
                self.error_waypoint.clear_ext()
                if fail_waypoint_index>=waypoint_num-1:
                    self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index-1]._waypoint.get_pose())
                else:
                    self.aft_error_waypoint._waypoint.set_pose(waypoints[fail_waypoint_index+1]._waypoint.get_pose())
                    
                    
                self.error_waypoint._ext+="error"
                
                new_waypoints.append(self.error_waypoint)
                new_waypoints.append(self.aft_error_waypoint)
                new_waypoints.append(self.recover_waypoint)
                
                for i in range(fail_waypoint_index+1,waypoint_num):
                    new_waypoints.append(waypoints[i])
                    
                task._waypoints = new_waypoints
                self.task = task


    def on_reset(self) -> None:
        self._state = GraspState.IDLE
        if self.error_waypoint is not None: 
            self.error_waypoint._waypoint.remove()
            self.aft_error_waypoint._waypoint.remove()
            self.recover_waypoint._waypoint.remove()
            self.error_waypoint = None
            self.aft_error_waypoint = None
            self.recover_waypoint = None
        
        

    def on_step(self) -> None:
        # Do nothing for now
        ...

    def on_waypoint(self, point: Waypoint) -> None:
        ...
