"""Utils for evaluating policies in LIBERO simulation environments."""

import math
import os

import imageio
import numpy as np
import tensorflow as tf
import time
import robosuite.utils.transform_utils as T
from scipy.spatial.transform import Rotation as R

DATE = time.strftime("%Y_%m_%d")
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")

def quat_to_rot6d(quat: np.ndarray) -> np.ndarray:
    """
    Converts a batch of quaternion rotations to 6D rotation representation.

    Args:
        quat (np.ndarray): Array of shape (..., 4), where each row represents a quaternion (w, x, y, z).

    Returns:
        np.ndarray: Array of shape (..., 6), representing the 6D rotation representation.
    """
    # Ensure input is at least 2D for batch processing
    quat = np.atleast_2d(quat)

    # Convert quaternion to rotation matrix (..., 3, 3)
    rot_matrices = R.from_quat(quat).as_matrix()

    # Extract the first two columns of the rotation matrix
    rot_6d = rot_matrices[..., :2].reshape(*rot_matrices.shape[:-2], 6)  # Shape: (..., 6)
    return rot_6d

def get_libero_state(obs, absolute=False):
    if absolute:
        rot6d = quat_to_rot6d(obs["robot0_eef_quat"])
        # print('state', T.quat2axisangle(obs['robot0_eef_quat']))
        return np.concatenate([obs['robot0_eef_pos'], rot6d[0], obs['robot0_gripper_qpos'][0:1]])
    return np.concatenate([obs['robot0_eef_pos'], T.quat2axisangle(obs["robot0_eef_quat"]), obs['robot0_gripper_qpos']])

def get_libero_env(task, model_family, resolution=256):
    from libero.libero import get_libero_path
    from libero.libero.envs import OffScreenRenderEnv
    """Initializes and returns the LIBERO environment, along with the task description."""
    task_description = task.language
    task_bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
    env = OffScreenRenderEnv(**env_args)
    env.seed(0)  # IMPORTANT: seed seems to affect object positions even when using fixed initial state
    return env, task_description


def get_libero_dummy_action(model_family: str):
    """Get dummy/no-op action, used to roll out the simulation while the robot does nothing."""
    return [0, 0, 0, 0, 0, 0, -1]


def resize_image(img, resize_size):
    """
    Takes numpy array corresponding to a single image and returns resized image as numpy array.

    NOTE (Moo Jin): To make input images in distribution with respect to the inputs seen at training time, we follow
                    the same resizing scheme used in the Octo dataloader, which OpenVLA uses for training.
    """
    assert isinstance(resize_size, tuple)
    # Resize to image size expected by model
    img = tf.image.encode_jpeg(img)  # Encode as JPEG, as done in RLDS dataset builder
    img = tf.io.decode_image(img, expand_animations=False, dtype=tf.uint8)  # Immediately decode back
    img = tf.image.resize(img, resize_size, method="lanczos3", antialias=True)
    img = tf.cast(tf.clip_by_value(tf.round(img), 0, 255), tf.uint8)
    img = img.numpy()
    return img


def get_libero_image(obs, resize_size):
    """Extracts image from observations and preprocesses it."""
    assert isinstance(resize_size, int) or isinstance(resize_size, tuple)
    if isinstance(resize_size, int):
        resize_size = (resize_size, resize_size)
    img = obs["agentview_image"]
    img = img[::-1, ::-1]  # IMPORTANT: rotate 180 degrees to match train preprocessing
    # img = resize_image(img, resize_size)
    img_w = obs["robot0_eye_in_hand_image"]
    img_w = img_w[::-1, ::-1]  # IMPORTANT: rotate 180 degrees to match train preprocessing
    # img_w = resize_image(img_w, resize_size)
    return img, img_w


def save_rollout_video(rollout_images, idx, success, task_description, log_file=None, folder=None):
    """Saves an MP4 replay of an episode."""
    if folder is None:
        rollout_dir = f"./rollouts/{DATE}"
    else:
        rollout_dir = f"./rollouts/{DATE}/{folder}/videos"
    os.makedirs(rollout_dir, exist_ok=True)
    processed_task_description = task_description.lower().replace(" ", "_").replace("\n", "_").replace(".", "_")[:50]
    mp4_path = f"{rollout_dir}/{DATE_TIME}--episode={idx}--success={success}--task={processed_task_description}.mp4"
    video_writer = imageio.get_writer(mp4_path, fps=30)
    for img in rollout_images:
        video_writer.append_data(img)
    video_writer.close()
    print(f"Saved rollout MP4 at path {mp4_path}")
    if log_file is not None:
        log_file.write(f"Saved rollout MP4 at path {mp4_path}\n")
    return mp4_path


def quat2axisangle(quat):
    """
    Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55

    Converts quaternion to axis-angle format.
    Returns a unit vector direction scaled by its angle in radians.

    Args:
        quat (np.array): (x,y,z,w) vec4 float angles

    Returns:
        np.array: (ax,ay,az) axis-angle exponential coordinates
    """
    # clip quaternion
    if quat[3] > 1.0:
        quat[3] = 1.0
    elif quat[3] < -1.0:
        quat[3] = -1.0

    den = np.sqrt(1.0 - quat[3] * quat[3])
    if math.isclose(den, 0.0):
        # This is (close to) a zero degree rotation, immediately return
        return np.zeros(3)

    return (quat[:3] * 2.0 * math.acos(quat[3])) / den