'''
Behavioral Topology (BeTop): https://arxiv.org/abs/2409.18031
'''
'''
Pipeline developed upon planTF: 
https://arxiv.org/pdf/2309.10443
'''
import time
from typing import List, Optional, Type
import gc
import math
import numpy as np
import torch
import logging

from nuplan.common.actor_state.ego_state import EgoState
from nuplan.planning.simulation.observation.observation_type import DetectionsTracks, Observation
from nuplan.planning.simulation.planner.abstract_planner import PlannerInitialization, PlannerInput, PlannerReport
from nuplan.planning.simulation.planner.planner_report import MLPlannerReport
from nuplan.planning.simulation.trajectory.abstract_trajectory import AbstractTrajectory
from nuplan.planning.simulation.trajectory.interpolated_trajectory import InterpolatedTrajectory
from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling

# 你的项目特有 import (请确保路径正确)
from src.feature_builders.common.utils import rotate_round_z_axis
from .planner_utils import global_trajectory_to_states, load_checkpoint
from src.planners.pdm_planner.abstract_pdm_planner import AbstractPDMPlanner
from src.planners.pdm_planner.observation.pdm_observation_pred import PDMObservationPred
from src.planners.pdm_planner.simulation.pdm_simulator import PDMSimulator
from src.planners.pdm_planner.scoring.pdm_scorer import PDMScorer
from src.planners.pdm_planner.utils.pdm_emergency_brake import PDMEmergencyBrake
from src.planners.pdm_planner.observation.pdm_observation_utils import get_drivable_area_map
from src.planners.pdm_planner.utils.pdm_path import PDMPath

def wrap_to_pi(theta):
    return (theta+math.pi) % (2*math.pi) - math.pi

