from typing import List
import ray
import os
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from utils.train_utils import RLTrainDirs, ExceptionCatcher

from rl.algorithms import PPO
from model.robot_simple_motor_head_moe.robot import (
    RobotMotorHeadCritic,
    RobotActorMoterHeadMoE,
)
from sim.env.env_for_latent_conditioned_moe import (
    RiseEnvForLatentConditionedMoE,
)
from sim.env.env_for_latent_conditioned_mode_hieght_reward import (
    RiseEnvForLatentConditionedMoEHeightReward,
)
from rise import *
from utils.misc import cuda_to_cpu
from reward import (
    reward_component_action_cost,
    reward_component_height,
    reward_component_movement,
    reward_component_standing,
)


@ray.remote
class RobotSimulationCollector:
    def __init__(self):
        self.rank = None
        self.rl_dirs = None
        self.actor = None
        self.critic = None
        self.min_actions = None
        self.observation_frequency = 10
        self.voxel_size = None
        self.ppo = None
        self.reward_weight_config: List[float] = []
        self.movement_direction = np.array([1.0, 0.0, 0.0], dtype=np.float32)
        self.stand_cos_threshold = 0.5
        self.height_threshold = 0.0
        self.server_port = 8080
        self.run_name = None

    def set_parameters(
        self,
        rank: int,
        rl_dirs: RLTrainDirs,
        actor: RobotActorMoterHeadMoE,
        critic: RobotMotorHeadCritic,
        min_actions: int,
        observation_frequency: int,
        voxel_size: float,
        reward_weight_config: List[float],
        movement_direction: List[float],
        stand_cos_threshold: float,
        height_threshold: float,
        server_port: int,
        run_name: str,
    ):
        with ExceptionCatcher() as _:
            self.rank = rank
            self.rl_dirs = rl_dirs
            self.actor = actor.to("cuda:0")
            self.critic = critic.to("cuda:0")
            self.min_actions = min_actions
            self.observation_frequency = observation_frequency
            self.voxel_size = voxel_size
            self.ppo = PPO(actor, critic, optim.AdamW, nn.MSELoss())
            if len(reward_weight_config) != 4:
                raise ValueError(
                    "reward_weight_config must have 4 values: "
                    "[movement, standing, height, action_cost]"
                )
            self.reward_weight_config = list(map(float, reward_weight_config))
            self.movement_direction = np.array(movement_direction, dtype=np.float32)
            self.stand_cos_threshold = float(stand_cos_threshold)
            self.height_threshold = float(height_threshold)
            self.server_port = server_port
            self.run_name = run_name

    def update_parameters(self, actor_state_dict, critic_state_dict):
        with ExceptionCatcher() as _:
            self.actor.load_state_dict(actor_state_dict)
            self.critic.load_state_dict(critic_state_dict)
            return self.rank

    @ray.method(retry_exceptions=True)
    def collect(
        self,
        epoch,
        env_config: RS_Config,
        all_robot_ids: List[int],
        all_robot_configs_batch: list[RS_StructureConfig],
        allocated_robot_config_indices: List[int],
        actor_placeholder_latent_dim: int,
        global_camera_width: int,
        global_camera_height: int,
        global_camera_samples: int,
        hdf5_record_save_interval: int = 10,
    ):
        with ExceptionCatcher() as _:
            robot_ids = [all_robot_ids[i] for i in allocated_robot_config_indices]
            robot_configs_batch = [
                all_robot_configs_batch[i] for i in allocated_robot_config_indices
            ]

            episodes = []

            if epoch % hdf5_record_save_interval == 0:
                hdf5_rec_config = RS_HDF5RecorderConfig()
                hdf5_rec_config.path = os.path.join(
                    self.rl_dirs.records_path, f"sim_{self.rank}_epoch_{epoch}.hdf5"
                )
                rec_config = RS_RecorderConfig()
                rec_config.type = RSE_RecorderType.RSE_RECORDER_HDF5
                rec_config.config = hdf5_rec_config
                env_config.recorder_configs.append(rec_config)

            # Create environment
            env = RiseEnvForLatentConditionedMoEHeightReward(
                self.ppo,
                "cuda:0",
                env_config,
                angles=[-1.4, -0.7, 0, 0.7, 1.4],
                rank=self.rank,
                sim_id=f"sim-{self.rank}",
                sim_name=f"Simulation {self.rank}",
                global_camera_width=global_camera_width,
                global_camera_height=global_camera_height,
                global_camera_samples=global_camera_samples,
                title=f"Rise Simulation Viewer - {self.run_name} - epoch {epoch}",
                # webserver_url=None
                webserver_url=f"XXXX:{self.server_port}",
            )

            # Run environment once for this batch
            env.run()
            # Add all robots for this simulation batch
            robot_info = {}  # Store mapping of robot names to their indices
            for tmp_robot_index, robot_config in enumerate(robot_configs_batch):
                # Create a duplicated config for safe modification
                tmp_robot_config = deepcopy(robot_config)

                # move robot to start position
                tmp_origin_position = tmp_robot_config.origin_position
                tmp_robot_voxel_size = tmp_robot_config.voxel_size
                tmp_robot_center = (
                    robot_config.bodies[0].x_voxels * tmp_robot_voxel_size / 2,
                    robot_config.bodies[0].y_voxels * tmp_robot_voxel_size / 2,
                )
                # Center robot, and shift robot in y axis by 10 meters
                tmp_robot_config.origin_position = RVec3rf(
                    tmp_origin_position.x - tmp_robot_center[0],
                    tmp_origin_position.y + 10 * tmp_robot_index - tmp_robot_center[1],
                    tmp_origin_position.z + 0.1,
                )
                tmp_robot_config.orientation = RQuat3rf(0, 0, 0, 1)

                # Create local unique name only for simulation
                # Note: simulation record will use this name
                tmp_robot_name = f"sim_index_{tmp_robot_index}_name_{robot_config.name}"

                # Create a dupolicated config with new simulation name
                tmp_robot_config.name = tmp_robot_name

                env.add_robot(tmp_robot_config, t.zeros([actor_placeholder_latent_dim]))
                robot_info[tmp_robot_name] = tmp_robot_index

            # Wait for simulation to complete
            env.wait_for_end()

            # Process each robot in this batch
            for tmp_robot_name, tmp_robot_index in robot_info.items():
                robot_id = robot_ids[tmp_robot_index]

                observation, action, reward_state = env.get_and_clear_robot_info(
                    tmp_robot_name
                )

                trace = {
                    "robot_id": robot_id,
                    "robot_name": tmp_robot_name,
                    "start_robot_positions": None,
                    "end_robot_positions": None,
                    "com_trace": [],
                }

                if len(observation) >= self.min_actions:
                    episode = []
                    episode_rewards = []
                    episode_costs = []
                    episode_probs = []
                    episode_movement_rewards = []
                    episode_standing_rewards = []
                    episode_height_rewards = []
                    episode_action_cost_rewards = []
                    last_a = None  # just for legacy action cost calculation
                    prev_action = None

                    start_com = reward_state[0]["com"]
                    end_com = reward_state[-1]["com"]
                    reward_weights = np.asarray(
                        self.reward_weight_config, dtype=np.float32
                    )
                    world_up_vector = np.array([0.0, 0.0, 1.0], dtype=np.float32)

                    start_voxel_positions = reward_state[0]["voxel_positions"]
                    end_voxel_positions = reward_state[-1]["voxel_positions"]

                    start_x_min = start_voxel_positions[:, 0].min()
                    start_x_max = start_voxel_positions[:, 0].max()
                    start_y_min = start_voxel_positions[:, 1].min()
                    start_y_max = start_voxel_positions[:, 1].max()
                    trace["start_robot_positions"] = (
                        start_x_min,
                        start_x_max,
                        start_y_min - tmp_robot_index * 10,
                        start_y_max - tmp_robot_index * 10,
                    )

                    end_x_min = end_voxel_positions[:, 0].min()
                    end_x_max = end_voxel_positions[:, 0].max()
                    end_y_min = end_voxel_positions[:, 1].min()
                    end_y_max = end_voxel_positions[:, 1].max()
                    trace["end_robot_positions"] = (
                        end_x_min,
                        end_x_max,
                        end_y_min - tmp_robot_index * 10,
                        end_y_max - tmp_robot_index * 10,
                    )

                    for step, (o, a, r) in enumerate(
                        zip(observation[:-1], action[:-1], reward_state[:-1])
                    ):

                        com = r["com"]
                        next_com = reward_state[step + 1]["com"]

                        # Compute reward components for this step from motion, pose, and action.
                        movement_reward = reward_component_movement(
                            movement_direction=self.movement_direction,
                            com_prev=com,
                            com_current=next_com,
                            observation_frequency=self.observation_frequency,
                        )
                        standing_reward = reward_component_standing(
                            body_vector=r.get("body_up_vector", None),
                            world_up_vector=world_up_vector,
                            cos_threshold=self.stand_cos_threshold,
                        )
                        seg_pos = r.get("second_segment_position", com)
                        height_reward = reward_component_height(
                            body_height=float(seg_pos[2]),
                            height_threshold=self.height_threshold,
                        )
                        current_action = np.asarray(
                            a[0].cpu().numpy(), dtype=np.float32
                        )
                        action_cost_reward = reward_component_action_cost(
                            current_action=current_action,
                            prev_action=prev_action,
                        )
                        prev_action = current_action

                        reward_components = np.array(
                            [
                                movement_reward,
                                standing_reward,
                                height_reward,
                                action_cost_reward,
                            ],
                            dtype=np.float32,
                        )
                        reward = float(np.sum(reward_components * reward_weights))

                        trace["com_trace"].append(
                            (com[0], com[1] - tmp_robot_index * 10)
                        )

                        total = len(a[0])
                        changed = 0
                        if last_a is None:
                            for sub_a in a[0]:
                                if sub_a != 1:
                                    changed += 1
                        else:
                            for last_sub_a, sub_a in zip(last_a, a[0]):
                                if sub_a != last_sub_a:
                                    changed += 1
                        last_a = a[0]
                        action_cost = changed / total

                        episode_movement_rewards.append(float(movement_reward))
                        episode_standing_rewards.append(float(standing_reward))
                        episode_height_rewards.append(float(height_reward))
                        episode_action_cost_rewards.append(float(action_cost_reward))

                        episode_rewards.append(reward)
                        episode_costs.append(float(action_cost))
                        episode_probs.append(t.exp(a[1]).cpu().numpy())
                        episode.append(
                            {
                                "state": o,
                                "action": {"action": a[0]},
                                "log_prob": a[1],
                                "reward": reward,
                                "terminal": step == len(observation) - 2,
                            }
                        )

                    total_movement = end_com - start_com
                    plane_distance = float(np.linalg.norm(total_movement[:2]))
                    rectified_distance = plane_distance

                    episodes.append(
                        {
                            "id": robot_id,
                            "robot_name": tmp_robot_name,
                            "reward": np.sum(episode_rewards),
                            "reward_component_movement_sum": float(
                                np.sum(episode_movement_rewards)
                            ),
                            "reward_component_standing_sum": float(
                                np.sum(episode_standing_rewards)
                            ),
                            "reward_component_height_sum": float(
                                np.sum(episode_height_rewards)
                            ),
                            "reward_component_action_cost_sum": float(
                                np.sum(episode_action_cost_rewards)
                            ),
                            "rectified_distance": rectified_distance,
                            "cost": np.sum(episode_costs),
                            "prob": np.mean(episode_probs),
                            "trace": trace,
                            "episode": episode,
                        }
                    )

            episodes_cpu = cuda_to_cpu(episodes)
            return episodes_cpu
