import numpy as np
import cv2

def point_line_distance(point, start, end):
    """高维点到线段的距离"""
    if np.all(start == end):
        return np.linalg.norm(point - start)

    line_vec = end - start
    t = np.dot(point - start, line_vec) / np.dot(line_vec, line_vec)
    t = np.clip(t, 0, 1)
    projection = start + t * line_vec
    return np.linalg.norm(point - projection)


def rdp(points, epsilon_ratio=0.01, idxs=None):
    """
    高维RDP算法，使用相对阈值
    points: numpy array, shape = (N, d)
    epsilon_ratio: float, 相对阈值 (0~1)，相对于轨迹对角线长度
    idxs: 保持原始索引，初始调用时自动生成
    """
    if idxs is None:
        idxs = np.arange(points.shape[0])
    # print(idxs)
    if points.shape[0] < 3:
        return points, idxs

    # 计算全局尺度（对角线长度）
    min_pt, max_pt = points.min(axis=0), points.max(axis=0)
    diag_len = np.linalg.norm(max_pt - min_pt)
    epsilon = epsilon_ratio * diag_len

    start, end = points[0], points[-1]

    # 计算点到线段的距离（不含首尾点）
    distances = np.apply_along_axis(
        lambda p: point_line_distance(p, start, end), 
        1, points[1:-1]
    )
    if distances.size == 0:
        return np.vstack((start, end)), np.array([idxs[0], idxs[-1]])

    max_idx = np.argmax(distances)
    max_dist = distances[max_idx]

    if max_dist > epsilon:
        split = max_idx + 1  # 转换为相对全局索引
        left, left_idx = rdp(points[:split+1], epsilon_ratio, idxs=idxs[:split+1])
        right, right_idx = rdp(points[split:], epsilon_ratio, idxs=idxs[split:])
        return np.vstack((left[:-1], right)), np.hstack((left_idx[:-1], right_idx))
    else:
        return np.vstack((start, end)), np.array([idxs[0], idxs[-1]])
    
def rdp_fix_frames(points, n_frames):
    assert points.shape[0] >= n_frames
    ## binary search epsilon s.t. len(rdp(points, epsilon)) == n_frames
    mx_epsilon = 1
    mn_epsilon = 0
    while mx_epsilon - mn_epsilon > 1e-10:
        test_eps = (mx_epsilon + mn_epsilon) / 2
        rdp_points, idx = rdp(points, epsilon_ratio=test_eps)
        if rdp_points.shape[0] == n_frames:
            return rdp_points, idx
        elif rdp_points.shape[0] < n_frames:
            mx_epsilon = test_eps
        else:
            mn_epsilon = test_eps
            
    if rdp_points.shape[0] != n_frames:
        mx_frame_test = ([0],[0])
        for random_try_time in range(10):
            test_eps = mn_epsilon + np.random.normal(0, 1e-3)
            rdp_points, idx = rdp(points, epsilon_ratio=test_eps)
            if rdp_points.shape[0] == n_frames:
                return rdp_points, idx
            if len(idx) > len(mx_frame_test[0]):
                mx_frame_test = (rdp_points, idx)
        
        rpd_points, idx = mx_frame_test
        selected = [i for i in range(n_frames - 1)] + [len(rpd_points) - 1]
        if len(selected) != n_frames:
            ## 最差情况，返回一个均匀采样
            print("WARNING: RDP failed, using uniform sampling")
            idx = np.linspace(0, len(points) - 1, n_frames).astype(int)
            return points[idx], idx
        return rpd_points[selected], idx[selected]


def clip_and_resize(images, H, W):
    """
    将图像裁剪到指定大小，并保持长宽比
    同时为满足cogvideox要求，需要处理为8x+1帧
    images: numpy array, shape = (N, H, W, C)
    """
    N_frame = images.shape[0]
    if N_frame < 9:
        raise ValueError("N_frame must be greater than 8")
    
    if N_frame % 8 != 1:
        N_frame = (N_frame // 8 + 1) * 8 + 1
        images = np.pad(images, ((0, N_frame - images.shape[0]), (0, 0), (0, 0), (0, 0)), mode='edge')
    ori_H, ori_W = images.shape[1], images.shape[2]
    scale = min(H / ori_H, W / ori_W)
    new_H, new_W = int(ori_H * scale), int(ori_W * scale)
    # print(H, W, ori_H, ori_W, new_H, new_W)
    new_img = []
    for img in images:
        test_img = cv2.resize(img, (new_W, new_H))
        # pad to (H, W)
        if new_H < H:
            test_img = np.pad(test_img, (((H - new_H) / 2, (H - new_H) / 2), (0, 0), (0, 0)), mode='constant', constant_values=0)
        if new_W < W:
            test_img = np.pad(test_img, ((0, 0), ((W - new_W) // 2, (W - new_W) // 2), (0, 0)), mode='constant', constant_values=0)
        new_img.append(test_img)
    return np.array(new_img)


def resize_video_to_length(video, target_len=81):
    """
    使用 NumPy 在线性插值时间维度，将 (T,H,W,C) 的视频调整到 target_len 帧
    
    参数
    ----
    video : np.ndarray, shape = (T,H,W,C)
        输入视频
    target_len : int
        目标帧数，默认为 81
    
    返回
    ----
    out : np.ndarray, shape = (target_len,H,W,C)
        插值后的视频
    """
    if video.ndim != 4:
        raise ValueError("video 必须是 (T,H,W,C) 形状")
    
    T, H, W, C = video.shape
    if T == target_len:
        return video.copy()
    if T == 1:
        return np.repeat(video, target_len, axis=0)
    
    # 源时间坐标 (0,...,T-1)
    src_idx = np.arange(T)
    # 目标时间坐标 (等间隔采样到 T-1 之间)
    tgt_idx = np.linspace(0, T-1, target_len)
    
    # 左右索引
    left_idx = np.floor(tgt_idx).astype(int)
    right_idx = np.ceil(tgt_idx).astype(int)
    right_idx = np.clip(right_idx, 0, T-1)
    
    # 插值权重
    w = (tgt_idx - left_idx)[:, None, None, None]  # (target_len,1,1,1)
    
    # 线性插值
    out = (1-w) * video[left_idx] + w * video[right_idx]
    return out.astype(video.dtype)

def resize_state_to_length(state, target_len=81):
    T, _ = state.shape
    if T == target_len:
        return state.copy()
    if T == 1:
        return np.repeat(state, target_len, axis=0)
    tgt_idx = np.linspace(0, T-1, target_len)
    left_idx = np.floor(tgt_idx).astype(int)
    right_idx = np.ceil(tgt_idx).astype(int)
    right_idx = np.clip(right_idx, 0, T-1)
    w = (tgt_idx - left_idx)[:, None]
    out = (1-w) * state[left_idx] + w * state[right_idx]
    return out.astype(state.dtype)
    




if __name__ == "__main__":
    # 示例：7个二维点
    pts = np.array([
        [0, 0],
        [1, 0.1],
        [2, -0.1],
        [3, 5],
        [4, 6.2],
        [5, 6.9],
        [6, 8]
    ])
    simplified, idx = rdp(pts, epsilon_ratio=0.1, idxs=np.arange(pts.shape[0]))
    print("原始点数:", len(pts))
    print("简化后点数:", len(simplified))
    print(simplified)
    print("保留点索引:", idx)
    from matplotlib import pyplot as plt
    plt.plot(pts[:, 0], pts[:, 1], 'o')
    plt.plot(simplified[:, 0], simplified[:, 1], 'x')
    plt.savefig("rdp_example.png")