import os
import json
import time
import pprint
import pickle
import argparse

import cma
from matplotlib import pyplot as plt
import ray
from ray.util.placement_group import placement_group
import wandb
import torch as t
import numpy as np

from PIL import Image
from typing import List, Dict, Any, Union

from model.vae_old.star_vae import StarVAE
from model.robot_simple_latent_conditioned_moe.robot import (
    RobotLatentEncoderHeadEncoder,
    RobotLatentActorEncoderHeadMoE,
    RobotLatentEncoderHeadCritic,
)
from sim.create.create_floor import (
    generate_2d_wave_floor, 
    scale_floor_height, 
    clear_floor_center, 
    update_config_with_array_floor
)
from utils.plot import plot_floor
from utils.train_utils import RLTrainDirs, ExceptionCatcher, count_parameters
from utils.robot_stats_utils import (
    population_segment_count_stats,
    population_mass_stats,
    population_bounding_box_stats,
    population_connectivity_stats,
    population_longest_path_stats,
    diversity_cv,
)
from scripts.evolve_floor_latent_conditioned_moe.collector import (
    RobotSimulationCollector,
    build_robot,
)
from scripts.evolve_floor_latent_conditioned_moe.trainer import (
    TrainerWorker,
)
from scripts.evolve_floor_latent_conditioned_moe.visualize import plot_robot_trace

from rise import *


