from typing import List
import ray
import os
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from utils.train_utils import RLTrainDirs, ExceptionCatcher

from rl.algorithms import PPO
from model.robot_simple_latent_conditioned_moe.robot import (
    RobotLatentEncoderHeadCritic,
    RobotLatentActorEncoderHeadMoE,
)
from sim.env.env_for_latent_conditioned_moe import (
    RiseEnvForLatentConditionedMoE,
)
from rise import *
from utils.misc import cuda_to_cpu
from sim.builder.builder_old_for_basic_locomotion import (
    RobotBuilderOldForBasicLocomotion,
)


@ray.remote(num_cpus=2)
def build_robot(
    epoch: int,
    robot_design_iters: int,
    robot_id: int,
    voxels: np.ndarray,
    max_torque: float,
    voxel_size: float,
):
    with ExceptionCatcher() as _:
        builder = RobotBuilderOldForBasicLocomotion(
            voxel_size=voxel_size,
            valid_min_rigid_ratio=0.2,
            valid_min_joint_num=2,
            valid_max_connected_components=1,
            min_rigid_volume=100,
        )
        robot_config, robot_structure = builder.build(
            voxels,
            robot_name=f"bot_e_{epoch}_ro_{robot_id}_iter_{robot_design_iters}",
            hinge_max_torque=max_torque,
            print_summary=True,
        )
        if robot_config is not None:
            pil_image = builder.visualize(interactive=False)
            return robot_config, pil_image, robot_structure
        else:
            return None, None, None


