import os
import time
import pickle
import argparse

import ray
import wandb
import torch as t
import numpy as np
import matplotlib.pyplot as plt

from copy import deepcopy
from typing import List, Dict, Any, Union
from ray.util.placement_group import placement_group

from model.robot_simple_latent_conditioned_moe.robot import (
    RobotLatentEncoderHeadEncoder,
    RobotLatentActorEncoderHeadMoE,
    RobotLatentEncoderHeadCritic,
)
from utils.train_utils import RLTrainDirs, ExceptionCatcher, count_parameters
from collector import RobotSimulationCollector
from trainer import TrainerWorker
from visualize import plot_robot_trace

from rise import *


@ray.remote(num_cpus=5)
class RobotTrain:

    def __init__(
        self,
        run_name: str,
        code_root_dir: str,
        env_vars: dict,
        env_config_path: str,
        robot_config_path: str,
        ppo_checkpoint_path: str,
        rl_dirs: RLTrainDirs,
        base_seed: int,
        voxel_size: float,
        max_torque: float,
        min_actions: int,
        observation_frequency: int,
        max_epochs: int,
        collector_num: int,
        collector_timeout: int,
        trainworker_num: int,
        rollout_num: int,
        reward_weight_config: List[float],
        movement_direction: List[float],
        stand_cos_threshold: float,
        height_threshold: float,
        actor_lr: float,
        critic_lr: float,
        ppo_batch_size: int,
        ppo_accumulate_steps: int,
        ppo_train_steps_per_worker: int,
        actor_num_experts: int,
        actor_expert_hidden1: int,
        actor_expert_hidden2: int,
        actor_encoder_hidden1: int,
        actor_encoder_hidden2: int,
        actor_encoder_hidden3: int,
        actor_gating_threshold: float,
        actor_gating_top_k: int,
        actor_routing_entropy_weight: Union[float, None],
        actor_routing_diversity_weight: Union[float, None],
        actor_placeholder_latent_dim: int,
        save_trace_img: bool,
        ppo_save_interval: int,
        hdf5_record_save_interval: int,
        global_camera_width: int,
        global_camera_height: int,
        global_camera_samples: int,
        server_port: int,
        nccl_port: int,
    ):
        with ExceptionCatcher() as _:
            for key, value in env_vars.items():
                os.environ[key] = value

            self.common_env_vars = dict(os.environ)

            self.run_name = run_name
            self.code_root_dir = code_root_dir

            self.env_config_path = env_config_path
            self.robot_config_path = robot_config_path
            self.ppo_checkpoint_path = ppo_checkpoint_path
            self.rl_dirs = rl_dirs

            self.base_seed = base_seed
            self.voxel_size = voxel_size
            self.max_torque = max_torque
            self.min_actions = min_actions
            self.observation_frequency = observation_frequency
            self.max_epochs = max_epochs

            self.collector_num = collector_num
            self.collector_timeout = collector_timeout
            self.trainworker_num = trainworker_num

            self.rollout_num = rollout_num

            if len(reward_weight_config) != 4:
                raise ValueError(
                    "reward_weight_config must have 4 values: "
                    "[movement, standing, height, action_cost]"
                )
            if len(movement_direction) != 3:
                raise ValueError("movement_direction must have 3 values")

            self.reward_weight_config = reward_weight_config
            self.movement_direction = movement_direction
            self.stand_cos_threshold = stand_cos_threshold
            self.height_threshold = height_threshold
            self.actor_lr = actor_lr
            self.critic_lr = critic_lr
            self.ppo_batch_size = ppo_batch_size
            self.ppo_accumulate_steps = ppo_accumulate_steps
            self.ppo_train_steps_per_worker = ppo_train_steps_per_worker
            self.actor_num_experts = actor_num_experts
            self.actor_expert_hidden1 = actor_expert_hidden1
            self.actor_expert_hidden2 = actor_expert_hidden2
            self.actor_encoder_hidden1 = actor_encoder_hidden1
            self.actor_encoder_hidden2 = actor_encoder_hidden2
            self.actor_encoder_hidden3 = actor_encoder_hidden3
            self.actor_gating_threshold = actor_gating_threshold
            self.actor_gating_top_k = actor_gating_top_k
            self.actor_routing_entropy_weight = actor_routing_entropy_weight
            self.actor_routing_diversity_weight = actor_routing_diversity_weight
            self.actor_placeholder_latent_dim = actor_placeholder_latent_dim

            self.save_trace_img = save_trace_img
            self.ppo_save_interval = ppo_save_interval
            self.hdf5_record_save_interval = hdf5_record_save_interval

            self.global_camera_width = global_camera_width
            self.global_camera_height = global_camera_height
            self.global_camera_samples = global_camera_samples

            self.server_port = server_port
            self.nccl_port = nccl_port

            #################################
            # Logging utilities
            #################################
            # Initialize WandB
            if "WANDB_API_KEY" in os.environ:
                mode = "online"
            else:
                mode = "offline"
                print("Note: Wandb is offline")

            reward_component_config = {
                "reward_component_movement": reward_weight_config[0],
                "reward_component_standing": reward_weight_config[1],
                "reward_component_height": reward_weight_config[2],
                "reward_component_action_cost": reward_weight_config[3],
            }
            wandb_run_id = os.environ.get("WANDB_RUN_ID")
            wandb_resume = os.environ.get("WANDB_RESUME")

            wandb.init(
                project=os.environ.get(
                    "WANDB_PROJECT", "endoskeletal-quadruped-expert"
                ),
                name=run_name,
                mode=mode,
                id=wandb_run_id,
                resume=wandb_resume,
                config={
                    "base_seed": base_seed,
                    "voxel_size": voxel_size,
                    "max_torque": max_torque,
                    "min_actions": min_actions,
                    "observation_frequency": observation_frequency,
                    "max_epochs": max_epochs,
                    "collector_num": collector_num,
                    "collector_timeout": collector_timeout,
                    "trainworker_num": trainworker_num,
                    "rollout_num": rollout_num,
                    **reward_component_config,
                    "movement_direction": movement_direction,
                    "stand_cos_threshold": stand_cos_threshold,
                    "height_threshold": height_threshold,
                    "actor_lr": actor_lr,
                    "critic_lr": critic_lr,
                    "ppo_batch_size": ppo_batch_size,
                    "ppo_accumulate_steps": ppo_accumulate_steps,
                    "ppo_train_steps_per_worker": ppo_train_steps_per_worker,
                    "actor_num_experts": actor_num_experts,
                    "actor_expert_hidden1": actor_expert_hidden1,
                    "actor_expert_hidden2": actor_expert_hidden2,
                    "actor_encoder_hidden1": actor_encoder_hidden1,
                    "actor_encoder_hidden2": actor_encoder_hidden2,
                    "actor_encoder_hidden3": actor_encoder_hidden3,
                    "actor_gating_threshold": actor_gating_threshold,
                    "actor_gating_top_k": actor_gating_top_k,
                    "actor_routing_entropy_weight": actor_routing_entropy_weight,
                    "actor_routing_diversity_weight": actor_routing_diversity_weight,
                    "global_camera_width": global_camera_width,
                    "global_camera_height": global_camera_height,
                    "global_camera_samples": global_camera_samples,
                },
            )

            # Log directory paths as config
            wandb.config.update(
                {
                    "paths/root": rl_dirs.root_path,
                    "paths/log": rl_dirs.log_path,
                    "paths/debug_log": rl_dirs.debug_log_path,
                    "paths/results": rl_dirs.results_path,
                    "paths/records": rl_dirs.records_path,
                    "paths/ckpt": rl_dirs.ckpt_path,
                }
            )

            # Log source code using wandb.run.log_code()
            def include_fn(path: str, root: str) -> bool:
                """Include .py files and .rsc files, excluding venv directory."""
                rel_path = os.path.relpath(path, root)
                # Exclude venv directory
                if rel_path.startswith("venv/") or "/venv/" in rel_path:
                    return False
                # Include .py and .rsc files
                return path.endswith(".py") or path.endswith(".rsc")

            wandb.run.log_code(
                root=code_root_dir, name=f"source-{run_name}", include_fn=include_fn
            )
            # Prefer epoch for x-axis while keeping step-based metrics available
            wandb.define_metric("epoch")
            wandb.define_metric("*", step_metric="epoch")
            wandb.define_metric("training/step_*", step_metric="training_step")

            print("WandB Initialized")

            ################################
            # Environment setup
            ################################
            self.env_config = self.load_env_config()

            ################################
            # Robot setup
            ################################
            with open(self.robot_config_path, "rb") as f:
                self.robot_config = pickle.load(f)  # type: RS_StructureConfig
            print(f"Robot config: {self.robot_config}")

            self.voxel_size = self.robot_config.voxel_size

            ################################
            # RL setup
            ################################
            # Match the actor/critic architecture used in `scripts/train_basic_moe/main.py`
            grid_size = 32
            normalize_radius = self.voxel_size * 64 * 1.2 / 2
            self.actor = RobotLatentActorEncoderHeadMoE(
                normalize_radius=normalize_radius,
                grid_resolution=grid_size,
                num_experts=self.actor_num_experts,
                resolution=5,
                actor_hidden1=self.actor_expert_hidden1,
                actor_hidden2=self.actor_expert_hidden2,
                kinematic_hidden_dims=(
                    self.actor_encoder_hidden1,
                    self.actor_encoder_hidden2,
                    self.actor_encoder_hidden3,
                ),
                gating_threshold=(
                    self.actor_gating_threshold
                    if self.actor_gating_threshold > 0
                    else None
                ),
                gating_top_k=(
                    self.actor_gating_top_k if self.actor_gating_top_k > 0 else None
                ),
                latent_dim=self.actor_placeholder_latent_dim,
            )
            self.critic = RobotLatentEncoderHeadCritic(
                RobotLatentEncoderHeadEncoder(
                    normalize_radius=normalize_radius,
                    grid_resolution=grid_size,
                )
            )

            if self.ppo_checkpoint_path is not None:
                print(f"Loading local ppo checkpoint: {self.ppo_checkpoint_path}")
                try:
                    checkpoint = t.load(self.ppo_checkpoint_path)
                    if isinstance(checkpoint, dict):
                        if "actor_state" in checkpoint and "critic_state" in checkpoint:
                            self.actor.load_state_dict(checkpoint["actor_state"])
                            self.critic.load_state_dict(checkpoint["critic_state"])
                            print(
                                "Successfully load actor and critic model from checkpoint"
                            )

                    else:
                        print(f"Unknown checkpoint format: {type(checkpoint)}")
                except Exception as e:
                    print(f"Error loading checkpoint: {e}")

            actor_ref = ray.put(self.actor)
            critic_ref = ray.put(self.critic)
            # Keep refs for later collector recreation (avoid re-serializing full models).
            self.actor_ref = actor_ref
            self.critic_ref = critic_ref
            print(
                f"Actor single expert encoder size: {count_parameters(self.actor.encoders[0])}"
            )
            if self.actor_num_experts > 1:
                print(
                    f"Actor latent-only gate size (input [B, {self.latent_dim}]): {count_parameters(self.actor.gate)}"
                )
            else:
                print("Actor uses no gating (single expert)")
            print(
                f"Actor action_expert size: {count_parameters(self.actor.action_experts[0])}"
            )
            print(f"Actor size: {count_parameters(self.actor)}")
            print(f"Critic encoder size: {count_parameters(self.critic.encoder)}")
            print(f"Critic size: {count_parameters(self.critic)}")
            print("RL initialized")

            num_gpus = max(collector_num, trainworker_num)
            self.pgs = []
            for i in range(num_gpus):
                pg = placement_group([{"CPU": 6, "GPU": 1}], strategy="SPREAD")
                ray.get(pg.ready())
                self.pgs.append(pg)

            self.collectors = []
            self.trainers = []
            collector_start = []
            trainer_start = []
            for c_idx in range(num_gpus):
                if c_idx < collector_num:
                    self.collectors.append(
                        RobotSimulationCollector.options(
                            num_cpus=3,
                            num_gpus=0.5,
                            placement_group=self.pgs[c_idx],
                            runtime_env={"env_vars": self.common_env_vars},
                        ).remote()
                    )
                    collector_start.append(
                        self.collectors[-1].set_parameters.remote(
                            c_idx,
                            self.rl_dirs,
                            actor_ref,
                            critic_ref,
                            self.min_actions,
                            self.observation_frequency,
                            self.voxel_size,
                            self.reward_weight_config,
                            self.movement_direction,
                            self.stand_cos_threshold,
                            self.height_threshold,
                            self.server_port,
                            self.run_name,
                        )
                    )

                if c_idx < trainworker_num:
                    self.trainers.append(
                        TrainerWorker.options(
                            num_cpus=3,
                            num_gpus=0.5,
                            placement_group=self.pgs[c_idx],
                            runtime_env={"env_vars": self.common_env_vars},
                        ).remote()
                    )
                    trainer_start.append(
                        self.trainers[-1].set_parameters.remote(
                            c_idx,
                            actor_ref,
                            critic_ref,
                            ppo_batch_size,
                            actor_lr,
                            critic_lr,
                            ppo_accumulate_steps,
                            ppo_train_steps_per_worker,
                            trainworker_num,
                            nccl_port,
                        )
                    )
            ray.get(collector_start)
            ray.get(trainer_start)
            print("Collectors created")
            print("Trainers created")

    def load_env_config(self):
        cfg = RS_Config()
        cfg.open([self.env_config_path])
        if not cfg.is_valid:
            print(f"Config invalid: {cfg.invalid_reason}")
            return

        return cfg

    def train(self):
        with ExceptionCatcher() as _:
            env_config = self.load_env_config()
            env_config_ref = ray.put(env_config)

            simulated_robot_ids = list(range(self.rollout_num))
            simulated_robot_configs = [deepcopy(self.robot_config)] * self.rollout_num
            for i, robot_config in enumerate(simulated_robot_configs):
                robot_config.name = f"robot_{i}"
            simulated_robot_configs_ref = ray.put(simulated_robot_configs)

            for epoch in range(self.max_epochs):
                self.error_count = 0

                print(f"Start rollout at epoch {epoch}")
                episodes = self._rollout(
                    epoch,
                    env_config_ref,
                    simulated_robot_ids,
                    simulated_robot_configs_ref,
                )

                print(f"[MULTI-GPU] Received rollout episodes: {len(episodes)}")
                episode_length_list = []
                for episode in episodes:
                    episode_length_list.append(len(episode["episode"]))
                print(f"Episode length list: {episode_length_list}")

                if len(episodes) > 2 and self.error_count <= 2:
                    print(f"Finish rollout at epoch {epoch}")
                    self._store_ppo_episodes(epoch, episodes)
                    print(f"Finish store episodes at epoch {epoch}")
                    self._train_ppo(epoch)
                    print(f"Finish train ppo at epoch {epoch}")
                    self._update_collector_parameters(epoch)
                    print(f"Finish update collector parameters at epoch {epoch}")
                else:
                    if self.error_count <= 2:
                        print(
                            "Too little valid episodes collected for training, skipping ppo"
                        )
                    else:
                        print("Too many errors in rollout, skipping ppo")

                wandb.log({"system/error_count": self.error_count, "epoch": epoch})

            wandb.finish()

    def _rollout(
        self,
        epoch: int,
        env_config_ref: ray.ObjectRef,
        robot_ids: List[int],
        robot_configs_ref: ray.ObjectRef,
    ):
        # Generate allocation plan for each collector to ensure no overlap
        total_robots = len(robot_ids)
        sim_per_collector = int(np.ceil(total_robots / self.collector_num))

        print(f"[MULTI-GPU] Sim per collector: {sim_per_collector}")

        # Create allocation plan (continuous, non-shuffled since only one design)
        indices = list(range(total_robots))

        all_allocated_robot_config_indices = []
        all_allocated_robot_ids = []
        for i in range(self.collector_num):
            start_idx = i * sim_per_collector
            end_idx = min(start_idx + sim_per_collector, total_robots)
            if start_idx < total_robots:
                collector_indices = indices[start_idx:end_idx]
                all_allocated_robot_config_indices.append(collector_indices)
                all_allocated_robot_ids.append(
                    [robot_ids[i] for i in collector_indices]
                )
            else:
                all_allocated_robot_config_indices.append(
                    []
                )  # Empty plan if no robots left
                all_allocated_robot_ids.append([])

        print(
            f"[MULTI-GPU] Robot allocation plans: {all_allocated_robot_config_indices}"
        )
        print(
            f"[MULTI-GPU] Robot allocation plan assigned robot ids: {all_allocated_robot_ids}"
        )
        results = []

        print(f"[MULTI-GPU] Sim per collector: {sim_per_collector}")

        for collector_idx, collector in enumerate(self.collectors):
            allocated_robot_config_indices = all_allocated_robot_config_indices[
                collector_idx
            ]
            results.append(
                collector.collect.remote(
                    epoch,
                    env_config_ref,
                    robot_ids,
                    robot_configs_ref,
                    allocated_robot_config_indices,
                    self.actor_placeholder_latent_dim,
                    self.global_camera_width,
                    self.global_camera_height,
                    self.global_camera_samples,
                    self.hdf5_record_save_interval,
                )
            )

        episodes = []
        for rank, result in enumerate(results):
            print(f"[MULTI-GPU] Get result from collector {rank}")
            try:
                sub_episodes = ray.get(result, timeout=self.collector_timeout)
            except (ray.exceptions.GetTimeoutError, Exception) as e:
                # Handle both timeout errors and general exceptions
                if isinstance(e, ray.exceptions.GetTimeoutError):
                    print(
                        f"Timeout error: cannot get result from collector {rank}: {e}"
                    )
                    # Cancel the stuck task so it doesn't keep running in background.
                    try:
                        ray.cancel(result, force=True)
                    except Exception as _cancel_e:
                        pass
                    # For timeout errors, kill the collector first
                    ray.kill(self.collectors[rank])
                else:
                    print(f"Error in ray.get during rollout: {e}")

                print(f"Exception caught in collector {rank}, returned none, skipping")
                self.error_count = self.error_count + 1

                # Recreate collector for both types of errors
                self.collectors[rank] = RobotSimulationCollector.options(
                    num_cpus=3,
                    num_gpus=0.5,
                    placement_group=self.pgs[rank],
                    runtime_env={"env_vars": self.common_env_vars},
                ).remote()
                ray.get(
                    self.collectors[rank].set_parameters.remote(
                        rank,
                        self.rl_dirs,
                        self.actor_ref,
                        self.critic_ref,
                        self.min_actions,
                        self.observation_frequency,
                        self.voxel_size,
                        self.reward_weight_config,
                        self.movement_direction,
                        self.stand_cos_threshold,
                        self.height_threshold,
                        self.server_port,
                        self.run_name,
                    )
                )
                print(f"Collector {rank} recreated")
                continue

            if sub_episodes is None:
                print(
                    f"Other exception caught in collector {rank}, returned none, skipping"
                )
                continue

            for episode in sub_episodes:
                if len(episode["episode"]) > 0:
                    episodes.append(episode)

        return episodes

    def _store_ppo_episodes(
        self,
        epoch: int,
        episodes: List[Dict[str, Any]],
    ):
        with ExceptionCatcher() as _:
            episode_lengths = []

            rewards = []
            rectified_distances = []
            costs = []
            probs = []
            movement_reward_sums = []
            standing_reward_sums = []
            height_reward_sums = []
            action_cost_reward_sums = []

            per_step_rewards = []
            per_step_costs = []
            per_step_movement_rewards = []
            per_step_standing_rewards = []
            per_step_height_rewards = []
            per_step_action_cost_rewards = []

            robot_traces = []

            total_episodes = len(episodes)

            base_episodes_per_worker = total_episodes // self.trainworker_num
            remainder = total_episodes % self.trainworker_num

            print(f"Number of episodes: {total_episodes}")
            print(f"Base episodes per worker: {base_episodes_per_worker}")
            print(f"Remainder episodes: {remainder}")
            print(f"Episode length: {len(episodes[-1]['episode']) if episodes else 0}")

            results_list = []
            offset = 0
            for worker_idx in range(self.trainworker_num):
                worker_episode_count = base_episodes_per_worker + (
                    1 if worker_idx < remainder else 0
                )
                worker_episodes = episodes[offset : offset + worker_episode_count]
                offset += worker_episode_count
                if worker_episodes:
                    results_list.append(
                        self.trainers[worker_idx].store_episodes.remote(worker_episodes)
                    )

            for episode in episodes:
                episode_lengths.append(len(episode["episode"]))

                rewards.append(episode["reward"])
                rectified_distances.append(episode["rectified_distance"])
                costs.append(episode["cost"])
                probs.append(episode["prob"])
                movement_sum = float(episode.get("reward_component_movement_sum", 0.0))
                standing_sum = float(episode.get("reward_component_standing_sum", 0.0))
                height_sum = float(episode.get("reward_component_height_sum", 0.0))
                action_cost_sum = float(
                    episode.get("reward_component_action_cost_sum", 0.0)
                )
                movement_reward_sums.append(movement_sum)
                standing_reward_sums.append(standing_sum)
                height_reward_sums.append(height_sum)
                action_cost_reward_sums.append(action_cost_sum)

                per_step_rewards.append(episode["reward"] / len(episode["episode"]))
                per_step_costs.append(episode["cost"] / len(episode["episode"]))

                per_step_movement_rewards.append(movement_sum / len(episode["episode"]))
                per_step_standing_rewards.append(standing_sum / len(episode["episode"]))
                per_step_height_rewards.append(height_sum / len(episode["episode"]))
                per_step_action_cost_rewards.append(
                    action_cost_sum / len(episode["episode"])
                )

                robot_traces.append(episode["trace"])

            # save robot_traces to file
            with open(
                f"{self.rl_dirs.results_path}/epoch_{epoch}_robot_traces.pkl", "wb"
            ) as f:
                pickle.dump(robot_traces, f)

            if self.save_trace_img:
                robot_trace_fig = plot_robot_trace(
                    robot_traces=robot_traces,
                    title=f"{self.run_name} - Epoch {epoch} - Robot traces",
                )

                if robot_trace_fig:
                    wandb.log(
                        {
                            "traces/robot_movement": wandb.Image(robot_trace_fig),
                            "epoch": epoch,
                        }
                    )
                    robot_trace_fig.clf()
                    plt.close(robot_trace_fig)
                else:
                    print("No robot trace figure provided")

            try:
                ray.get(results_list)
            except Exception as e:
                print(f"Error in ray.get during store episodes: {e}")
                self.error_count = self.error_count + 1

            print(f"Finish store episodes at epoch {epoch}")

            # Log all metrics in one grouped wandb.log call
            wandb.log(
                {
                    "episode/max_length": np.max(episode_lengths),
                    "episode/mean_length": np.mean(episode_lengths),
                    "episode/min_length": np.min(episode_lengths),
                    "reward/max": np.max(rewards),
                    "reward/mean": np.mean(rewards),
                    "reward/min": np.min(rewards),
                    "rectified_distance/max": np.max(rectified_distances),
                    "rectified_distance/mean": np.mean(rectified_distances),
                    "rectified_distance/min": np.min(rectified_distances),
                    "cost/max": np.max(costs),
                    "cost/mean": np.mean(costs),
                    "cost/min": np.min(costs),
                    "prob/max": np.max(probs),
                    "prob/mean": np.mean(probs),
                    "prob/min": np.min(probs),
                    "per_step_reward/max": np.max(per_step_rewards),
                    "per_step_reward/mean": np.mean(per_step_rewards),
                    "per_step_reward/min": np.min(per_step_rewards),
                    "per_step_cost/max": np.max(per_step_costs),
                    "per_step_cost/mean": np.mean(per_step_costs),
                    "per_step_cost/min": np.min(per_step_costs),
                    "reward_components/movement_sum/max": np.max(movement_reward_sums),
                    "reward_components/movement_sum/mean": np.mean(
                        movement_reward_sums
                    ),
                    "reward_components/movement_sum/min": np.min(movement_reward_sums),
                    "reward_components/standing_sum/max": np.max(standing_reward_sums),
                    "reward_components/standing_sum/mean": np.mean(
                        standing_reward_sums
                    ),
                    "reward_components/standing_sum/min": np.min(standing_reward_sums),
                    "reward_components/height_sum/max": np.max(height_reward_sums),
                    "reward_components/height_sum/mean": np.mean(height_reward_sums),
                    "reward_components/height_sum/min": np.min(height_reward_sums),
                    "reward_components/action_cost_sum/max": np.max(
                        action_cost_reward_sums
                    ),
                    "reward_components/action_cost_sum/mean": np.mean(
                        action_cost_reward_sums
                    ),
                    "reward_components/action_cost_sum/min": np.min(
                        action_cost_reward_sums
                    ),
                    "reward_components/per_step_movement/max": np.max(
                        per_step_movement_rewards
                    ),
                    "reward_components/per_step_movement/mean": np.mean(
                        per_step_movement_rewards
                    ),
                    "reward_components/per_step_movement/min": np.min(
                        per_step_movement_rewards
                    ),
                    "reward_components/per_step_standing/max": np.max(
                        per_step_standing_rewards
                    ),
                    "reward_components/per_step_standing/mean": np.mean(
                        per_step_standing_rewards
                    ),
                    "reward_components/per_step_standing/min": np.min(
                        per_step_standing_rewards
                    ),
                    "reward_components/per_step_height/max": np.max(
                        per_step_height_rewards
                    ),
                    "reward_components/per_step_height/mean": np.mean(
                        per_step_height_rewards
                    ),
                    "reward_components/per_step_height/min": np.min(
                        per_step_height_rewards
                    ),
                    "reward_components/per_step_action_cost/max": np.max(
                        per_step_action_cost_rewards
                    ),
                    "reward_components/per_step_action_cost/mean": np.mean(
                        per_step_action_cost_rewards
                    ),
                    "reward_components/per_step_action_cost/min": np.min(
                        per_step_action_cost_rewards
                    ),
                    "epoch": epoch,
                }
            )

            np.savez(
                os.path.join(self.rl_dirs.log_path, f"epoch={epoch}-metrics.npz"),
                rewards=np.array(rewards),
                rectified_distances=np.array(rectified_distances),
                costs=np.array(costs),
                reward_cost_ratio=np.array(rewards) / (np.array(costs) + 1e-3),
            )

    def _train_ppo(self, epoch: int):
        with ExceptionCatcher() as _:
            result_list = []
            for ppo_trainer in self.trainers:
                result_list.append(ppo_trainer.train_loop.remote())

            try:
                training_result = ray.get(result_list)[0]
            except Exception as e:
                print(f"Error in ray.get during train ppo: {e}")
                self.error_count = self.error_count + 1
                return

            # only rank 0 trainer will return training performance
            epoch_loss = training_result["epoch_loss"]
            step_actor_loss_list = training_result["step_actor_loss_list"]
            step_critic_loss_list = training_result["step_critic_loss_list"]

            wandb.log({"training/epoch_loss": epoch_loss, "epoch": epoch})
            for i in range(len(step_actor_loss_list)):
                wandb.log(
                    {
                        "training/step_actor_loss": step_actor_loss_list[i],
                        "training/step_critic_loss": step_critic_loss_list[i],
                        "training_step": epoch * self.ppo_train_steps_per_worker + i,
                    }
                )

            print(f"Finish train ppo at epoch {epoch}")

    def _update_collector_parameters(self, epoch: int):
        # get parameters from rank 0 trainer
        bool_save_optimizer_state = epoch % self.ppo_save_interval == 0
        if bool_save_optimizer_state:
            actor_state, critic_state, actor_optimizer_state, critic_optimizer_state = (
                ray.get(
                    self.trainers[0].get_parameters.remote(bool_save_optimizer_state)
                )
            )
        else:
            actor_state, critic_state = ray.get(
                self.trainers[0].get_parameters.remote()
            )

        # Update local actor and critic for reference
        self.actor.load_state_dict(actor_state)
        self.critic.load_state_dict(critic_state)

        actor_state_ref = ray.put(actor_state)
        critic_state_ref = ray.put(critic_state)
        try:
            ranks = ray.get(
                [
                    collector.update_parameters.remote(
                        actor_state_ref, critic_state_ref
                    )
                    for collector in self.collectors
                ]
            )
        except Exception as e:
            print(f"Error in ray.get during update collector parameters: {e}")
            self.error_count = self.error_count + 1
            return

        if bool_save_optimizer_state:
            # save checkpoint to local
            checkpoint_path = os.path.join(
                self.rl_dirs.ckpt_path, f"epoch_{epoch}_checkpoint.st"
            )
            t.save(
                {
                    "actor_state": actor_state,
                    "critic_state": critic_state,
                    "actor_optimizer_state": actor_optimizer_state,
                    "critic_optimizer_state": critic_optimizer_state,
                },
                checkpoint_path,
            )

            print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")

        print(f"Finish update parameters for collectors: {ranks}")