@ray.remote(num_cpus=5)
class RobotEvolution:
    def __init__(
        self,
        run_name: str,
        code_root_dir: str,
        env_vars: dict,
        env_config_path: str,
        generator_path: str,
        ppo_checkpoint_path: str,
        rl_dirs: RLTrainDirs,
        env_floor_array: np.ndarray,
        env_floor_size: float,
        base_seed: int,
        voxel_size: float,
        max_torque: float,
        min_actions: int,
        max_epochs: int,
        collector_num: int,
        collector_timeout: int,
        trainworker_num: int,
        pop_size: int,
        best_k_robot_num: int,
        rollout_num: int,
        rollout_epochs: int,
        reward_type: str,
        actor_lr: float,
        critic_lr: float,
        ppo_batch_size: int,
        ppo_accumulate_steps: int,
        ppo_train_steps_per_worker: int,
        actor_num_experts: int,
        actor_expert_hidden1: int,
        actor_expert_hidden2: int,
        actor_encoder_hidden1: int,
        actor_encoder_hidden2: int,
        actor_encoder_hidden3: int,
        actor_gating_threshold: float,
        actor_gating_top_k: int,
        actor_routing_entropy_weight: Union[float, None],
        actor_routing_diversity_weight: Union[float, None],
        plot_robot_image_num: int,
        save_trace_img: bool,
        ppo_save_interval: int,
        hdf5_record_save_interval: int,
        global_camera_width: int,
        global_camera_height: int,
        global_camera_samples: int,
        server_port: int,
        nccl_port: int,
    ):
        with ExceptionCatcher() as _:
            for key, value in env_vars.items():
                os.environ[key] = value

            self.common_env_vars = dict(os.environ)

            self.run_name = run_name
            self.code_root_dir = code_root_dir

            self.env_config_path = env_config_path
            self.generator_path = generator_path
            self.ppo_checkpoint_path = ppo_checkpoint_path
            self.rl_dirs = rl_dirs

            self.env_floor_array = env_floor_array
            self.env_floor_size = env_floor_size
            self.base_seed = base_seed
            self.voxel_size = voxel_size
            self.max_torque = max_torque
            self.min_actions = min_actions
            self.max_epochs = max_epochs

            self.collector_num = collector_num
            self.collector_timeout = collector_timeout
            self.trainworker_num = trainworker_num

            self.pop_size = pop_size
            self.best_k_robot_num = best_k_robot_num
            self.rollout_num = rollout_num
            self.rollout_epochs = rollout_epochs

            self.reward_type = reward_type
            self.actor_lr = actor_lr
            self.critic_lr = critic_lr
            self.ppo_batch_size = ppo_batch_size
            self.ppo_accumulate_steps = ppo_accumulate_steps
            self.ppo_train_steps_per_worker = ppo_train_steps_per_worker
            self.actor_num_experts = actor_num_experts
            self.actor_expert_hidden1 = actor_expert_hidden1
            self.actor_expert_hidden2 = actor_expert_hidden2
            self.actor_encoder_hidden1 = actor_encoder_hidden1
            self.actor_encoder_hidden2 = actor_encoder_hidden2
            self.actor_encoder_hidden3 = actor_encoder_hidden3
            self.actor_gating_threshold = actor_gating_threshold
            self.actor_gating_top_k = actor_gating_top_k
            self.actor_routing_entropy_weight = actor_routing_entropy_weight
            self.actor_routing_diversity_weight = actor_routing_diversity_weight

            self.plot_robot_image_num = plot_robot_image_num
            self.save_trace_img = save_trace_img
            self.ppo_save_interval = ppo_save_interval
            self.hdf5_record_save_interval = hdf5_record_save_interval

            self.global_camera_width = global_camera_width
            self.global_camera_height = global_camera_height
            self.global_camera_samples = global_camera_samples

            self.server_port = server_port
            self.nccl_port = nccl_port

            #################################
            # Logging utilities
            #################################
            # Initialize WandB
            if "WANDB_API_KEY" in os.environ:
                mode = "online"
            else:
                mode = "offline"
                print("Note: Wandb is offline")

            wandb.init(
                project=os.environ.get("WANDB_PROJECT", "endoskeletal-v2"),
                name=run_name,
                mode=mode,
                config={
                    "base_seed": base_seed,
                    "voxel_size": voxel_size,
                    "max_torque": max_torque,
                    "min_actions": min_actions,
                    "max_epochs": max_epochs,
                    "collector_num": collector_num,
                    "collector_timeout": collector_timeout,
                    "trainworker_num": trainworker_num,
                    "pop_size": pop_size,
                    "best_k_robot_num": best_k_robot_num,
                    "rollout_num": rollout_num,
                    "rollout_epochs": rollout_epochs,
                    "reward_type": reward_type,
                    "actor_lr": actor_lr,
                    "critic_lr": critic_lr,
                    "ppo_batch_size": ppo_batch_size,
                    "ppo_accumulate_steps": ppo_accumulate_steps,
                    "ppo_train_steps_per_worker": ppo_train_steps_per_worker,
                    "actor_num_experts": actor_num_experts,
                    "actor_expert_hidden1": actor_expert_hidden1,
                    "actor_expert_hidden2": actor_expert_hidden2,
                    "actor_encoder_hidden1": actor_encoder_hidden1,
                    "actor_encoder_hidden2": actor_encoder_hidden2,
                    "actor_encoder_hidden3": actor_encoder_hidden3,
                    "actor_gating_threshold": actor_gating_threshold,
                    "actor_gating_top_k": actor_gating_top_k,
                    "actor_routing_entropy_weight": actor_routing_entropy_weight,
                    "actor_routing_diversity_weight": actor_routing_diversity_weight,
                    "global_camera_width": global_camera_width,
                    "global_camera_height": global_camera_height,
                    "global_camera_samples": global_camera_samples,
                },
            )

            # Log directory paths as config
            wandb.config.update(
                {
                    "paths/root": rl_dirs.root_path,
                    "paths/log": rl_dirs.log_path,
                    "paths/debug_log": rl_dirs.debug_log_path,
                    "paths/results": rl_dirs.results_path,
                    "paths/records": rl_dirs.records_path,
                    "paths/ckpt": rl_dirs.ckpt_path,
                }
            )

            # Log source code using wandb.run.log_code()
            def include_fn(path: str, root: str) -> bool:
                """Include .py files and .rsc files, excluding venv directory."""
                rel_path = os.path.relpath(path, root)
                # Exclude venv directory
                if rel_path.startswith("venv/") or "/venv/" in rel_path:
                    return False
                # Include .py and .rsc files
                return path.endswith(".py") or path.endswith(".rsc")

            wandb.run.log_code(
                root=code_root_dir, name=f"source-{run_name}", include_fn=include_fn
            )
            # Prefer epoch for x-axis while keeping step-based metrics available
            wandb.define_metric("epoch")
            wandb.define_metric("*", step_metric="epoch")
            wandb.define_metric("training/step_*", step_metric="training_step")

            print("WandB Initialized")

            ################################
            # VAE setup
            ################################
            t.set_printoptions(threshold=10_000)
            t.manual_seed(base_seed)
            t.cuda.manual_seed(base_seed)
            np.random.seed(base_seed)

            vae_device = "cpu"  # "cuda" if t.cuda.is_available() else "cpu"
            self.vae = StarVAE.load_from_checkpoint(
                generator_path, map_location=vae_device
            )
            self.vae.eval()
            print("VAE loaded")

            if not os.path.exists("./data/vae_mean_var.data"):
                self.vae_mean, self.vae_var = self.vae.generate_mean_var(128)
                t.save((self.vae_mean, self.vae_var), "./data/vae_mean_var.data")
                t.use_deterministic_algorithms(False)
            else:
                self.vae_mean, self.vae_var = t.load("./data/vae_mean_var.data")
            print("VAE mean and var loaded")
            # Latent dimension from VAE mean vector (shape [latent_dim])
            self.latent_dim = len(self.vae_mean)

            ################################
            # RL setup
            ################################
            grid_size = 32  # 32
            normalize_radius = voxel_size * 64 * 1.2 / 2  # change to 1m later
            self.actor = RobotLatentActorEncoderHeadMoE(
                normalize_radius=normalize_radius,
                grid_resolution=grid_size,
                num_experts=self.actor_num_experts,
                resolution=5,
                actor_hidden1=self.actor_expert_hidden1,
                actor_hidden2=self.actor_expert_hidden2,
                kinematic_hidden_dims=(
                    self.actor_encoder_hidden1,
                    self.actor_encoder_hidden2,
                    self.actor_encoder_hidden3,
                ),
                gating_threshold=(
                    self.actor_gating_threshold
                    if self.actor_gating_threshold > 0
                    else None
                ),
                gating_top_k=(
                    self.actor_gating_top_k if self.actor_gating_top_k > 0 else None
                ),
                latent_dim=self.latent_dim,
            )
            self.critic = RobotLatentEncoderHeadCritic(
                RobotLatentEncoderHeadEncoder(
                    normalize_radius=normalize_radius,
                    grid_resolution=grid_size,
                )
            )

            if self.ppo_checkpoint_path is not None:
                print(f"Loading local ppo checkpoint: {self.ppo_checkpoint_path}")
                try:
                    checkpoint = t.load(self.ppo_checkpoint_path)
                    if isinstance(checkpoint, dict):
                        if "actor_state" in checkpoint and "critic_state" in checkpoint:
                            self.actor.load_state_dict(checkpoint["actor_state"])
                            self.critic.load_state_dict(checkpoint["critic_state"])
                            print(
                                "Successfully load actor and critic model from checkpoint"
                            )

                    else:
                        print(f"Unknown checkpoint format: {type(checkpoint)}")
                except Exception as e:
                    print(f"Error loading checkpoint: {e}")

            actor_ref = ray.put(self.actor)
            critic_ref = ray.put(self.critic)
            print(
                f"Actor single expert encoder size: {count_parameters(self.actor.encoders[0])}"
            )
            if self.actor_num_experts > 1:
                print(
                    f"Actor latent-only gate size (input [B, {self.latent_dim}]): {count_parameters(self.actor.gate)}"
                )
            else: 
                print("Actor uses no gating (single expert)")
            print(
                f"Actor action_expert size: {count_parameters(self.actor.action_experts[0])}"
            )
            print(f"Actor size: {count_parameters(self.actor)}")
            print(f"Critic encoder size: {count_parameters(self.critic.encoder)}")
            print(f"Critic size: {count_parameters(self.critic)}")
            print("RL initialized")

            num_gpus = max(collector_num, trainworker_num)
            self.pgs = []
            for i in range(num_gpus):
                pg = placement_group([{"CPU": 6, "GPU": 1}], strategy="SPREAD")
                ray.get(pg.ready())
                self.pgs.append(pg)

            self.collectors = []
            self.trainers = []
            collector_start = []
            trainer_start = []
            for c_idx in range(num_gpus):
                if c_idx < collector_num:
                    self.collectors.append(
                        RobotSimulationCollector.options(
                            num_cpus=3,
                            num_gpus=0.5,
                            placement_group=self.pgs[c_idx],
                            runtime_env={"env_vars": self.common_env_vars},
                        ).remote()
                    )
                    collector_start.append(
                        self.collectors[-1].set_parameters.remote(
                            c_idx,
                            self.rl_dirs,
                            actor_ref,
                            critic_ref,
                            self.min_actions,
                            self.voxel_size,
                            self.reward_type,
                            self.server_port,
                            self.run_name,
                        )
                    )

                if c_idx < trainworker_num:
                    self.trainers.append(
                        TrainerWorker.options(
                            num_cpus=3,
                            num_gpus=0.5,
                            placement_group=self.pgs[c_idx],
                            runtime_env={"env_vars": self.common_env_vars},
                        ).remote()
                    )
                    trainer_start.append(
                        self.trainers[-1].set_parameters.remote(
                            c_idx,
                            actor_ref,
                            critic_ref,
                            ppo_batch_size,
                            actor_lr,
                            critic_lr,
                            self.actor_routing_entropy_weight,
                            self.actor_routing_diversity_weight,
                            ppo_accumulate_steps,
                            ppo_train_steps_per_worker,
                            trainworker_num,
                            nccl_port,
                        )
                    )
            ray.get(collector_start)
            ray.get(trainer_start)
            print("Collectors created")
            print("Trainers created")

    def load_env_config(self):
        env_config = RS_Config()
        env_config.open([self.env_config_path])
        if not env_config.is_valid:
            print(f"Config invalid: {env_config.invalid_reason}")
            return

        return env_config

    def evolve(self):
        with ExceptionCatcher() as _:
            mean = self.vae_mean.numpy()
            # Latent dimension used for gating; keep in sync with actor.
            latent_dim = len(mean)
            if latent_dim != self.latent_dim:
                print(
                    f"Warning: latent_dim ({latent_dim}) != initialized latent_dim ({self.latent_dim}), updating."
                )
                self.latent_dim = latent_dim
            std = np.sqrt(self.vae_var.numpy())
            morphology_es = cma.CMAEvolutionStrategy(
                x0=np.zeros_like(mean),
                sigma0=1,
                inopts={
                    "seed": self.base_seed,
                    "popsize": self.pop_size,
                    "CMA_stds": std,
                    "bounds": [
                        mean - 6 * std,
                        mean + 6 * std,
                    ],
                    "verbose": 9,
                },
            )
            epoch = 0
            top_k_latents = []
            k = self.best_k_robot_num  # used for top k selection
            robot_design_iter = 0

            env_config = self.load_env_config()

            # Update env config with floor array, save floor data
            with open(os.path.join(self.rl_dirs.log_path, "env_floor.data"), "wb") as f:
                pickle.dump((self.env_floor_array, self.env_floor_size), f)

            update_config_with_array_floor(env_config, self.env_floor_array, self.env_floor_size)
            floor_fig = plot_floor(self.env_floor_array)
            wandb.log(
                {
                    "floor": wandb.Image(floor_fig),
                    "epoch": 0,
                }
            )

            env_config_ref = ray.put(env_config)

            while not morphology_es.stop():
                if epoch > self.max_epochs:
                    print(f"Reached max epochs: {self.max_epochs}, stopping evolution")
                    break

                es_latents = morphology_es.ask()
                robot_design_epoch = epoch

                es_latents_with_top_k = es_latents + top_k_latents
                (
                    valid_result_ids,
                    invalid_result_ids,
                    valid_robot_configs,
                    valid_robot_latents,
                    valid_robot_images,
                    valid_robot_structures,
                ) = self._generate_robots(
                    robot_design_epoch, robot_design_iter, es_latents_with_top_k
                )

                print(f"Valid result ids: {valid_result_ids}")
                print(f"Invalid result ids: {invalid_result_ids}")

                print(
                    f"Robots generated at epoch {robot_design_epoch} iter {robot_design_iter}"
                )
                new_valid_robot_num = sum(
                    1 for id in valid_result_ids if id < len(es_latents)
                )
                valid_robot_ratio = new_valid_robot_num / self.pop_size
                print(
                    f"Valid robot configs: {new_valid_robot_num}, "
                    f"ratio: {valid_robot_ratio}"
                )
                mass_mean, mass_std = population_mass_stats(valid_robot_structures)
                bbox_mean, bbox_std = population_bounding_box_stats(
                    valid_robot_structures
                )
                conn_mean, conn_std = population_connectivity_stats(
                    valid_robot_structures
                )
                path_mean, path_std = population_longest_path_stats(
                    valid_robot_structures
                )
                bones_mean, bones_std = population_segment_count_stats(
                    valid_robot_structures
                )
                diversity_score = diversity_cv(
                    mass_mean,
                    mass_std,
                    bones_mean,
                    bones_std,
                    bbox_mean,
                    bbox_std,
                    conn_mean,
                    conn_std,
                    path_mean,
                    path_std,
                )
                print(
                    f"Mass mean/std: {mass_mean:.3f}/{mass_std:.3f} | "
                    f"Bones mean/std: {bones_mean:.3f}/{bones_std:.3f} | "
                    f"BBox mean/std: {bbox_mean:.3f}/{bbox_std:.3f} | "
                    f"Conn mean/std: {conn_mean:.3f}/{conn_std:.3f} | "
                    f"Path mean/std: {path_mean:.3f}/{path_std:.3f} | "
                    f"Diversity score: {diversity_score:.3f}"
                )
                wandb.log(
                    {
                        "design/valid_robot_ratio": valid_robot_ratio,
                        "diversity/mass_mean": mass_mean,
                        "diversity/mass_std": mass_std,
                        "diversity/bones_mean": bones_mean,
                        "diversity/bones_std": bones_std,
                        "diversity/bbox_mean": bbox_mean,
                        "diversity/bbox_std": bbox_std,
                        "diversity/connectivity_mean": conn_mean,
                        "diversity/connectivity_std": conn_std,
                        "diversity/longest_path_mean": path_mean,
                        "diversity/longest_path_std": path_std,
                        "diversity/diversity_score": diversity_score,
                        "epoch": robot_design_epoch,
                        "robot_design_iter": robot_design_iter,
                    }
                )
                diversity_dir = os.path.join(
                    self.rl_dirs.log_path, "robot_design_diversity_stats"
                )
                os.makedirs(diversity_dir, exist_ok=True)
                stats_path = os.path.join(
                    diversity_dir, f"design_iter_{robot_design_iter}.json"
                )
                with open(stats_path, "w") as f:
                    json.dump(
                        {
                            "epoch": robot_design_epoch,
                            "design_iter": robot_design_iter,
                            "valid_robot_ratio": valid_robot_ratio,
                            "mass_mean": mass_mean,
                            "mass_std": mass_std,
                            "bones_mean": bones_mean,
                            "bones_std": bones_std,
                            "bbox_mean": bbox_mean,
                            "bbox_std": bbox_std,
                            "connectivity_mean": conn_mean,
                            "connectivity_std": conn_std,
                            "longest_path_mean": path_mean,
                            "longest_path_std": path_std,
                            "diversity_score": diversity_score,
                        },
                        f,
                    )

                print(
                    f"Running robot num at epoch {robot_design_epoch} iter {robot_design_iter}: {len(valid_result_ids)}"
                )
                self._save_design_result(
                    robot_design_epoch,
                    robot_design_iter,
                    valid_result_ids,
                    es_latents_with_top_k,
                    valid_robot_images,
                    valid_robot_configs,
                    valid_robot_structures,
                )

                robot_design_iter += 1

                robot_sim_scores = [[] for _ in range(len(es_latents_with_top_k))]

                if valid_robot_ratio >= 0.5:
                    # Create duplicated copies of simulated new and old top-k robot designs
                    simulated_robot_ids = valid_result_ids * self.rollout_num
                    simulated_robot_configs = valid_robot_configs * self.rollout_num
                    simulated_robot_latents = valid_robot_latents * self.rollout_num

                    simulated_robot_configs_ref = ray.put(simulated_robot_configs)
                    simulated_robot_latents_ref = ray.put(simulated_robot_latents)
                    for rollout_epoch in range(self.rollout_epochs):
                        self.error_count = 0

                        print(
                            f"Start rollout at epoch {epoch} rollout_epoch {rollout_epoch}"
                        )
                        episodes = self._rollout(
                            epoch,
                            env_config_ref,
                            simulated_robot_ids,
                            simulated_robot_configs_ref,
                            simulated_robot_latents_ref,
                        )

                        print(f"[MULTI-GPU] Received rollout episodes: {len(episodes)}")
                        episode_length_list = []
                        for episode in episodes:
                            episode_length_list.append(len(episode["episode"]))
                        print(f"Episode length list: {episode_length_list}")

                        for episode in episodes:
                            robot_sim_scores[episode["id"]].append(
                                episode["rectified_distance"]
                            )
                            print(
                                f"Robot {episode['id']} "
                                f"name: {episode['robot_name']} "
                                f"reward: {episode['reward']} "
                                f"rectified_distance: {episode['rectified_distance']} "
                                f"prob: {episode['prob']} "
                            )

                        if len(episodes) > 2 and self.error_count <= 2:
                            print(
                                f"Finish rollout at epoch {epoch} rollout_epoch {rollout_epoch}"
                            )
                            self._store_ppo_episodes(epoch, episodes)
                            print(
                                f"Finish store episodes at epoch {epoch} rollout_epoch {rollout_epoch}"
                            )
                            self._train_ppo(epoch)
                            print(
                                f"Finish train ppo at epoch {epoch} rollout_epoch {rollout_epoch}"
                            )
                            self._update_collector_parameters(epoch)
                            print(
                                f"Finish update collector parameters at epoch {epoch} rollout_epoch {rollout_epoch}"
                            )
                        else:
                            if self.error_count <= 2:
                                print(
                                    f"Too little valid episodes collected for training, skipping ppo"
                                )
                            else:
                                print(f"Too many errors in rollout, skipping ppo")

                        wandb.log(
                            {"system/error_count": self.error_count, "epoch": epoch}
                        )

                        epoch += 1

                else:
                    print("Low valid robot config ratio, skipping rollout")

                robot_report_scores = [1e6] * len(es_latents_with_top_k)

                # CMA-ES will minimize scores
                for valid_id in valid_result_ids:
                    sim_scores = robot_sim_scores[valid_id]
                    if len(sim_scores) > 0:
                        # Use negative max for not missing any good design
                        robot_report_scores[valid_id] = -np.max(sim_scores)
                        print(
                            f"Robot {valid_id}, max-score: {np.max(sim_scores)}, score: {sim_scores}"
                        )

                robot_report_scores = np.array(robot_report_scores)

                print(
                    f"{len(valid_result_ids)} running robots and {len(invalid_result_ids)} invalid robots "
                    f"were duplicated {self.rollout_num} times to update cma-es "
                    f"Final scores of design "
                    f"design epoch {robot_design_epoch} "
                    f"design iter {robot_design_iter}: "
                )

                for r_id, r_sim_scores, r_report_score in zip(
                    range(len(robot_sim_scores)), robot_sim_scores, robot_report_scores
                ):
                    print(
                        f"Robot id={r_id}, sim scores={r_sim_scores}, report score={r_report_score}"
                    )

                if valid_robot_ratio >= 0.5:
                    sorted_indices = np.argsort(robot_report_scores)
                    top_k_indices = sorted_indices[:k]
                    top_k_latents = [
                        es_latents_with_top_k[k_idx] for k_idx in top_k_indices
                    ]
                    print(f"Top {k} indices: {top_k_indices}")

                    robot_id_to_valid_idx_mapping = {}
                    for valid_idx, robot_id in enumerate(valid_result_ids):
                        robot_id_to_valid_idx_mapping[robot_id] = valid_idx

                    plot_images = []
                    for robot_id in sorted_indices[: self.plot_robot_image_num]:
                        if robot_id in robot_id_to_valid_idx_mapping:
                            valid_idx = robot_id_to_valid_idx_mapping[robot_id]
                            plot_images.append(valid_robot_images[valid_idx])

                    if plot_images:
                        robot_image_grid, robot_pil_image_grid = (
                            self._generate_robot_image_grid(plot_images)
                        )
                        wandb.log(
                            {
                                "robots/design": wandb.Image(robot_pil_image_grid),
                                "epoch": robot_design_epoch,
                            }
                        )
                    else:
                        print(f"Warning: No valid robot images found in top {k} robots")

                self._save_design_scores(
                    robot_design_epoch, robot_design_iter, robot_report_scores
                )

                # Remove top_k latent scores and only tell fitness scores
                # related to morpholiogies generated by CMA-ES
                morphology_es.tell(es_latents, robot_report_scores[: self.pop_size])
                morphology_es.logger.add()
                morphology_es.disp()

            wandb.finish()

    def _rollout(
        self,
        epoch: int,
        env_config_ref: ray.ObjectRef,
        robot_ids: List[int],
        robot_configs_ref: ray.ObjectRef,
        robot_latents_ref: ray.ObjectRef,
    ):

        # Generate allocation plan for each collector to ensure no overlap
        total_robots = len(robot_ids)
        sim_per_collector = int(np.ceil(total_robots / self.collector_num))

        # Create randomized allocation plan
        shuffled_indices = list(range(total_robots))
        np.random.shuffle(shuffled_indices)

        all_allocated_robot_config_indices = []
        all_allocated_robot_ids = []
        for i in range(self.collector_num):
            start_idx = i * sim_per_collector
            end_idx = min(start_idx + sim_per_collector, total_robots)
            if start_idx < total_robots:
                collector_indices = shuffled_indices[start_idx:end_idx]
                all_allocated_robot_config_indices.append(collector_indices)
                all_allocated_robot_ids.append(
                    [robot_ids[i] for i in collector_indices]
                )
            else:
                all_allocated_robot_config_indices.append(
                    []
                )  # Empty plan if no robots left
                all_allocated_robot_ids.append([])
        print(
            f"[MULTI-GPU] Robot allocation plans: {all_allocated_robot_config_indices}"
        )
        print(
            f"[MULTI-GPU] Robot allocation plan assigned robot ids: {all_allocated_robot_ids}"
        )
        results = []

        print(f"[MULTI-GPU] Sim per collector: {sim_per_collector}")

        for collector_idx, collector in enumerate(self.collectors):
            allocated_robot_config_indices = all_allocated_robot_config_indices[
                collector_idx
            ]

            if len(allocated_robot_config_indices) > 0:
                results.append(
                    collector.collect.remote(
                        epoch,
                        env_config_ref,
                        robot_ids,
                        robot_configs_ref,
                        robot_latents_ref,
                        allocated_robot_config_indices,
                        self.global_camera_width,
                        self.global_camera_height,
                        self.global_camera_samples,
                        self.hdf5_record_save_interval,
                    )
                )

        performance = {
            "best_robot_metrics": {
                "reward": -np.inf,
                "rectified_distance": -np.inf,
                "collector_rank": None,
                "robot_id": None,
            },
            "collector_metrics": {},
        }

        episodes = []
        for rank, result in enumerate(results):
            print(f"[MULTI-GPU] Get result from collector {rank}")
            try:
                sub_episodes = ray.get(result, timeout=self.collector_timeout)
            except (ray.exceptions.GetTimeoutError, Exception) as e:
                # Handle both timeout errors and general exceptions
                if isinstance(e, ray.exceptions.GetTimeoutError):
                    print(
                        f"Timeout error: cannot get result from collector {rank}: {e}"
                    )
                    # For timeout errors, kill the collector first
                    ray.kill(self.collectors[rank])
                else:
                    print(f"Error in ray.get during rollout: {e}")

                print(f"Exception caught in collector {rank}, returned none, skipping")
                self.error_count = self.error_count + 1

                # Recreate collector for both types of errors
                self.collectors[rank] = RobotSimulationCollector.options(
                    num_cpus=3,
                    num_gpus=0.5,
                    placement_group=self.pgs[rank],
                    runtime_env={"env_vars": self.common_env_vars},
                ).remote()
                ray.get(
                    self.collectors[rank].set_parameters.remote(
                        rank,
                        self.rl_dirs,
                        self.actor,
                        self.critic,
                        self.min_actions,
                        self.voxel_size,
                        self.reward_type,
                        self.server_port,
                        self.run_name,
                    )
                )
                print(f"Collector {rank} recreated")
                continue

            if sub_episodes is None:
                print(
                    f"Other exception caught in collector {rank}, returned none, skipping"
                )
                continue
            filtered_sub_episodes = [
                episode
                for episode in sub_episodes
                if episode["reward"] <= 400
                and episode["rectified_distance"] <= 40
                and len(episode["episode"]) >= self.min_actions
            ]
            performance["collector_metrics"][rank] = [
                {k: v for k, v in episode.items() if k != "episode"}
                for episode in filtered_sub_episodes
            ]
            for sim_index, episode in enumerate(filtered_sub_episodes):
                if (
                    episode["rectified_distance"]
                    > performance["best_robot_metrics"]["rectified_distance"]
                ):
                    performance["best_robot_metrics"] = {
                        "reward": episode["reward"],
                        "rectified_distance": episode["rectified_distance"],
                        "collector_rank": rank,
                        "robot_id": episode["id"],
                        "robot_name": episode["robot_name"],
                    }

            for episode in filtered_sub_episodes:
                if len(episode["episode"]) > 0:
                    episodes.append(episode)

        with open(
            f"{self.rl_dirs.debug_log_path}/debug_e_{epoch}_performance.log",
            "wb",
        ) as f:
            print(f"Best performance data of epoch {epoch}")
            pprint.pprint(performance["best_robot_metrics"])
            pickle.dump(performance, f)

        return episodes

    def _store_ppo_episodes(self, epoch: int, episodes: List[Dict[str, Any]]):
        with ExceptionCatcher() as _:
            episode_lengths = []

            rewards = []
            rectified_distances = []
            costs = []
            probs = []
            routing_entropies = []
            routing_weights = []

            per_step_rewards = []
            per_step_costs = []

            robot_traces = []

            total_episodes = len(episodes)

            base_episodes_per_worker = total_episodes // self.trainworker_num
            remainder = total_episodes % self.trainworker_num

            print(f"Number of episodes: {total_episodes}")
            print(f"Base episodes per worker: {base_episodes_per_worker}")
            print(f"Remainder episodes: {remainder}")
            print(f"Episode length: {len(episodes[-1]['episode']) if episodes else 0}")

            results_list = []
            offset = 0
            for worker_idx in range(self.trainworker_num):
                worker_episode_count = base_episodes_per_worker + (
                    1 if worker_idx < remainder else 0
                )
                worker_episodes = episodes[offset : offset + worker_episode_count]
                offset += worker_episode_count
                if worker_episodes:
                    results_list.append(
                        self.trainers[worker_idx].store_episodes.remote(worker_episodes)
                    )

            for episode in episodes:
                episode_lengths.append(len(episode["episode"]))

                rewards.append(episode["reward"])
                rectified_distances.append(episode["rectified_distance"])
                costs.append(episode["cost"])
                probs.append(episode["prob"])
                routing_entropies.append(episode["routing_entropy"])
                routing_weights.append(
                    (episode["id"], episode["robot_name"], episode["routing_weights"])
                )

                per_step_rewards.append(episode["reward"] / len(episode["episode"]))
                per_step_costs.append(episode["cost"] / len(episode["episode"]))

                robot_traces.append(episode["trace"])

            # save robot_traces to file
            with open(
                f"{self.rl_dirs.results_path}/epoch_{epoch}_robot_traces.pkl", "wb"
            ) as f:
                pickle.dump(robot_traces, f)

            if self.save_trace_img:
                robot_trace_fig = plot_robot_trace(
                    robot_traces=robot_traces,
                    title=f"{self.run_name} - Epoch {epoch} - Robot traces",
                )

                if robot_trace_fig:
                    wandb.log(
                        {
                            "traces/robot_movement": wandb.Image(robot_trace_fig),
                            "epoch": epoch,
                        }
                    )
                    robot_trace_fig.clf()
                    plt.close(robot_trace_fig)
                else:
                    print("No robot trace figure provided")

            # save robot_weights to file
            with open(
                f"{self.rl_dirs.results_path}/epoch_{epoch}_robot_routing_weights.pkl",
                "wb",
            ) as f:
                pickle.dump(routing_weights, f)

            try:
                ray.get(results_list)
            except Exception as e:
                print(f"Error in ray.get during store episodes: {e}")
                self.error_count = self.error_count + 1

            print(f"Finish store episodes at epoch {epoch}")

            # Log all metrics in one grouped wandb.log call
            wandb.log(
                {
                    "episode/max_length": np.max(episode_lengths),
                    "episode/mean_length": np.mean(episode_lengths),
                    "episode/min_length": np.min(episode_lengths),
                    "reward/max": np.max(rewards),
                    "reward/mean": np.mean(rewards),
                    "reward/min": np.min(rewards),
                    "rectified_distance/max": np.max(rectified_distances),
                    "rectified_distance/mean": np.mean(rectified_distances),
                    "rectified_distance/min": np.min(rectified_distances),
                    "cost/max": np.max(costs),
                    "cost/mean": np.mean(costs),
                    "cost/min": np.min(costs),
                    "prob/max": np.max(probs),
                    "prob/mean": np.mean(probs),
                    "prob/min": np.min(probs),
                    "routing_entropy/min": np.min(routing_entropies),
                    "routing_entropy/mean": np.mean(routing_entropies),
                    "routing_entropy/max": np.max(routing_entropies),
                    "per_step_reward/max": np.max(per_step_rewards),
                    "per_step_reward/mean": np.mean(per_step_rewards),
                    "per_step_reward/min": np.min(per_step_rewards),
                    "per_step_cost/max": np.max(per_step_costs),
                    "per_step_cost/mean": np.mean(per_step_costs),
                    "per_step_cost/min": np.min(per_step_costs),
                    "epoch": epoch,
                }
            )

            np.savez(
                os.path.join(self.rl_dirs.log_path, f"epoch={epoch}-metrics.npz"),
                rewards=np.array(rewards),
                rectified_distances=np.array(rectified_distances),
                costs=np.array(costs),
                reward_cost_ratio=np.array(rewards) / (np.array(costs) + 1e-3),
            )

    def _train_ppo(self, epoch: int):
        with ExceptionCatcher() as _:
            result_list = []
            for ppo_trainer in self.trainers:
                result_list.append(ppo_trainer.train_loop.remote())

            try:
                training_result = ray.get(result_list)[0]
            except Exception as e:
                print(f"Error in ray.get during train ppo: {e}")
                self.error_count = self.error_count + 1
                return

            # only rank 0 trainer will return training performance
            epoch_loss = training_result["epoch_loss"]
            step_actor_loss_list = training_result["step_actor_loss_list"]
            step_critic_loss_list = training_result["step_critic_loss_list"]

            wandb.log({"training/epoch_loss": epoch_loss, "epoch": epoch})
            for i in range(len(step_actor_loss_list)):
                wandb.log(
                    {
                        "training/step_actor_loss": step_actor_loss_list[i],
                        "training/step_critic_loss": step_critic_loss_list[i],
                        "training_step": epoch * self.ppo_train_steps_per_worker + i,
                    }
                )

            print(f"Finish train ppo at epoch {epoch}")

    def _update_collector_parameters(self, epoch: int):
        # get parameters from rank 0 trainer
        bool_save_optimizer_state = epoch % self.ppo_save_interval == 0
        if bool_save_optimizer_state:
            actor_state, critic_state, actor_optimizer_state, critic_optimizer_state = (
                ray.get(
                    self.trainers[0].get_parameters.remote(bool_save_optimizer_state)
                )
            )
        else:
            actor_state, critic_state = ray.get(
                self.trainers[0].get_parameters.remote()
            )

        # Update local actor and critic for reference
        self.actor.load_state_dict(actor_state)
        self.critic.load_state_dict(critic_state)

        actor_state_ref = ray.put(actor_state)
        critic_state_ref = ray.put(critic_state)
        try:
            ranks = ray.get(
                [
                    collector.update_parameters.remote(
                        actor_state_ref, critic_state_ref
                    )
                    for collector in self.collectors
                ]
            )
        except Exception as e:
            print(f"Error in ray.get during update collector parameters: {e}")
            self.error_count = self.error_count + 1
            return

        if bool_save_optimizer_state:
            # save checkpoint to local
            checkpoint_path = os.path.join(
                self.rl_dirs.ckpt_path, f"epoch_{epoch}_checkpoint.st"
            )
            t.save(
                {
                    "actor_state": actor_state,
                    "critic_state": critic_state,
                    "actor_optimizer_state": actor_optimizer_state,
                    "critic_optimizer_state": critic_optimizer_state,
                },
                checkpoint_path,
            )

            print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")

        print(f"Finish update parameters for collectors: {ranks}")

    def _generate_robots(self, epoch: int, robot_design_iter: int, es_latents):
        with ExceptionCatcher() as _:
            t.use_deterministic_algorithms(True)
            latents = t.from_numpy(np.array(es_latents)).to(
                device=self.vae.device, dtype=t.float32
            )
            voxels = self.vae.generate_by_latent(latents)
            t.use_deterministic_algorithms(False)

            results = []
            robot_num = voxels.shape[0]
            print(f"Submitting task for building {robot_num} robots")
            begin = time.time()
            for robot_id in range(voxels.shape[0]):
                results.append(
                    build_robot.remote(
                        epoch,
                        robot_design_iter,
                        robot_id,
                        voxels[robot_id].numpy(),
                        self.max_torque,
                        self.voxel_size,
                    )
                )
            print(f"Submission takes {time.time() - begin} s")

            print(f"Waiting for building {robot_num} robots")
            begin = time.time()
            results = ray.get(results)
            print(f"Building {robot_num} robots takes {time.time() - begin} s")

            # Attach corresponding VAE latent (shape [1, latent_dim]) to each build result
            results = [
                r + (latents[robot_id].unsqueeze(0).cpu(),)
                for robot_id, r in enumerate(results)
            ]

            valid_results_id = [i for i, r in enumerate(results) if r[0] is not None]
            invalid_results_id = [i for i, r in enumerate(results) if r[0] is None]
            valid_results = [results[i] for i in valid_results_id]

            valid_robot_configs = [r[0] for r in valid_results]
            valid_robot_images = [r[1] for r in valid_results]
            valid_robot_structures = [r[2] for r in valid_results]
            valid_robot_latents = [r[3] for r in valid_results]
            return (
                valid_results_id,
                invalid_results_id,
                valid_robot_configs,
                valid_robot_latents,
                valid_robot_images,
                valid_robot_structures,
            )

    @staticmethod
    def _generate_robot_image_grid(images: List[Image.Image]):
        with ExceptionCatcher() as _:
            widths, heights = zip(*(img.size for img in images))

            # Calculate the total width and height of the final stacked image
            total_width = max(widths)
            total_height = sum(heights)

            # Create a new blank image with the total dimensions
            image_grid = Image.new("RGB", (total_width, total_height))

            # Paste each image into the new blank image
            y_offset = 0
            for img in images:
                image_grid.paste(img, (0, y_offset))
                y_offset += img.size[1]

            return np.asarray(image_grid).transpose((2, 0, 1)), image_grid

    def _save_design_result(
        self,
        epoch: int,
        robot_design_iter: int,
        valid_robot_ids: List[int],
        es_latents_with_memory: List[Any],
        valid_robot_images: List[Image.Image],
        valid_robot_configs: List[Any],
        valid_robot_structures: List[Any],
    ):
        save_dir = os.path.join(self.rl_dirs.log_path, "robot_info_and_images")
        os.makedirs(save_dir, exist_ok=True)

        # save all latents
        with open(
            f"{save_dir}/design_e_{epoch}_it_{robot_design_iter}_latents.data",
            "wb",
        ) as f:
            pickle.dump(es_latents_with_memory, f)

        valid_robot_info = [
            (
                valid_robot_images[i],
                valid_robot_configs[i],
                valid_robot_structures[i],
            )
            for i in range(len(valid_robot_ids))
        ]

        for idx, robot_id in enumerate(valid_robot_ids):
            with open(
                f"{save_dir}/design_e_{epoch}_it_{robot_design_iter}_id{robot_id}_info.data",
                "wb",
            ) as f:
                pickle.dump(valid_robot_info[idx], f)

            with open(
                f"{save_dir}/design_e_{epoch}_it_{robot_design_iter}_id{robot_id}.png",
                "wb",
            ) as f:
                valid_robot_images[idx].save(f)

        print(
            f"Saved {len(valid_robot_ids)} robot info and images, {len(es_latents_with_memory)} latents to {save_dir}"
        )

    def _save_design_scores(
        self,
        epoch: int,
        robot_design_iter: int,
        robot_report_scores: np.ndarray,
    ):
        with open(
            f"{self.rl_dirs.log_path}/score_e_{epoch}_it_{robot_design_iter}.data",
            "wb",
        ) as f:
            pickle.dump(robot_report_scores, f)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Robot Training")
    parser.add_argument("--nccl_port", type=int, default=12345)
    parser.add_argument("--webserver_port", type=int, default=8002)
    parser.add_argument("--run_name", type=str, default="")
    parser.add_argument(
        "--ray_cuda_devices",
        type=str,
        default="0-7",
        help="CUDA device range for Ray (e.g. '0-3' or '4-7')",
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default=None, help="Load local ppo checkpoint"
    )
    args = parser.parse_args()

    print(f"NCCL port: {args.nccl_port}")
    print(f"Webserver port: {args.webserver_port}")
    print(f"Run name: {args.run_name}")

    # Path configurations
    generator_path = "data/ckpt/vae-v2-epoch=97-val_loss=0.205.ckpt"
    ppo_checkpoint_path = None  # Specify if need to load pre-trained PPO
    env_config_path = "data/env/env_for_basic_locomotion.rsc"
    root_path = "./data/rl-result"
    log_sub_path = "logs"
    debug_log_sub_path = "debug-logs"
    ckpt_sub_path = "ckpt"
    results_sub_path = "results"
    records_sub_path = "records"

    # Simulation and general configurations
    env_floor_array = clear_floor_center(
        scale_floor_height(
            -0.1, 0.1, generate_2d_wave_floor(50, 50, 3, 3)
        ),
        4, 4
    )
    env_floor_size = 5.0  # in meters
    base_seed = 42
    voxel_size = 0.01
    max_torque = 6
    min_actions = 60
    max_epochs = 800

    # Compute configurations
    collector_num = 4
    collector_timeout = 240
    trainworker_num = 4

    # Evolution configurations
    pop_size = 64
    best_k_robot_num = 4
    rollout_num = 1  # Copy times of valid designs
    rollout_epochs = 5

    # RL configurations
    reward_type = "stand"  # reward type: basic / stand
    actor_lr = 6e-5
    critic_lr = 6e-5
    ppo_batch_size = 32
    ppo_accumulate_steps = 1
    ppo_train_steps_per_worker = 40
    actor_num_experts = 4
    actor_expert_hidden1 = 128
    actor_expert_hidden2 = 32
    actor_encoder_hidden1 = 128
    actor_encoder_hidden2 = 192
    actor_encoder_hidden3 = 256
    actor_gating_threshold = 0.0  # 0 disables threshold
    actor_gating_top_k = 0  # 0 disables top-k pruning
    actor_routing_entropy_weight = 0.0  # No punishment for routing entropy with 0.0
    actor_routing_diversity_weight = (
        0.01  # No punishment for routing diversity with 0.0
    )

    # Plot and save configurations
    plot_robot_image_num = 16
    save_trace_img = True
    ppo_save_interval = 10
    hdf5_record_save_interval = 100

    # Web monitoring configurations
    global_camera_width = 1280
    global_camera_height = 720
    global_camera_samples = 4

    print("initialize ray")

    start, end = map(int, args.ray_cuda_devices.split("-"))
    # cuda_visible = ",".join(map(str, range(start, end + 1)))
    # cuda_visible = "0,1,2,3,6,7"
    # cuda_visible = "4,5,6,7"
    cuda_visible = "0,1,2,3"
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["RAY_DEDUP_LOGS"] = "0"

    print(f"Ray visible CUDA devices: {cuda_visible}")

    ray.init(include_dashboard=False, runtime_env={"env_vars": dict(os.environ)})
    print("Ray initialized")

    rl_dirs = RLTrainDirs(
        run_name=args.run_name,
        root_path=root_path,
        log_sub_path=log_sub_path,
        debug_log_sub_path=debug_log_sub_path,
        ckpt_sub_path=ckpt_sub_path,
        results_sub_path=results_sub_path,
        records_sub_path=records_sub_path,
    )
    env_vars = dict(os.environ)

    current_dir = os.path.dirname(os.path.abspath(__file__))
    code_root_dir = os.path.abspath(
        os.path.join(os.path.join(current_dir, os.pardir), os.pardir)
    )
    print("generate robot evolution")
    evo = RobotEvolution.remote(
        run_name=args.run_name,
        code_root_dir=code_root_dir,
        env_vars=env_vars,
        env_config_path=env_config_path,
        generator_path=generator_path,
        ppo_checkpoint_path=ppo_checkpoint_path,
        rl_dirs=rl_dirs,
        env_floor_array=env_floor_array,
        env_floor_size=env_floor_size,
        base_seed=base_seed,
        voxel_size=voxel_size,
        max_torque=max_torque,
        min_actions=min_actions,
        max_epochs=max_epochs,
        collector_num=collector_num,
        collector_timeout=collector_timeout,
        trainworker_num=trainworker_num,
        pop_size=pop_size,
        best_k_robot_num=best_k_robot_num,
        rollout_num=rollout_num,
        rollout_epochs=rollout_epochs,
        reward_type=reward_type,
        actor_lr=actor_lr,
        critic_lr=critic_lr,
        ppo_batch_size=ppo_batch_size,
        ppo_accumulate_steps=ppo_accumulate_steps,
        ppo_train_steps_per_worker=ppo_train_steps_per_worker,
        actor_num_experts=actor_num_experts,
        actor_expert_hidden1=actor_expert_hidden1,
        actor_expert_hidden2=actor_expert_hidden2,
        actor_encoder_hidden1=actor_encoder_hidden1,
        actor_encoder_hidden2=actor_encoder_hidden2,
        actor_encoder_hidden3=actor_encoder_hidden3,
        actor_gating_threshold=actor_gating_threshold,
        actor_gating_top_k=actor_gating_top_k,
        actor_routing_entropy_weight=actor_routing_entropy_weight,
        actor_routing_diversity_weight=actor_routing_diversity_weight,
        plot_robot_image_num=plot_robot_image_num,
        save_trace_img=save_trace_img,
        ppo_save_interval=ppo_save_interval,
        hdf5_record_save_interval=hdf5_record_save_interval,
        global_camera_width=global_camera_width,
        global_camera_height=global_camera_height,
        global_camera_samples=global_camera_samples,
        server_port=args.webserver_port,
        nccl_port=args.nccl_port,
    )
    print("start training robots")
    ray.get(evo.evolve.remote())
    time.sleep(5)
    ray.shutdown()
