import collections
import dataclasses
from dataclasses import field
import json
import logging
import math
import pathlib

import imageio
from robosuite import load_controller_config

from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite_patches import apply_patches
from typing import List
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
from scipy.spatial.transform import Rotation as R
import tqdm
import tyro
from utils import (
    merge_delta_actions, normalize_action, put_text, split_action, interpolate_pos_quat,
    unnormalize_action, compute_eef_error, snapshot, restore
)
MANIPULATION_THRESHOLD = 0.002
MOVEMENT_THRESHOLD = 0.05

# POS_SPEED_THRESHOLD = 0.005  # 位置速度阈值
# POS_SPEED_VERY_SLOW_RATIO = 0.1  # 非常慢的帧占比
POS_SPEED_THRESHOLD = 0.005  # 发现libero-10上0.005太低了，改为0.01
POS_SPEED_VERY_SLOW_RATIO = 0.1

ROT_SPEED_THRESHOLD = 0.02  # 旋转速度阈值
GRIPPER_CHANGE_THRESHOLD = 0.02  # 夹爪变化阈值
CONTROL_FREQ = 20
ORIGINAL_CONTROL_FREQ = 20
REPEATS = CONTROL_FREQ // ORIGINAL_CONTROL_FREQ

LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]

LIBERO_ENV_RESOLUTION = 256  # resolution used to render training data
controller_config = load_controller_config(default_controller="OSC_POSE")
output_max = np.array(controller_config["output_max"])
output_min = np.array(controller_config["output_min"])
input_max = np.array(controller_config["input_max"])
input_min = np.array(controller_config["input_min"])

from client import AdaDSClient

@dataclasses.dataclass
class Args:
    #################################################################################################################
    # Model server parameters
    #################################################################################################################
    host: str = "0.0.0.0"
    port: int = 8000
    resize_size: int = 224
    #################################################################################################################
    # Predict K server parameters
    #################################################################################################################
    safeiql_host: str = "localhost"  # 预测 k 值的服务器地址（如果服务器在同一台机器上使用 localhost，否则使用服务器 IP）
    safeiql_port: int = 8888  # 预测 k 值的服务器端口
    #################################################################################################################
    # LIBERO environment-specific parameters
    #################################################################################################################
    task_suite_name: str = (
        "libero_spatial"  # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
    )
    num_steps_wait: int = 10  # Number of steps to wait for objects to stabilize i n sim
    num_trials_per_task: int = 50  # Number of rollouts per task

    seed: int = 7  # Random Seed (for reproducibility)
    #################################################################################################################
    # AdaDS parameters
    #################################################################################################################
    downsample_rates: List[int] = field(default_factory=lambda: [1, 2])  # Downsample rates to try
    movement_threshold: float = 0.05  # Threshold for movement phase
    manipulation_threshold: float = 0.002  # Threshold for manipulation phase

    name: str = "safe_iql"