@ray.remote
class RobotSimulationCollector:
    def __init__(self):
        self.rank = None
        self.rl_dirs = None
        self.actor = None
        self.critic = None
        self.min_actions = None
        self.voxel_size = None
        self.ppo = None
        self.reward_type = "stand"
        self.server_port = 8080
        self.run_name = None

    def set_parameters(
        self,
        rank: int,
        rl_dirs: RLTrainDirs,
        actor: RobotLatentActorEncoderHeadMoE,
        critic: RobotLatentEncoderHeadCritic,
        min_actions: int,
        voxel_size: float,
        reward_type: str,
        server_port: int,
        run_name: str,
    ):
        with ExceptionCatcher() as _:
            self.rank = rank
            self.rl_dirs = rl_dirs
            self.actor = actor.to("cuda:0")
            self.critic = critic.to("cuda:0")
            self.min_actions = min_actions
            self.voxel_size = voxel_size
            self.ppo = PPO(actor, critic, optim.AdamW, nn.MSELoss())
            self.reward_type = reward_type
            self.server_port = server_port
            self.run_name = run_name

    def update_parameters(self, actor_state_dict, critic_state_dict):
        with ExceptionCatcher() as _:
            self.actor.load_state_dict(actor_state_dict)
            self.critic.load_state_dict(critic_state_dict)
            return self.rank

    @ray.method(retry_exceptions=True)
    def collect(
        self,
        epoch,
        env_config: RS_Config,
        all_robot_ids: List[int],
        all_robot_configs_batch: list[RS_StructureConfig],
        all_robot_latent: list[t.Tensor],
        allocated_robot_config_indices: List[int],
        global_camera_width: int,
        global_camera_height: int,
        global_camera_samples: int,
        hdf5_record_save_interval: int = 10,
    ):
        with ExceptionCatcher() as _:
            robot_ids = [all_robot_ids[i] for i in allocated_robot_config_indices]
            robot_configs_batch = [
                all_robot_configs_batch[i] for i in allocated_robot_config_indices
            ]
            # Select the latents corresponding to the allocated robot configs
            robot_latent_batch = [
                all_robot_latent[i] for i in allocated_robot_config_indices
            ]

            episodes = []

            if epoch % hdf5_record_save_interval == 0:
                hdf5_rec_config = RS_HDF5RecorderConfig()
                hdf5_rec_config.path = os.path.join(
                    self.rl_dirs.records_path, f"sim_{self.rank}_epoch_{epoch}.hdf5"
                )
                rec_config = RS_RecorderConfig()
                rec_config.type = RSE_RecorderType.RSE_RECORDER_HDF5
                rec_config.config = hdf5_rec_config
                env_config.recorder_configs.append(rec_config)

            # Create environment
            env = RiseEnvForLatentConditionedMoE(
                self.ppo,
                "cuda:0",
                env_config,
                angles=[-1.4, -0.7, 0, 0.7, 1.4],
                rank=self.rank,
                sim_id=f"sim-{self.rank}",
                sim_name=f"Simulation {self.rank}",
                global_camera_width=global_camera_width,
                global_camera_height=global_camera_height,
                global_camera_samples=global_camera_samples,
                title=f"Rise Simulation Viewer - {self.run_name} - epoch {epoch}",
                # webserver_url=None
                webserver_url=f"XXXX:{self.server_port}",
            )

            # Run environment once for this batch
            env.run()
            # Add all robots for this simulation batch
            robot_info = {}  # Store mapping of robot names to their indices
            for tmp_robot_index, (robot_config, robot_latent) in enumerate(
                zip(robot_configs_batch, robot_latent_batch)
            ):
                # Create a duplicated config for safe modification
                tmp_robot_config = deepcopy(robot_config)

                # move robot to start position
                tmp_origin_position = tmp_robot_config.origin_position
                tmp_robot_voxel_size = tmp_robot_config.voxel_size
                tmp_robot_center = (
                    robot_config.bodies[0].x_voxels * tmp_robot_voxel_size / 2,
                    robot_config.bodies[0].y_voxels * tmp_robot_voxel_size / 2,
                )
                # Center robot, and shift robot in y axis by 10 meters
                tmp_robot_config.origin_position = RVec3rf(
                    tmp_origin_position.x - tmp_robot_center[0],
                    tmp_origin_position.y + 10 * tmp_robot_index - tmp_robot_center[1],
                    tmp_origin_position.z,
                )
                tmp_robot_config.orientation = RQuat3rf(0, 0, 0, 1)

                # Create local unique name only for simulation
                # Note: simulation record will use this name
                tmp_robot_name = f"sim_index_{tmp_robot_index}_name_{robot_config.name}"

                # Create a dupolicated config with new simulation name
                tmp_robot_config.name = tmp_robot_name

                env.add_robot(tmp_robot_config, robot_latent)
                robot_info[tmp_robot_name] = tmp_robot_index

            # Wait for simulation to complete
            env.wait_for_end()

            # Process each robot in this batch
            for tmp_robot_name, tmp_robot_index in robot_info.items():
                robot_id = robot_ids[tmp_robot_index]

                observation, action, reward_state = env.get_and_clear_robot_info(
                    tmp_robot_name
                )

                trace = {
                    "robot_id": robot_id,
                    "robot_name": tmp_robot_name,
                    "start_robot_positions": None,
                    "end_robot_positions": None,
                    "com_trace": [],
                }

                if len(observation) >= self.min_actions:
                    episode = []
                    episode_rewards = []
                    episode_costs = []
                    episode_probs = []
                    episode_routing_entropies = []
                    episode_routing_weights = []
                    episode_contact_ratios = []
                    last_a = None

                    start_com = reward_state[0]["com"]
                    end_com = reward_state[-1]["com"]

                    start_voxel_positions = reward_state[0]["voxel_positions"]
                    end_voxel_positions = reward_state[-1]["voxel_positions"]

                    start_x_min = start_voxel_positions[:, 0].min()
                    start_x_max = start_voxel_positions[:, 0].max()
                    start_y_min = start_voxel_positions[:, 1].min()
                    start_y_max = start_voxel_positions[:, 1].max()
                    trace["start_robot_positions"] = (
                        start_x_min,
                        start_x_max,
                        start_y_min - tmp_robot_index * 10,
                        start_y_max - tmp_robot_index * 10,
                    )

                    end_x_min = end_voxel_positions[:, 0].min()
                    end_x_max = end_voxel_positions[:, 0].max()
                    end_y_min = end_voxel_positions[:, 1].min()
                    end_y_max = end_voxel_positions[:, 1].max()
                    trace["end_robot_positions"] = (
                        end_x_min,
                        end_x_max,
                        end_y_min - tmp_robot_index * 10,
                        end_y_max - tmp_robot_index * 10,
                    )

                    for step, (o, a, r) in enumerate(
                        zip(observation[:-1], action[:-1], reward_state[:-1])
                    ):

                        com = r["com"]
                        next_com = reward_state[step + 1]["com"]

                        reward_dir = com - start_com
                        reward_dir[2] = 0
                        reward_norm = np.linalg.norm(reward_dir)

                        if reward_norm > 0:
                            reward_dir /= reward_norm
                        else:
                            reward_dir = np.array([1.0, 0.0, 0.0])

                        if self.reward_type == "stand":
                            # standing reward
                            voxel_num = len(r["voxel_positions"])
                            contact_ratio_weight = 1
                            contact_ratio_min_height = 5
                            contact_ratio = (
                                np.sum(
                                    r["voxel_positions"][:, 2]
                                    < self.voxel_size * contact_ratio_min_height
                                )
                                / voxel_num
                            )
                            movement = next_com - com

                            reward = np.dot(movement, reward_dir) / self.voxel_size
                            reward = max(reward, -0.1)
                            reward -= contact_ratio_weight * contact_ratio

                            episode_contact_ratios.append(contact_ratio)
                        else:
                            # basic reward
                            reward = 0.0
                            if step < len(observation) - 1:
                                next_com = reward_state[step + 1]["com"]
                                movement = next_com - com
                                reward = np.dot(movement, reward_dir) / self.voxel_size
                                reward = max(reward, -0.1)

                        if np.linalg.norm(movement) < 1e-3:
                            reward = -10.0

                        trace["com_trace"].append(
                            (com[0], com[1] - tmp_robot_index * 10)
                        )

                        total = len(a[0])
                        changed = 0
                        if last_a is None:
                            for sub_a in a[0]:
                                if sub_a != 1:
                                    changed += 1
                        else:
                            for last_sub_a, sub_a in zip(last_a, a[0]):
                                if sub_a != last_sub_a:
                                    changed += 1
                        last_a = a[0]
                        action_cost = changed / total

                        episode_rewards.append(reward)
                        episode_costs.append(action_cost)
                        episode_probs.append(t.exp(a[1]).cpu().numpy())
                        episode_routing_entropies.append(
                            np.sum(
                                -r["routing_weight"]
                                * np.log(r["routing_weight"] + 1e-8)
                            )
                        )
                        episode_routing_weights.append(r["routing_weight"])
                        episode.append(
                            {
                                "state": o,
                                "action": {"action": a[0]},
                                "log_prob": a[1],
                                "reward": reward,
                                "terminal": step == len(observation) - 2,
                            }
                        )

                    total_movement = end_com - start_com
                    plane_distance = float(np.linalg.norm(total_movement[:2]))
                    if self.reward_type == "stand":
                        mean_contact_ratio = np.mean(episode_contact_ratios)
                        rectified_distance = (
                            plane_distance / max(mean_contact_ratio, 1e-3)
                            if mean_contact_ratio < 0.25
                            else plane_distance
                        )
                    else:
                        # use place_distance for basic reward in evolution
                        rectified_distance = plane_distance

                    episodes.append(
                        {
                            "id": robot_id,
                            "robot_name": tmp_robot_name,
                            "reward": np.sum(episode_rewards),
                            "rectified_distance": rectified_distance,
                            "cost": np.sum(episode_costs),
                            "prob": np.mean(episode_probs),
                            "routing_entropy": np.mean(episode_routing_entropies),
                            "routing_weights": episode_routing_weights,
                            "trace": trace,
                            "episode": episode,
                        }
                    )

            episodes_cpu = cuda_to_cpu(episodes)
            return episodes_cpu
