# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the NVIDIA Source Code License [see LICENSE for details].

import argparse
from multiprocessing import Process
from typing import List, Optional

from failgen.env_wrapper import FailGenEnvWrapper
from failgen.fail_grasp import GraspFailure
from failgen.fail_instance import IFailure
from failgen.fail_no_rotation import NoRotationFailure
from failgen.fail_rotation import (
    RotationXFailure,
    RotationYFailure,
    RotationZFailure,
)
from failgen.fail_wrong_object import WrongObjectFailure
from failgen.fail_sequence import WrongSequenceFailure
from failgen.fail_slip import SlipFailure
from failgen.fail_translation import (
    TranslationXFailure,
    TranslationYFailure,
    TranslationZFailure,
)

FAILURES_LIST: List[str] = [
    GraspFailure.FAILURE_TYPE,
    RotationXFailure.FAILURE_TYPE,
    RotationYFailure.FAILURE_TYPE,
    RotationZFailure.FAILURE_TYPE,
    TranslationXFailure.FAILURE_TYPE,
    TranslationYFailure.FAILURE_TYPE,
    TranslationZFailure.FAILURE_TYPE,
    NoRotationFailure.FAILURE_TYPE,
    WrongObjectFailure.FAILURE_TYPE,
]


def run_get_failures(
    task_name: str,
    fail_type: str,
    num_episodes: int,
    max_tries: int,
    save_video: bool,
    save_path: str,
) -> None:
    env_wrapper = FailGenEnvWrapper(
        task_name=task_name,
        headless=True,
        record=save_video,
        save_data=True,
        save_path=save_path,
    )

    # Set current failure type
    has_failtype = False
    target_fail_obj: Optional[IFailure] = None
    for fail_obj in env_wrapper.manager._failures:
        if fail_obj.failure_type == fail_type:
            fail_obj.set_enabled(True)
            has_failtype = True
            target_fail_obj = fail_obj
        else:
            fail_obj.set_enabled(False)

    if not has_failtype:
        print(f"Skipping task {task_name} and fail {fail_type}")
        env_wrapper.shutdown()
        return

    print(
        f"Starting demo collection for task: {task_name} and fail: {fail_type}"
    )
    
    if fail_type == WrongObjectFailure.FAILURE_TYPE:
        for i in range(num_episodes):
            desc = env_wrapper.reset()
            attempts = max_tries
            while attempts > 0:
                demo, success, valid_flag_list, waypoint_belonging_list, failtype, fail_index = env_wrapper.get_failure()
                print("success",success)
                if demo is not None and success:
                    env_wrapper.save_error_cot_demo_to_lerobot(i, demo, valid_flag_list, waypoint_belonging_list, failtype, fail_index, desc)
                    break
                else:
                    attempts -= 1
            if attempts <= 0:
                print(
                    f"Got an issue with task: {task_name}, failure: {fail_type}"
                )
            else:
                print(f"Saved episode {i+1} / {num_episodes}")
        print(
            f"Saved {num_episodes} for task {task_name}, failure: {fail_type}"
        )
        return

    assert target_fail_obj is not None
    potential_waypoints = target_fail_obj.waypoints_indices
    for wp_idx in potential_waypoints:
        target_fail_obj.change_waypoint_fail_name(f"waypoint{wp_idx}")
        print(f"Trying to collect from waypoint {wp_idx}")
        for i in range(num_episodes):
            desc = env_wrapper.reset()
            attempts = max_tries
            while attempts > 0:
                demo, success, valid_flag_list, waypoint_belonging_list, failtype, fail_index = env_wrapper.get_failure()
                print("success",success)
                if demo is not None and success:
                    env_wrapper.save_error_cot_demo_to_lerobot(i, demo, valid_flag_list, waypoint_belonging_list, failtype, fail_index, desc)
                    break
                else:
                    attempts -= 1
            if attempts <= 0:
                print(
                    f"Got an issue with task: {task_name}, failure: {fail_type}"
                )
            else:
                print(f"Saved episode {i+1} / {num_episodes}")
        print(
            f"Saved {num_episodes} for task {task_name}, failure: {fail_type}, "
            + f"waypoint-index: {wp_idx}"
        )

    env_wrapper.shutdown()


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--task",
        type=str,
        #default="basketball_in_hoop",
        default="beat_the_buzz",
        help="The name of the task to load for this example",
    )
    parser.add_argument(
        "--episodes",
        type=int,
        default=2,
        help="The number of episodes to collect",
    )
    parser.add_argument(
        "--max_tries",
        type=int,
        default=1,
        help="The maximum number of tries to test a single failure",
    )
    parser.add_argument(
        "--video",
        action="store_true",
        help="Whether or not to save video recordings of the failures",
    )
    parser.add_argument(
        "--failtype",
        type=str,
        default="",
        help="The fail type to use for data collection of single failure"
    )
    parser.add_argument(
        "--savepath",
        type=str,
        default="",
        help="The path to the folder where to save all the data",
    )

    args = parser.parse_args()

    global FAILURES_LIST
    if args.failtype != "":
        FAILURES_LIST = [args.failtype]


    processes = [
        Process(
            target=run_get_failures,
            args=(
                args.task,
                fail_type,
                args.episodes,
                args.max_tries,
                args.video,
                args.savepath,
            ),
        )
        for fail_type in FAILURES_LIST
    ]
    [t.start() for t in processes]
    [t.join() for t in processes]

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