def eval_libero_with_adads(args: Args) -> None:
    """
    使用自适应降采样（AdaDS）评估 LIBERO 任务。
    
    Args:
        args: 评估参数，包含降采样率和阈值等配置
    """
    # Set random seed
    np.random.seed(args.seed)
    downsample_rates = args.downsample_rates

    # Initialize LIBERO task suite
    benchmark_dict = benchmark.get_benchmark_dict()
    task_suite = benchmark_dict[args.task_suite_name]()
    num_tasks_in_suite = task_suite.n_tasks
    logging.info(f"Task suite: {args.task_suite_name}")
    ds_str = "_".join(map(str, sorted(downsample_rates)))
    video_out_path = f"data/libero/{args.task_suite_name}/{args.name}"
    pathlib.Path(video_out_path).mkdir(parents=True, exist_ok=True)

    if args.task_suite_name == "libero_spatial":
        max_steps = 220 * REPEATS
    elif args.task_suite_name == "libero_object":
        max_steps = 280 * REPEATS
    elif args.task_suite_name == "libero_goal":
        max_steps = 300 * REPEATS
    elif args.task_suite_name == "libero_10":
        max_steps = 520 * REPEATS
    elif args.task_suite_name == "libero_90":
        max_steps = 400 * REPEATS
    else:
        raise ValueError(f"Unknown task suite: {args.task_suite_name}")

    args.num_steps_wait = args.num_steps_wait * REPEATS 

    client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
    
    # 初始化 AdaDS 客户端（用于预测 k 值）
    adads_client = AdaDSClient(host=args.safeiql_host, port=args.safeiql_port)

    # Dictionary to store results for each task
    task_results = {}

    # Apply patches for downsampling
    apply_patches()

    # Start evaluation
    total_episodes, total_successes = 0, 0
    tasks = list(range(num_tasks_in_suite))
    for task_id in tqdm.tqdm(tasks):
        # Get task
        task = task_suite.get_task(task_id)

        # Get default LIBERO initial states
        initial_states = task_suite.get_task_init_states(task_id)

        # Initialize LIBERO environment and task description
        env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)

        # Start episodes
        task_episodes, task_successes = 0, 0
        task_effective_steps_list = []
        task_ds_choices_all = []
        
        for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
            logging.info(f"\nTask: {task_description}")

            # Reset environment
            env.reset()

            # Set initial states
            obs = env.set_init_state(initial_states[episode_idx])

            # Setup
            t = 0
            effective_steps = 0
            replay_images = []
            save_video = True
            episode_ds_choices = []
            done = False  # 初始化 done 标志
            
            # 平滑过渡相关状态
            prev_ds = None  # 上一次的降采样率

            logging.info(f"Starting episode {task_episodes+1} with AdaDS...")
            while t < max_steps + args.num_steps_wait and not done:
                # Wait for objects to stabilize
                if t < args.num_steps_wait:
                    obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
                    t += 1
                    continue
                # Get preprocessed image
                img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
                img = image_tools.convert_to_uint8(
                    image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
                )
                wrist_img = image_tools.convert_to_uint8(
                    image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
                )

                state = np.concatenate(
                    (
                        obs["robot0_eef_pos"],
                        _quat2axisangle(obs["robot0_eef_quat"]),
                        obs["robot0_gripper_qpos"],
                    )
                )
                # Prepare observations dict
                element = {
                    "observation/image": img,
                    "observation/wrist_image": wrist_img,
                    "observation/state": state,
                    "prompt": str(task_description),
                }

                action_chunk = client.infer(element)["actions"]
                selected_k = adads_client.query_predict_k(state, action_chunk)
                # 根据预测的 k 值降采样动作序列
                if selected_k == 1:
                    # 不降采样：执行所有动作
                    actions_to_execute = action_chunk
                else:
                    # 合并 selected_k 个动作为一个
                    actions_to_execute = []
                    minimum_decay_steps = 0
                    last_t = 0
                    for i in range(0, len(action_chunk)-minimum_decay_steps, selected_k):
                        action_batch = action_chunk[i:i+selected_k]
                        if len(action_batch) > 0:
                            action_batch_raw = normalize_action(action_batch, input_max, input_min, output_max, output_min)
                            action_raw = merge_delta_actions(action_batch_raw)
                            action = unnormalize_action(action_raw[None, :], input_max, input_min, output_max, output_min)[0]
                            actions_to_execute.append(action)
                            last_t = i + selected_k
                    # actions_to_execute.extend(action_chunk[last_t:])
                if REPEATS == 2:
                    new_actions_to_execute = []
                    for action in actions_to_execute:
                        # first_half_action, second_half_action = split_action(action)
                        new_actions_to_execute.append(action)
                        new_actions_to_execute.append(action)
                    actions_to_execute = new_actions_to_execute


                # 执行动作，如果 episode 结束则立即退出所有循环
                for action in actions_to_execute:
                    # 在执行 action 前检查 episode 是否已结束，避免在已结束的 episode 中执行 action
                    obs, reward, done, info = env.step(action.tolist())
                    effective_steps += 1
                    t += 1
                    # Save preprocessed image for replay video
                    if save_video:
                        img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
                        img = image_tools.convert_to_uint8(
                            image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
                        )
                        img = put_text(img, f"DS: {selected_k}, Step: {effective_steps}", font_size=0.5, resize=False)
                        replay_images.append(img)
                    # 如果 episode 结束，记录统计信息并退出内部循环
                    if done:
                        task_effective_steps_list.append(effective_steps)
                        task_successes += 1
                        total_successes += 1
                        break
                episode_ds_choices.append(selected_k)
                    

            task_episodes += 1
            total_episodes += 1
            
            # Record statistics
            
            task_ds_choices_all.append(episode_ds_choices)

            # Save video
            if save_video and replay_images:
                suffix = "success" if done else "failure"
                task_segment = task_description.replace(" ", "_")
                video_filename = f"task_{task_id:02d}_{episode_idx:02d}_{suffix}.mp4"
                imageio.mimwrite(
                    pathlib.Path(video_out_path) / video_filename,
                    [np.asarray(x) for x in replay_images],
                    fps=CONTROL_FREQ,
                )
                logging.info(f"Saved video: {video_filename}")

            # Log results
            logging.info(f"Success: {done}")
            if done:
                logging.info(f"Task completed in {effective_steps} effective steps (total steps: {t})")
            else:
                logging.info(f"Task failed after {effective_steps} effective steps (total steps: {t})")
            
            # Log DS statistics for this episode
            if episode_ds_choices:
                ds_avg = np.mean(episode_ds_choices)
                ds_counts = {ds: episode_ds_choices.count(ds) for ds in downsample_rates}
                logging.info(f"Average DS: {ds_avg:.2f}, DS distribution: {ds_counts}")

        # Calculate and store task statistics
        task_success_rate = float(task_successes) / float(task_episodes) if task_episodes > 0 else 0.0
        task_avg_steps = float(np.mean(task_effective_steps_list)) if task_effective_steps_list else 0.0
        task_avg_steps = round(task_avg_steps, 2)
        
        # Calculate average DS usage across all episodes
        all_ds_choices = [ds for ep_ds in task_ds_choices_all for ds in ep_ds]
        avg_ds = float(np.mean(all_ds_choices)) if all_ds_choices else 0.0
        ds_distribution = {ds: all_ds_choices.count(ds) for ds in downsample_rates}
        
        # Store results for this task
        task_results[task_id] = {
            "task_description": task_description,
            "success_rate": task_success_rate,
            "avg_effective_steps": task_avg_steps,
            "avg_ds": round(avg_ds, 2),
            "ds_distribution": ds_distribution,
            "num_episodes": task_episodes,
            "num_successes": task_successes,
        }
        
        logging.info(f"Task success rate: {task_success_rate:.4f}")
        logging.info(f"Task average steps: {task_avg_steps:.2f}")
        logging.info(f"Task average DS: {avg_ds:.2f}")

    # 关闭 AdaDS 客户端连接
    adads_client.close()

    # Calculate overall statistics
    overall_success_rate = float(total_successes) / float(total_episodes) if total_episodes > 0 else 0.0
    all_avg_steps = float(np.mean([task_results[tid]["avg_effective_steps"] for tid in task_results])) if task_results else 0.0
    all_avg_ds = float(np.mean([task_results[tid]["avg_ds"] for tid in task_results])) if task_results else 0.0
    
    # Create final results dictionary
    final_results = {
        "task_suite": args.task_suite_name,
        "method": "adads_adaptive",
        "downsample_rates": downsample_rates,
        "threshold": f"adaptive ({args.manipulation_threshold} for manipulation, {args.movement_threshold} for movement)",
        "num_trials_per_task": args.num_trials_per_task,
        "overall_success_rate": overall_success_rate,
        "overall_avg_effective_steps": round(all_avg_steps, 2),
        "overall_avg_ds": round(all_avg_ds, 2),
        "total_episodes": total_episodes,
        "total_successes": total_successes,
        "task_results": task_results,
    }
    
    # Save results to JSON file
    # Create filename with parameters
    ds_str = "_".join(map(str, sorted(downsample_rates)))
    results_path = pathlib.Path(f"data/libero/{args.task_suite_name}/{args.name}.json")
    results_path.parent.mkdir(parents=True, exist_ok=True)
    with open(results_path, "w") as f:
        json.dump(final_results, f, indent=2)
    
    logging.info(f"Total success rate: {overall_success_rate:.4f}")
    logging.info(f"Overall average steps: {all_avg_steps:.2f}")
    logging.info(f"Overall average DS: {all_avg_ds:.2f}")
    logging.info(f"Results saved to: {results_path}")


def _get_libero_env(task, resolution, seed):
    """Initializes and returns the LIBERO environment, along with the task description."""
    task_description = task.language
    task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
    env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution, "control_freq": CONTROL_FREQ}
    env = OffScreenRenderEnv(**env_args)
    env.seed(seed)  # IMPORTANT: seed seems to affect object positions even when using fixed initial state
    return env, task_description



def _quat2axisangle(quat):
    """
    Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
    """
    # clip quaternion
    if quat[3] > 1.0:
        quat[3] = 1.0
    elif quat[3] < -1.0:
        quat[3] = -1.0

    den = np.sqrt(1.0 - quat[3] * quat[3])
    if math.isclose(den, 0.0):
        # This is (close to) a zero degree rotation, immediately return
        return np.zeros(3)

    return (quat[:3] * 2.0 * math.acos(quat[3])) / den


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    tyro.cli(eval_libero_with_adads)
