import collections
import dataclasses
import logging
import math
import pathlib
from typing import Tuple

import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro

LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256  # resolution used to render training data


@dataclasses.dataclass
class Args:
    #################################################################################################################
    # Model server parameters
    #################################################################################################################
    host: str = "127.0.0.1"
    port: int = 8000
    resize_size: int = 224
    replan_steps: int = 5

    #################################################################################################################
    # LIBERO environment-specific parameters
    #################################################################################################################
    task_suite_name: Tuple[str, ...] = (
        # "libero_spatial",  # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
        # "libero_object",
        "libero_goal",
        # "libero_10",
    )
    num_steps_wait: int = 10  # Number of steps to wait for objects to stabilize i n sim
    num_trials_per_task: int = 50  # Number of rollouts per task

    #################################################################################################################
    # Utils
    #################################################################################################################
    video_out_path: str = "data/libero/libero_10_5_videos"  # Path to save videos

    seed: int = 7  # Random Seed (for reproducibility)
    eval_dir: str = "data/libero/eval_results/libero_5_pi05_libero_spatial_branch_rollout"  # Path to save eval results
    save_data: bool = False
    task_id: int = 0


def eval_libero(args: Args) -> None:
    for task_suite_name in args.task_suite_name:
        # Set random seed
        np.random.seed(args.seed)

        # Initialize LIBERO task suite
        benchmark_dict = benchmark.get_benchmark_dict()
        task_suite = benchmark_dict[task_suite_name]()
        num_tasks_in_suite = task_suite.n_tasks
        logging.info(f"Task suite: {task_suite_name}")
        if args.save_data:
            pathlib.Path(f"{args.eval_dir}/data").mkdir(parents=True, exist_ok=True)
        pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
        pathlib.Path(args.eval_dir).mkdir(parents=True, exist_ok=True)
        if task_suite_name == "libero_spatial":
            max_steps = 220  # longest training demo has 193 steps
        elif task_suite_name == "libero_object":
            max_steps = 280  # longest training demo has 254 steps
        elif task_suite_name == "libero_goal":
            max_steps = 300  # longest training demo has 270 steps
        elif task_suite_name == "libero_10":
            max_steps = 520  # longest training demo has 505 steps
        elif task_suite_name == "libero_90":
            max_steps = 400  # longest training demo has 373 steps
        else:
            raise ValueError(f"Unknown task suite: {task_suite_name}")

        client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)

        # Start evaluation
        total_episodes, total_successes = 0, 0
        
        save_path = f"{args.eval_dir}/result.txt"

        with open(save_path, "a") as f:
            f.write(f"\n{task_suite_name}\n")
        for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
            if task_id != args.task_id and args.task_id != -1:
                continue
            # Get task
            task = task_suite.get_task(task_id)

            # Get default LIBERO initial states
            initial_states = task_suite.get_task_init_states(task_id)

            # Initialize LIBERO environment and task description
            env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)

            # Start episodes
            task_episodes, task_successes = 0, 0
            for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
                save_data = []
                logging.info(f"\nTask: {task_description}")

                # Reset environment
                env.reset()
                action_plan = collections.deque()
                action_chunk = None  # Initialize to track current action chunk

                # Set initial states
                obs = env.set_init_state(initial_states[episode_idx % 50])

                # Setup
                t = 0
                replay_images = []
                succeed_step = 0

                logging.info(f"Starting episode {task_episodes+1}...")
                while t < max_steps + args.num_steps_wait:
                    try:
                        # IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
                        # and we need to wait for them to fall
                        if t < args.num_steps_wait:
                            obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
                            t += 1
                            continue

                        # Get preprocessed image
                        # IMPORTANT: rotate 180 degrees to match train preprocessing
                        img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                        wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
                        img = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
                        )
                        wrist_img = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
                        )

                        # Save preprocessed image for replay video
                        replay_images.append(img)

                        if not action_plan:
                            # Finished executing previous action chunk -- compute new chunk
                            # Prepare observations dict
                            element = {
                                "observation/image": img,
                                "observation/wrist_image": wrist_img,
                                "observation/state": np.concatenate(
                                    (
                                        obs["robot0_eef_pos"],
                                        _quat2axisangle(obs["robot0_eef_quat"]),
                                        obs["robot0_gripper_qpos"],
                                    )
                                ),
                                "prompt": str(task_description),
                                "num_steps": 3,
                            }

                            # Query model to get action
                            response = client.infer(element)
                            action_chunk = response["actions"]

                            assert (
                                len(action_chunk) >= args.replan_steps
                            ), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
                            action_plan.extend(action_chunk[: args.replan_steps])

                        action = action_plan.popleft()

                        # Save data at each step (not just when getting new action chunk)
                        if args.save_data:
                            state_info = env.get_sim_state()
                            data = {
                                'image': img,
                                'wrist_image': wrist_img,
                                'state': np.concatenate(
                                    (
                                        obs["robot0_eef_pos"],
                                        _quat2axisangle(obs["robot0_eef_quat"]),
                                        obs["robot0_gripper_qpos"],
                                    )
                                ),
                                'prompt': str(task_description),
                                'action': action.tolist(),  # Save single action to be executed
                                'sim_state': state_info,
                            }
                            save_data.append(data)

                        # Execute action in environment
                        obs, reward, done, info = env.step(action.tolist())
                        if done and succeed_step == 0:
                            succeed_step = t
                        if succeed_step != 0 and t - succeed_step == 20:
                            if args.save_data:
                                img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                                wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
                                img = image_tools.convert_to_uint8(
                                    image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
                                )
                                wrist_img = image_tools.convert_to_uint8(
                                    image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
                                )
                                state_info = env.get_sim_state()
                                data = {
                                    'image': img,
                                    'wrist_image': wrist_img,
                                    'state': np.concatenate(
                                        (
                                            obs["robot0_eef_pos"],
                                            _quat2axisangle(obs["robot0_eef_quat"]),
                                            obs["robot0_gripper_qpos"],
                                        )
                                    ),
                                    'prompt': str(task_description),
                                    'action': np.zeros_like(action.tolist()),
                                    'sim_state': state_info,
                                }
                                save_data.append(data)
                            task_successes += 1
                            total_successes += 1
                            break
                        t += 1

                    except Exception as e:
                        logging.error(f"Caught exception: {e}")
                        break

                task_episodes += 1
                total_episodes += 1

                # Save a replay video of the episode
                suffix = "success" if done else "failure"
                if args.save_data:
                    np.save(f"{args.eval_dir}/data/{task_suite_name}_{task_id}_{episode_idx}_{suffix}.npy", save_data)
                task_segment = task_description.replace(" ", "_")
                imageio.mimwrite(
                    pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{episode_idx}_{suffix}.mp4",
                    [np.asarray(x) for x in replay_images],
                    fps=10,
                )

                # Log current results
                logging.info(f"Success: {done}")
                logging.info(f"# episodes completed so far: {total_episodes}")
                logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")

            # Log final results
            logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
            logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
            with open(save_path, "a") as f:
                f.write(f"{task_description}: {float(task_successes) / float(task_episodes)} ({task_successes}/{task_episodes}). ")

        logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
        logging.info(f"Total episodes: {total_episodes}")
        with open(save_path, "a") as f:
            f.write(f"\nTotal Suc Rate: {float(total_successes) / float(total_episodes)} ({total_successes}/{total_episodes})\n")


def _get_libero_env(task, resolution, seed):
    """Initializes and returns the LIBERO environment, along with the task description."""
    task_description = task.language
    task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
    env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
    env = OffScreenRenderEnv(**env_args)
    env.seed(seed)  # IMPORTANT: seed seems to affect object positions even when using fixed initial state
    return env, task_description


def _quat2axisangle(quat):
    """
    Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
    """
    # clip quaternion
    if quat[3] > 1.0:
        quat[3] = 1.0
    elif quat[3] < -1.0:
        quat[3] = -1.0

    den = np.sqrt(1.0 - quat[3] * quat[3])
    if math.isclose(den, 0.0):
        # This is (close to) a zero degree rotation, immediately return
        return np.zeros(3)

    return (quat[:3] * 2.0 * math.acos(quat[3])) / den


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    tyro.cli(eval_libero)