class BeTopImitationPlanner(AbstractPDMPlanner):
    """
    Long-term IL-based trajectory planner, with short-term RL-based trajectory tracker.
    Improved with: 
    1. Hybrid Prediction (Model + CV) with Time-0 Alignment.
    2. Safety Veto Strategy for high progress and drivable area compliance.
    """

    requires_scenario: bool = False

    def __init__(
        self,
        planner: TorchModuleWrapper,
        planner_ckpt: str = None,
        replan_interval: int = 1,
        use_gpu: bool = True,
        map_radius: float = 50,
        simulation_metric: str = 'closed_loop_reactive_agents',
    ) -> None:
        super(BeTopImitationPlanner, self).__init__(map_radius)  
        
        if use_gpu:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")

        self._planner = planner
        if hasattr(planner, 'get_list_of_required_feature'):
            self._planner_feature_builder = planner.get_list_of_required_feature()[0]
        else:
            self._planner_feature_builder = planner.feature_builders[0]

        self._planner_ckpt = planner_ckpt
        self._initialization: Optional[PlannerInitialization] = None

        self._future_horizon = 8.0
        self._step_interval = 0.1

        self._replan_interval = replan_interval
        self._last_plan_elapsed_step = replan_interval  # force plan at first step
        self._global_trajectory = None
        self._start_time = None

        # Runtime stats
        self._feature_building_runtimes: List[float] = []
        self._inference_runtimes: List[float] = []
        self.brake = False
        
        self.simulation_metric = simulation_metric

        # PDM components
        trajectory_sampling  = TrajectorySampling(num_poses=80, interval_length=0.1)
        proposal_sampling = TrajectorySampling(num_poses=40, interval_length=0.1)
        self._observation = PDMObservationPred(trajectory_sampling, proposal_sampling, map_radius)
        self._simulator = PDMSimulator(proposal_sampling)
        self._scorer = PDMScorer(proposal_sampling)
        self._emergency_brake = PDMEmergencyBrake(trajectory_sampling)
    
    def model_initialize(self):
        print(self._planner_ckpt)
        torch.set_grad_enabled(False)
        if self._planner_ckpt is not None:
            print('loaded')
            self._planner.load_state_dict(load_checkpoint(self._planner_ckpt))
        self._planner.eval()
        self._planner = self._planner.to(self.device)
    
    def step_initialize(self, initialization: PlannerInitialization):
        self._initialization = initialization
        self._iteration = 0
        self._map_api = initialization.map_api
        self._load_route_dicts(initialization.route_roadblock_ids)
        gc.collect()
        rotate_round_z_axis(np.zeros((1, 2), dtype=np.float64), float(0.0))

    def initialize(self, initialization: PlannerInitialization) -> None:
        torch.set_grad_enabled(False)
        if self._planner_ckpt is not None:
            self._planner.load_state_dict(load_checkpoint(self._planner_ckpt))

        self._planner.eval()
        self._planner = self._planner.to(self.device)
        self._initialization = initialization

        self._iteration = 0
        self._map_api = initialization.map_api
        self._load_route_dicts(initialization.route_roadblock_ids)
        gc.collect()

        # Trigger numba compile
        rotate_round_z_axis(np.zeros((1, 2), dtype=np.float64), float(0.0))

    def name(self) -> str:
        return self.__class__.__name__

    def observation_type(self) -> Type[Observation]:
        return DetectionsTracks 

    def _planning(self, current_input: PlannerInput):
        self._start_time = time.perf_counter()
        planner_feature = self._planner_feature_builder.get_features_from_simulation(
            current_input, self._initialization
        )
        planner_feature_torch = planner_feature.collate(
            [planner_feature.to_feature_tensor().to_device(self.device)]
        )
        self._feature_building_runtimes.append(time.perf_counter() - self._start_time)

        out = self._planner.forward(planner_feature_torch.data)
        if isinstance(out, tuple): out = out[0]
        
        local_trajectory = out["output_trajectory"][0].cpu().numpy().astype(np.float64)
        return local_trajectory
    
    def _update_proposal_manager(self, ego_state: EgoState):
        current_lane = self._get_starting_lane(ego_state)
        create_new_proposals = self._iteration == 0
        if create_new_proposals:
            self._get_proposal_paths(current_lane)
    
    def _get_proposal_paths(self, current_lane):
        centerline_discrete_path = self._get_discrete_centerline(current_lane)
        self._centerline = PDMPath(centerline_discrete_path)
    
    def pdm_plan(self, current_input: PlannerInput):
        logger = logging.getLogger(__name__)

        self._start_time = time.perf_counter()
        planner_feature = self._planner_feature_builder.get_features_from_simulation(
            current_input, self._initialization
        )
        planner_feature_torch = planner_feature.collate(
            [planner_feature.to_feature_tensor().to_device(self.device)]
        )
        self._feature_building_runtimes.append(time.perf_counter() - self._start_time)
        
        planner_result = self._planner.forward(planner_feature_torch.data)
        if isinstance(planner_result, tuple):
            planner_out, mask = planner_result
        else:
            planner_out = planner_result
            mask = None

        self.input_data = planner_feature_torch.data
        self.plan_output = planner_out
        ego_state, observation = current_input.history.current_state

        if isinstance(planner_out['full_trajectory'], tuple):
             proposals_array = planner_out['full_trajectory'][0]
        else:
             proposals_array = planner_out['full_trajectory'][0]

        max_score = planner_out['probability']
        top_k = min(10, max_score.shape[-1])
        plan_probs, score_idx = torch.topk(max_score, k=top_k, dim=-1)
        proposals_array = proposals_array[score_idx[0], :]
        proposals_array = proposals_array.cpu().numpy().astype(np.float64)
        
        arr_list = []
        for i in range(proposals_array.shape[0]):
            arr_list.append(
                self._get_global_trajectory(proposals_array[i], current_input.history.ego_states[-1])
            )
        proposals_array = np.stack(arr_list, axis=0)

        # =========================================================
        # 3. 混合预测 (Hybrid Prediction) & Horizon Alignment
        # =========================================================
        data = planner_feature_torch.data
        all_agent_pos = data["agent"]["position"][:, 1:, 20]     
        all_curr_angle = data["agent"]["heading"][:, 1:, 20]     
        all_agent_vel = data["agent"]["velocity"][:, 1:, 20]     
        all_neighbor_tokens = planner_feature_torch.data['agent_token'][0][1:]
        num_valid_tokens = len(all_neighbor_tokens)

        final_pred_list = []
        final_token_list = []
        processed_token_indices = set()
        
        TARGET_FRAMES = 81 

        if mask is not None:
            mask_cpu = mask[0].cpu().numpy()
            is_index_valid = mask_cpu < num_valid_tokens
            safe_indices = mask_cpu[is_index_valid]
            
            for idx in safe_indices:
                processed_token_indices.add(idx)
            
            batch_indices = torch.arange(all_agent_pos.shape[0], device=self.device)[:, None]
            selected_pos = all_agent_pos[batch_indices, mask]    
            selected_angle = all_curr_angle[batch_indices, mask] 

            pred_xy = planner_out['prediction'][..., :2] + selected_pos[:, :, None, None, :2]
            angle = torch.atan2(planner_out['prediction'][..., 3], planner_out['prediction'][..., 2])
            full_angle = angle + selected_angle[:, :, None, None]
            full_angle = wrap_to_pi(full_angle)
            
            # [B, K, 80, 3]
            model_preds = torch.cat([pred_xy, full_angle[..., None]], dim=-1)[0] 

            if 'pred_probability' in planner_out:
                max_mode = torch.argmax(planner_out['pred_probability'], dim=-1)[0]
                k_range = torch.arange(model_preds.shape[0])
                best_preds = model_preds[k_range, max_mode] 
            else:
                best_preds = model_preds[:, 0]

            best_preds_np = best_preds.cpu().numpy().astype(np.float64)
            
            curr_pos_np = selected_pos[0].cpu().numpy().astype(np.float64)[:, None, :] # [K, 1, 2]
            curr_ang_np = selected_angle[0].cpu().numpy().astype(np.float64)[:, None, None] # [K, 1, 1]
            curr_state_np = np.concatenate([curr_pos_np, curr_ang_np], axis=-1) # [K, 1, 3]

            merged_preds = np.concatenate([curr_state_np, best_preds_np], axis=1) # [K, 1+T, 3]

            current_len = merged_preds.shape[1]
            if current_len != TARGET_FRAMES:
                final_arr = np.zeros((merged_preds.shape[0], TARGET_FRAMES, 3), dtype=np.float64)
                if current_len > TARGET_FRAMES:
                     final_arr = merged_preds[:, :TARGET_FRAMES, :]
                else:
                     final_arr[:, :current_len, :] = merged_preds
                     final_arr[:, current_len:, :] = merged_preds[:, -1:, :]
                merged_preds = final_arr

            for k, token_idx in enumerate(mask_cpu):
                if token_idx < num_valid_tokens:
                    traj_local = merged_preds[k] 
                    traj_global = self._get_global_trajectory(
                        traj_local, current_input.history.ego_states[-1]
                    )
                    final_pred_list.append(traj_global)
                    final_token_list.append(all_neighbor_tokens[token_idx])

        all_indices = set(range(num_valid_tokens))
        missing_indices = list(all_indices - processed_token_indices)
        
        if len(missing_indices) > 0:
            missing_indices_torch = torch.tensor(missing_indices, device=self.device)
            cv_pos = all_agent_pos[0, missing_indices_torch] 
            cv_vel = all_agent_vel[0, missing_indices_torch] 
            cv_ang = all_curr_angle[0, missing_indices_torch] 
            
            future_steps = TARGET_FRAMES - 1 
            dt = 0.1
            time_steps = torch.arange(0, future_steps + 1, device=self.device).float() * dt
            
            pred_cv_xy = cv_pos[:, None, :] + cv_vel[:, None, :] * time_steps[None, :, None]
            pred_cv_heading = cv_ang[:, None, None].repeat(1, len(time_steps), 1)
            
            cv_preds = torch.cat([pred_cv_xy, pred_cv_heading], dim=-1)
            cv_preds_np = cv_preds.cpu().numpy().astype(np.float64)
            
            for i, token_idx in enumerate(missing_indices):
                traj_local = cv_preds_np[i]
                traj_global = self._get_global_trajectory(
                    traj_local, current_input.history.ego_states[-1]
                )
                final_pred_list.append(traj_global)
                final_token_list.append(all_neighbor_tokens[token_idx])

        if len(final_pred_list) > 0:
            pred_input = np.array(final_pred_list, dtype=np.float64)
            token = np.array(final_token_list)
        else:
            pred_input = None
            token = None
        

        self._observation.update(
            ego_state,
            observation,
            current_input.traffic_light_data,
            self._route_lane_dict,
            pred=pred_input,
            token=token,
            behave_occ=None
        )
        self._update_proposal_manager(ego_state)

        simulated_proposals_array = self._simulator.simulate_proposals(
            proposals_array, ego_state
        )
        
        proposal_scores = self._scorer.score_proposals(
            simulated_proposals_array,
            ego_state,
            self._observation,
            self._centerline,
            self._route_lane_dict,
            self._drivable_area_map,
            self._map_api,
        )
        
        trajectory = self._emergency_brake.brake_if_emergency(
            ego_state, proposal_scores, self._scorer
        )
        
        if trajectory is not None:
            brake = True
        else:
            brake = False
            
            if not planner_out['conti_plan']:
                plan_probs = planner_out['max_score'].softmax(-1)[0].cpu().numpy()
            else:
                plan_probs = plan_probs[0].softmax(-1).cpu().numpy()

            sim_metric = getattr(self, 'simulation_metric', 'closed_loop_reactive_agents')
            
            if sim_metric == 'open_loop_boxes':
                full_score = plan_probs
                final_idx = np.argmax(full_score)
            else:
                safe_idx = np.argmax(proposal_scores)
                safe_score_val = proposal_scores[safe_idx]

                AGGRESSIVE_WEIGHT = 0.3 
                hybrid_scores = proposal_scores + AGGRESSIVE_WEIGHT * plan_probs
                hybrid_idx = np.argmax(hybrid_scores)
                
                hybrid_choice_safety_score = proposal_scores[hybrid_idx]
                
                SAFETY_THRESHOLD = 4.0 
                
                if (safe_score_val - hybrid_choice_safety_score) > SAFETY_THRESHOLD:
                    final_idx = safe_idx
                else:
                    final_idx = hybrid_idx
            
            trajectory = proposals_array[final_idx]

        return trajectory, brake

    def compute_planner_trajectory(self, current_input: PlannerInput) -> AbstractTrajectory:
        gc.disable()
        ego_state, _ = current_input.history.current_state
        if self._iteration == 0:
            self._route_roadblock_correction(ego_state)

        self._drivable_area_map = get_drivable_area_map(
            self._map_api, ego_state, self._map_radius
        )

        if self._last_plan_elapsed_step >= self._replan_interval or (self.brake==True):
            self._global_trajectory, self.brake = self.pdm_plan(current_input)
            self._last_plan_elapsed_step = 0
        else:
            self._global_trajectory = self._global_trajectory[1:]

        if self.brake:
            trajectory = self._global_trajectory
        else:
            trajectory = InterpolatedTrajectory(
                trajectory=global_trajectory_to_states(
                    global_trajectory=self._global_trajectory,
                    ego_history=current_input.history.ego_states,
                    future_horizon=len(self._global_trajectory) * self._step_interval,
                    step_interval=self._step_interval,
                )
            )

        self._inference_runtimes.append(time.perf_counter() - self._start_time)
        self._iteration += 1
        self._last_plan_elapsed_step += 1

        return trajectory

    def generate_planner_report(self, clear_stats: bool = True) -> PlannerReport:
        report = MLPlannerReport(
            compute_trajectory_runtimes=self._compute_trajectory_runtimes,
            feature_building_runtimes=self._feature_building_runtimes,
            inference_runtimes=self._inference_runtimes,
        )
        if clear_stats:
            self._compute_trajectory_runtimes: List[float] = []
            self._feature_building_runtimes = []
            self._inference_runtimes = []

        return report

    def _get_global_trajectory(self, local_trajectory: np.ndarray, ego_state: EgoState):
        origin = ego_state.rear_axle.array
        angle = ego_state.rear_axle.heading
        
        arr = np.ascontiguousarray(local_trajectory[..., :2]).astype(np.float64)
        
        global_position = (
            rotate_round_z_axis(arr, -angle)
            + origin
        )
        global_heading = local_trajectory[..., 2] + angle

        global_trajectory = np.concatenate(
            [global_position, global_heading[..., None]], axis=1
        )

        return global_trajectory