def parse_reward_weight_config(raw: str) -> List[float]:
    """Parse a comma-separated reward weight config string.

    Args:
        raw: Comma-separated string of 4 floats.

    Returns:
        List of 4 float values in the defined component order.
    """
    values = [float(item.strip()) for item in raw.split(",") if item.strip()]
    if len(values) != 4:
        raise ValueError(
            "reward_weight_config must have 4 comma-separated values: "
            "movement, standing, height, action_cost"
        )
    return values


def parse_cuda_devices(raw: str) -> List[int]:
    """Parse CUDA device input into a list of device indices.

    Args:
        raw: String like "0-3", "0,2,3", or "0".

    Returns:
        List of integer CUDA device indices.
    """
    raw = raw.strip()
    if "," in raw:
        return [int(item) for item in raw.split(",") if item]
    if "-" in raw:
        start, end = raw.split("-", maxsplit=1)
        return list(range(int(start), int(end) + 1))
    return [int(raw)]


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Manual Robot Expert Training")
    parser.add_argument("--nccl_port", type=int, default=12345)
    parser.add_argument("--webserver_port", type=int, default=8002)
    parser.add_argument("--run_name", type=str, default="")
    parser.add_argument(
        "--ray_cuda_devices",
        type=str,
        default="0-7",
        help="CUDA device range or list (e.g. '0-3' or '0,2,3')",
    )
    parser.add_argument(
        "--reward_weight_config",
        type=str,
        default="10,1,1,0.05",
        help=(
            "Comma-separated reward weights in order: "
            "movement, standing, height, action_cost"
        ),
    )
    parser.add_argument("--max_epochs", type=int, default=200)
    parser.add_argument(
        "--root_path",
        type=str,
        default="./data/rl-result",
        help="Base directory for training outputs.",
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default=None, help="Load local ppo checkpoint"
    )
    args = parser.parse_args()

    print(f"NCCL port: {args.nccl_port}")
    print(f"Webserver port: {args.webserver_port}")
    print(f"Run name: {args.run_name}")

    # Path configurations
    ppo_checkpoint_path = args.checkpoint_path
    env_config_path = "data/env/env_for_basic_locomotion.rsc"

    # robot_config_path = "data/robot_config/quadruped.data"
    robot_config_path = "data/robot_config/quadruped_8dof.data"

    # Note: robot metadata is intentionally not used by this training script.

    root_path = args.root_path
    log_sub_path = "logs"
    debug_log_sub_path = "debug-logs"
    ckpt_sub_path = "ckpt"
    results_sub_path = "results"
    records_sub_path = "records"

    # Simulation and general configurations
    base_seed = 42
    voxel_size = 0.01
    max_torque = 6
    min_actions = 60
    observation_frequency = 10
    max_epochs = args.max_epochs

    # Compute configurations
    collector_num = 1
    collector_timeout = 600
    trainworker_num = 1

    # RL configurations
    # Order: movement, standing, height, action_cost.
    reward_weight_config = parse_reward_weight_config(args.reward_weight_config)
    movement_direction = [1.0, 0.0, 0.0]
    stand_cos_threshold = 0.5
    height_threshold = voxel_size * 15
    rollout_num = 16
    actor_lr = 6e-5
    critic_lr = 6e-5
    ppo_batch_size = 128
    ppo_train_steps_per_worker = 40
    ppo_accumulate_steps = 1
    actor_num_experts = 1
    actor_expert_hidden1 = 128
    actor_expert_hidden2 = 32
    actor_encoder_hidden1 = 128
    actor_encoder_hidden2 = 192
    actor_encoder_hidden3 = 256
    actor_gating_threshold = 0.0  # 0 disables threshold
    actor_gating_top_k = 0  # 0 disables top-k pruning
    actor_routing_entropy_weight = 0.0  # No punishment for routing entropy with 0.0
    actor_routing_diversity_weight = (
        0.01  # No punishment for routing diversity with 0.0
    )
    actor_placeholder_latent_dim = 512

    # Plot and save configurations
    save_trace_img = True
    ppo_save_interval = 10
    hdf5_record_save_interval = 100

    # Web monitoring configurations
    global_camera_width = 1280
    global_camera_height = 720
    global_camera_samples = 4

    print("initialize ray")

    cuda_devices = parse_cuda_devices(args.ray_cuda_devices)
    cuda_visible = ",".join(map(str, cuda_devices))
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["RAY_DEDUP_LOGS"] = "0"

    print(f"Ray visible CUDA devices: {cuda_visible}")

    ray.init(include_dashboard=False, runtime_env={"env_vars": dict(os.environ)})
    print("Ray initialized")

    rl_dirs = RLTrainDirs(
        run_name=args.run_name,
        root_path=root_path,
        log_sub_path=log_sub_path,
        debug_log_sub_path=debug_log_sub_path,
        ckpt_sub_path=ckpt_sub_path,
        results_sub_path=results_sub_path,
        records_sub_path=records_sub_path,
    )
    env_vars = dict[str, str](os.environ)

    current_dir = os.path.dirname(os.path.abspath(__file__))
    code_root_dir = os.path.abspath(
        os.path.join(os.path.join(current_dir, os.pardir), os.pardir)
    )
    print("generate robot evolution")
    train = RobotTrain.remote(
        run_name=args.run_name,
        code_root_dir=code_root_dir,
        env_vars=env_vars,
        env_config_path=env_config_path,
        robot_config_path=robot_config_path,
        ppo_checkpoint_path=ppo_checkpoint_path,
        rl_dirs=rl_dirs,
        base_seed=base_seed,
        voxel_size=voxel_size,
        max_torque=max_torque,
        min_actions=min_actions,
        observation_frequency=observation_frequency,
        max_epochs=max_epochs,
        collector_num=collector_num,
        collector_timeout=collector_timeout,
        trainworker_num=trainworker_num,
        rollout_num=rollout_num,
        reward_weight_config=reward_weight_config,
        movement_direction=movement_direction,
        stand_cos_threshold=stand_cos_threshold,
        height_threshold=height_threshold,
        actor_lr=actor_lr,
        critic_lr=critic_lr,
        ppo_batch_size=ppo_batch_size,
        ppo_accumulate_steps=ppo_accumulate_steps,
        ppo_train_steps_per_worker=ppo_train_steps_per_worker,
        actor_num_experts=actor_num_experts,
        actor_expert_hidden1=actor_expert_hidden1,
        actor_expert_hidden2=actor_expert_hidden2,
        actor_encoder_hidden1=actor_encoder_hidden1,
        actor_encoder_hidden2=actor_encoder_hidden2,
        actor_encoder_hidden3=actor_encoder_hidden3,
        actor_gating_threshold=actor_gating_threshold,
        actor_gating_top_k=actor_gating_top_k,
        actor_routing_entropy_weight=actor_routing_entropy_weight,
        actor_routing_diversity_weight=actor_routing_diversity_weight,
        actor_placeholder_latent_dim=actor_placeholder_latent_dim,
        save_trace_img=save_trace_img,
        ppo_save_interval=ppo_save_interval,
        hdf5_record_save_interval=hdf5_record_save_interval,
        global_camera_width=global_camera_width,
        global_camera_height=global_camera_height,
        global_camera_samples=global_camera_samples,
        server_port=args.webserver_port,
        nccl_port=args.nccl_port,
    )
    print("start training robots")
    ray.get(train.train.remote())
    time.sleep(5)
    ray.shutdown()
