import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from pointcloud.toolkits.waymo_helpers import (
    load_poses_calibration,
    get_lane_shift_direction,
)
from .clip_config import ClipConfig


class NovelPoseScheduler:
    def __init__(
        self,
        waymo_root: Path,
        clip_id: str,
        camera_names: List[str],
        segment_len: int,
        lane_change_frames: int,
        accel_frames: int,
        clip_configs: List[ClipConfig],
    ) -> None:
        self.scene_dir = os.path.join(waymo_root, f"segment-{clip_id}_with_camera_labels.tfrecord")
        self.camera_names = list(camera_names)

        ego_frame_poses, ego_cam_poses, extrinsics, _intrinsics = load_poses_calibration(datadir=self.scene_dir)
        self.extrinsics = extrinsics
        self.intrinsics = _intrinsics
        self.ego_frame_poses = ego_frame_poses
        self.num_frames = int(len(ego_frame_poses))

        # 保留配置与参数以便生成不同 index 的轨迹
        self.segment_len = segment_len
        self.lane_change_frames = lane_change_frames
        self.accel_frames = accel_frames
        total_segments = self.num_frames // segment_len

        # 生成基准轨迹：全 skip
        actions_skip = ['skip'] * total_segments
        self.cam_to_c2w_base, _ = self._generate_cam_c2w_from_waymo(
            ego_frame_poses=self.ego_frame_poses,
            extrinsics=self.extrinsics,
            camera_names=self.camera_names,
            actions_by_segment=actions_skip,
            segment_len=self.segment_len,
            base_shift=0.0,
            base_longitudinal=0.0,
            lane_change_frames=self.lane_change_frames,
            accel_frames=self.accel_frames,
            use_fixed_motion=False,
        )

        # 为每个 entry_index 生成独立轨迹
        self.cam_to_c2w_all: Dict[int, Dict[str, np.ndarray]] = {}

        for config in clip_configs:
            entry_index = config.entry_index
            seg_idx_of_entry = config.seg_idx_of_entry
            action_for_seg = config.action_for_seg
            local_base_shift = config.local_base_shift
            local_base_longitudinal = config.local_base_longitudinal
            use_fixed_motion = config.use_fixed_motion

            actions_for_entry = ['skip'] * total_segments
            if 0 <= seg_idx_of_entry < total_segments:
                actions_for_entry[seg_idx_of_entry] = action_for_seg

            poses_for_entry, _ = self._generate_cam_c2w_from_waymo(
                ego_frame_poses=self.ego_frame_poses,
                extrinsics=self.extrinsics,
                camera_names=self.camera_names,
                actions_by_segment=actions_for_entry,
                segment_len=self.segment_len,
                base_shift=local_base_shift,
                base_longitudinal=local_base_longitudinal,
                lane_change_frames=self.lane_change_frames,
                accel_frames=self.accel_frames,
                use_fixed_motion=use_fixed_motion,
            )

            self.cam_to_c2w_all[entry_index] = poses_for_entry

    @staticmethod
    def _cosine_ease_curve(num_frames: int, total_value: float, duration_frames: int) -> np.ndarray:
        duration = max(1, int(duration_frames))
        t = np.arange(num_frames)
        tau = np.clip(t / duration, 0.0, 1.0)
        curve = total_value * 0.5 * (1 - np.cos(np.pi * tau))
        curve[tau >= 1.0] = total_value
        return curve.astype(np.float32)

    @staticmethod
    def _compute_lane_change_curves(num_frames: int, total_shift: float,
                                    lane_change_frames: int) -> Tuple[np.ndarray, np.ndarray]:
        duration = max(1, int(lane_change_frames))
        t = np.arange(num_frames)
        tau = np.clip(t / duration, 0.0, 1.0)
        shift_curve = total_shift * 0.5 * (1 - np.cos(np.pi * tau))
        shift_curve[tau >= 1.0] = total_shift
        yaw_peak_deg = min(5.0, 2.0 * abs(total_shift))
        yaw_peak_rad = np.deg2rad(yaw_peak_deg)
        yaw_curve = yaw_peak_rad * np.sin(np.pi * tau)
        yaw_curve[tau >= 1.0] = 0.0
        return shift_curve.astype(np.float32), yaw_curve.astype(np.float32)

    def _read_config_actions(self, config_path: Optional[str], clip_id: str, num_segments: int) -> List[str]:
        if config_path is None or (not os.path.exists(config_path)):
            return ['skip'] * num_segments
        try:
            with open(config_path, 'r') as f:
                cfg = json.load(f)
        except Exception:
            cfg = {}

        actions: List[str] = ['skip'] * num_segments

        for item in cfg:
            method = item.get('method', '').strip().lower()
            if method != 'sparse':
                continue

            clip_name = str(item.get('clip_name'))
            seg_idx = int(clip_name.rsplit('_', 1)[1])
            ego_transform = str(item.get('ego_transform', '')).strip().lower()
            action_map = {
                'right': 'right',
                'left': 'left',
                'up': 'up',
                'down': 'down',
            }
            actions[seg_idx] = action_map.get(ego_transform, 'skip')

        return actions

    def _generate_cam_c2w_from_waymo(
        self,
        ego_frame_poses: np.ndarray,
        extrinsics: np.ndarray,
        camera_names: List[str],
        actions_by_segment: List[str],
        segment_len: int,
        base_shift: float,
        base_longitudinal: float,
        lane_change_frames: int,
        accel_frames: int,
        use_fixed_motion: bool = False,
    ) -> Tuple[Dict[str, np.ndarray], int]:
        num_frames = int(len(ego_frame_poses))
        total_segments = (num_frames + segment_len - 1) // segment_len
        actions = [(actions_by_segment[i] if i < len(actions_by_segment) else 'skip') for i in range(total_segments)]
        actions = [(a or 'skip').strip().lower() for a in actions]

        name_to_idx = {
            'front': 0,
            'front_left': 1,
            'front_right': 2,
            'side_left': 3,
            'side_right': 4,
        }
        cam_to_c2w_all: Dict[str, np.ndarray] = {cam: np.zeros((num_frames, 4, 4), dtype=np.float32) for cam in camera_names}

        for seg_idx in range(total_segments):
            seg_start = seg_idx * segment_len
            seg_end = min(num_frames, seg_start + segment_len)
            seg_len = seg_end - seg_start
            if seg_len <= 0:
                continue

            action = actions[seg_idx]
            do_shift = action in ['left', 'right']
            do_accel = action in ['up', 'down']
            left_right_sign = -1.0 if action == 'left' else (1.0 if action == 'right' else 0.0)
            up_down_sign = 1.0 if action == 'up' else (-1.0 if action == 'down' else 0.0)

            if use_fixed_motion:
                # 固定值模式：从段开始即施加恒定位移（无视角变化）
                shift_curve = (np.full(seg_len, abs(base_shift), dtype=np.float32)
                               if (do_shift and base_shift != 0.0) else np.zeros(seg_len, dtype=np.float32))
                yaw_curve = np.zeros(seg_len, dtype=np.float32)
                accel_curve = (np.full(seg_len, up_down_sign * abs(base_longitudinal), dtype=np.float32)
                               if (do_accel and base_longitudinal != 0.0) else np.zeros(seg_len, dtype=np.float32))
            else:
                if do_shift and base_shift != 0.0:
                    shift_curve, yaw_curve = self._compute_lane_change_curves(seg_len, abs(base_shift), lane_change_frames)
                else:
                    shift_curve = np.zeros(seg_len, dtype=np.float32)
                    yaw_curve = np.zeros(seg_len, dtype=np.float32)
                if do_accel and base_longitudinal != 0.0:
                    accel_curve = self._cosine_ease_curve(seg_len, up_down_sign * abs(base_longitudinal), accel_frames)
                else:
                    accel_curve = np.zeros(seg_len, dtype=np.float32)

            for local_idx in range(seg_len):
                f = seg_start + local_idx
                lane_shift_direction = get_lane_shift_direction(ego_frame_poses[seg_start:], local_idx).astype(np.float32)
                yaw_delta = 0.0
                if (not use_fixed_motion) and do_shift and base_shift != 0.0:
                    yaw_delta = left_right_sign * float(yaw_curve[local_idx])
                cos_y = float(np.cos(yaw_delta))
                sin_y = float(np.sin(yaw_delta))
                Rz = np.array([[cos_y, -sin_y, 0.0],
                               [sin_y,  cos_y, 0.0],
                               [0.0,    0.0,   1.0]], dtype=np.float32)

                R_world = ego_frame_poses[f][:3, :3]
                forward_world = R_world @ np.array([1.0, 0.0, 0.0], dtype=np.float32)
                forward_world = forward_world / (np.linalg.norm(forward_world) + 1e-8)

                dtrans_world = np.zeros(3, dtype=np.float32)
                if do_shift and base_shift != 0.0:
                    dtrans_world += left_right_sign * lane_shift_direction * float(shift_curve[local_idx])
                if do_accel and base_longitudinal != 0.0:
                    dtrans_world += forward_world * float(accel_curve[local_idx])

                for cam in camera_names:
                    cam_idx = name_to_idx.get(cam, None)
                    if cam_idx is None:
                        try:
                            cam_idx = camera_names.index(cam)
                        except Exception:
                            cam_idx = 0
                    ego_pose_world = ego_frame_poses[f].copy()
                    ego_pose_world[:3, 3] = ego_pose_world[:3, 3] + dtrans_world
                    ego_pose_world[:3, :3] = ego_pose_world[:3, :3] @ Rz
                    cam_to_ego = np.asarray(extrinsics[cam_idx]).astype(np.float32)
                    c2w = (ego_pose_world @ cam_to_ego).astype(np.float32)
                    cam_to_c2w_all[cam][f] = c2w

        return cam_to_c2w_all, num_frames

    def get_pose(self, cam: str, frame_idx: int, entry_index: int) -> np.ndarray:
        return self.cam_to_c2w_all[entry_index][cam][frame_idx].astype(np.float32)

    def get_base_pose(self, cam: str, frame_idx: int) -> np.ndarray:
        return self.cam_to_c2w_base[cam][frame_idx].astype(np.float32)


