"""
统一的编辑处理类，封装四种编辑类型（deletion, replacement, repositioning, insertion）的共同逻辑。
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING

import numpy as np
import cv2

if TYPE_CHECKING:
    from .clip_config import ClipConfig


@dataclass
class EditConfig:
    """编辑配置信息"""
    deletion_origin_id: Optional[str] = None
    replacement_origin_id: Optional[str] = None
    replacement_new_id: Optional[str] = None
    reposition_origin_id: Optional[str] = None
    insertion_new_id: Optional[str] = None
    insertion_do_hole: bool = True
    action_for_reposition: Optional[str] = None


class EditProcessor:
    """统一的编辑处理器，封装四种编辑类型的共同逻辑"""
    
    REPOSITION_DISTANCE_M = 3.0  # 重定位默认平移距离（米）
    
    @staticmethod
    def extract_edit_config(clip_config: 'ClipConfig') -> EditConfig:
        """从 ClipConfig 中提取编辑配置"""
        validation_edit_info = getattr(clip_config, "validation_edit_info", {}) or {}
        if not isinstance(validation_edit_info, dict):
            return EditConfig()
        
        deletion_cfg = validation_edit_info.get("deletion", {}) or {}
        replacement_cfg = validation_edit_info.get("replacement", {}) or {}
        reposition_cfg = validation_edit_info.get("repositioning", {}) or {}
        insertion_cfg = validation_edit_info.get("insertion", {}) or {}
        
        deletion_origin_id = deletion_cfg.get("origin_id") if isinstance(deletion_cfg, dict) else None
        replacement_origin_id = replacement_cfg.get("origin_id") if isinstance(replacement_cfg, dict) else None

        # replacement / insertion 的 candidate 现在是单一的 id，兼容旧的 candidate_ids list 结构
        replacement_candidate = None
        if isinstance(replacement_cfg, dict):
            if "candidate_id" in replacement_cfg:
                replacement_candidate = replacement_cfg.get("candidate_id")
            else:
                old_list = replacement_cfg.get("candidate_ids", []) or []
                if isinstance(old_list, (list, tuple)) and len(old_list) > 0:
                    replacement_candidate = old_list[0]
        replacement_new_id = str(replacement_candidate) if replacement_candidate is not None else None
        
        reposition_origin_id = reposition_cfg.get("origin_id") if isinstance(reposition_cfg, dict) else None

        insertion_candidate = None
        if isinstance(insertion_cfg, dict):
            if "candidate_id" in insertion_cfg:
                insertion_candidate = insertion_cfg.get("candidate_id")
            else:
                old_list_ins = insertion_cfg.get("candidate_ids", []) or []
                if isinstance(old_list_ins, (list, tuple)) and len(old_list_ins) > 0:
                    insertion_candidate = old_list_ins[0]
        insertion_new_id = str(insertion_candidate) if insertion_candidate is not None else None
        insertion_do_hole = bool(insertion_cfg.get("do_hole", True)) if isinstance(insertion_cfg, dict) else True
        
        action_for_reposition = validation_edit_info.get("action_for_reposition", None)
        
        return EditConfig(
            deletion_origin_id=str(deletion_origin_id) if deletion_origin_id is not None else None,
            replacement_origin_id=str(replacement_origin_id) if replacement_origin_id is not None else None,
            replacement_new_id=str(replacement_new_id) if replacement_new_id is not None else None,
            reposition_origin_id=str(reposition_origin_id) if reposition_origin_id is not None else None,
            insertion_new_id=str(insertion_new_id) if insertion_new_id is not None else None,
            insertion_do_hole=insertion_do_hole,
            action_for_reposition=str(action_for_reposition).strip().lower() if action_for_reposition else None,
        )
    
    @staticmethod
    def get_object_ids_to_remove(clip_config: 'ClipConfig') -> List[str]:
        """获取需要删除的对象ID列表（deletion + replacement）"""
        edit_config = EditProcessor.extract_edit_config(clip_config)
        ids_to_remove = []
        
        if edit_config.deletion_origin_id is not None:
            ids_to_remove.append(edit_config.deletion_origin_id)
        if edit_config.replacement_origin_id is not None:
            ids_to_remove.append(edit_config.replacement_origin_id)
        
        return ids_to_remove
    
    @staticmethod
    def get_objects_to_add(clip_config: 'ClipConfig', ply_root: Path) -> List[Tuple[str, Path]]:
        """获取需要添加的对象列表（insertion + replacement + repositioning）"""
        edit_config = EditProcessor.extract_edit_config(clip_config)
        objects_to_add = []
        
        # Insertion: 使用新的ID和模型
        if edit_config.insertion_new_id is not None:
            objects_to_add.append(
                (edit_config.insertion_new_id, ply_root / f"{edit_config.insertion_new_id}.ply")
            )
        
        # Replacement: 使用原始ID，但模型来自candidate
        if edit_config.replacement_origin_id is not None and edit_config.replacement_new_id is not None:
            objects_to_add.append(
                (edit_config.replacement_origin_id, ply_root / f"{edit_config.replacement_new_id}.ply")
            )
        
        # Repositioning: 使用原始ID和模型
        if edit_config.reposition_origin_id is not None:
            objects_to_add.append(
                (edit_config.reposition_origin_id, ply_root / f"{edit_config.reposition_origin_id}.ply")
            )
        
        return objects_to_add
    
    @staticmethod
    def shift_object_tfm_by_action(
        tfm: np.ndarray,
        action: Optional[str],
        distance_m: float = REPOSITION_DISTANCE_M,
    ) -> np.ndarray:
        """
        根据字符串动作（'up'/'down'/'left'/'right'）在物体自身坐标系中平移 object_to_world。
        - up:   朝局部 x 轴正方向
        - down: 朝局部 x 轴负方向
        - left: 朝局部 y 轴正方向
        - right:朝局部 y 轴负方向
        其它 / None：不做变换。
        """
        if action is None:
            return tfm
        if not isinstance(tfm, np.ndarray):
            tfm = np.array(tfm, dtype=np.float32)
        
        action_l = str(action).strip().lower()
        if action_l not in ('up', 'down', 'left', 'right'):
            return tfm
        
        if action_l == 'up':
            dir_vec = tfm[:3, 0]
        elif action_l == 'down':
            dir_vec = -tfm[:3, 0]
        elif action_l == 'left':
            dir_vec = tfm[:3, 1]
        else:  # 'right'
            dir_vec = -tfm[:3, 1]
        
        dir_unit = dir_vec / (np.linalg.norm(dir_vec) + 1e-8)
        tfm2 = tfm.copy()
        tfm2[:3, 3] = tfm2[:3, 3] + dir_unit * float(distance_m)
        return tfm2
    
    @staticmethod
    def get_obb_vertices(lwh: np.ndarray, scale_factor: float) -> np.ndarray:
        """返回OBB的8个顶点（齐次坐标）在其局部坐标系中的位置"""
        half_dims_lwh = (lwh * float(scale_factor)) / 2.0
        
        l_half = half_dims_lwh[0]
        w_half = half_dims_lwh[1]
        z_half = half_dims_lwh[2]
        
        vertices = np.array([
            [l_half, w_half, -z_half],
            [l_half, -w_half, -z_half],
            [-l_half, -w_half, -z_half],
            [-l_half, w_half, -z_half],
            [l_half, w_half, z_half],
            [l_half, -w_half, z_half],
            [-l_half, -w_half, z_half],
            [-l_half, w_half, z_half],
        ], dtype=np.float32)
        
        # (8, 4) 齐次坐标
        return np.hstack([vertices, np.ones((8, 1), dtype=np.float32)])
    
    @staticmethod
    def compute_hole_mask_for_object(
        object_to_world: np.ndarray,
        lwh: np.ndarray,
        w2c: np.ndarray,
        K: np.ndarray,
        H: int,
        W: int,
        scale_factor: float = 1.1,
    ) -> Tuple[Optional[np.ndarray], float, Optional[np.ndarray]]:
        """
        计算单个对象在相机视图中的挖洞mask、最小深度和多边形顶点。
        
        Returns:
            (mask_2d, min_depth, polygon_vertices_2d): 
            - mask_2d为(H, W)的uint8数组
            - min_depth为最小深度值
            - polygon_vertices_2d为多边形顶点坐标
        """
        vertices_local_h = EditProcessor.get_obb_vertices(lwh, scale_factor)
        vertices_world_h = (object_to_world @ vertices_local_h.T).T
        vertices_cam_h = (w2c @ vertices_world_h.T).T
        vertices_cam = vertices_cam_h[:, :3]
        vertices_depths = vertices_cam[:, 2]
        
        front_mask = vertices_depths > 1e-6
        if not np.any(front_mask):
            return None, 0.0, None
        
        min_depth = np.min(vertices_depths[front_mask])
        vertices_cam_front = vertices_cam[front_mask]
        vertices_depths_front = vertices_depths[front_mask]
        
        vert_normalized = vertices_cam_front / vertices_depths_front[:, np.newaxis]
        vert_proj_h = (K @ vert_normalized.T).T
        polygon_vertices_2d = vert_proj_h[:, :2]
        
        if polygon_vertices_2d.shape[0] < 3:
            return None, min_depth, None
        
        shadow_polygon = cv2.convexHull(polygon_vertices_2d, returnPoints=True)
        if shadow_polygon is None or shadow_polygon.shape[0] < 3:
            return None, min_depth, None
        
        remove_mask_2d = np.zeros((H, W), dtype=np.uint8)
        cv2.drawContours(remove_mask_2d, [np.round(shadow_polygon).astype(np.int32)], -1, 255, cv2.FILLED)
        
        return remove_mask_2d, min_depth, polygon_vertices_2d

