# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

from enum import Enum
from typing import List

import numpy as np
from pyrep.objects.object import Object
from rlbench.backend.robot import Robot
from rlbench.backend.task import Task
from rlbench.backend.waypoints import Waypoint

from failgen.fail_instance import IFailure
import copy


class WrongObjectState(Enum):
    IDLE = 0
    APPLIED = 1


class WrongObjectFailure(IFailure):
    FAILURE_TYPE = "wrong_object"

    def __init__(
        self,
        robot: Robot,
        name: str,
        waypoints_indices: List[int],
        original_name: str,
        alternatives_names: List[str],
        key_waypoint: str,
    ):
        super().__init__(
            robot=robot, name=name, waypoints_indices=waypoints_indices
        )

        self._failure_type = WrongObjectFailure.FAILURE_TYPE
        self._state = WrongObjectState.IDLE
        self._original_name: str = original_name
        self._alternatives_names: List[str] = alternatives_names
        self._key_waypoint: str = key_waypoint
        self.error_waypoint = None
        self.recover_waypoint = None


    def on_start(self, task: Task, sub_tasks=None) -> None:
        if not self._enabled:
            return

        if self._state == WrongObjectState.IDLE:
            self._state = WrongObjectState.APPLIED
            waypoints = task.get_waypoints()
            
            
            original_obj = Object.get_object(self._original_name)
            
            alternative_obj = Object.get_object(
                np.random.choice(self._alternatives_names)
            )
            fail_waypoint_name = self._waypoint_fail_name
            fail_waypoint_index = int(fail_waypoint_name[-1])
            waypoint_num = len(waypoints)
            
            new_waypoints = []
            for i in range(fail_waypoint_index):
                new_waypoints.append(waypoints[i])
            self.error_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.recover_waypoint = copy.deepcopy(waypoints[fail_waypoint_index])
            self.error_waypoint._waypoint = self.error_waypoint._waypoint.copy()
            self.recover_waypoint._waypoint = self.recover_waypoint._waypoint.copy()
            self.error_waypoint._ext+="error_wrong_object"
            #control error waypoint
            error_original_relative_pose = self.error_waypoint._waypoint.get_pose(relative_to=original_obj)
            self.error_waypoint._waypoint.set_pose(error_original_relative_pose, relative_to=alternative_obj)
            
            new_waypoints.append(self.error_waypoint)
            new_waypoints.append(self.recover_waypoint)
            
            for i in range(fail_waypoint_index+1,waypoint_num):
                new_waypoints.append(waypoints[i])
            task._waypoints = new_waypoints
            self.task = task

    def on_reset(self) -> None:
        self._state = WrongObjectState.IDLE
        if self.error_waypoint is not None:
            self.error_waypoint._waypoint.remove()
            self.recover_waypoint._waypoint.remove()
            self.error_waypoint = None
            self.recover_waypoint = None

    def on_step(self) -> None:
        # Do nothing for now
        ...

    def on_waypoint(self, point: Waypoint) -> None:
        # Do nothing for now
        ...
