import av
import os
import csv
import cv2
import torch
import pickle
import numpy as np
from tqdm import tqdm
# from moviepy.editor import *
import multiprocessing as mp
from functools import partial
import matplotlib.pyplot as plt
from moviepy.video.io.VideoFileClip import VideoFileClip
# from pytorch3d.transforms import euler_angles_to_matrix
from natsort import natsorted  # 自然排序（episode0, episode1,...）
from torch.cuda.amp import autocast
import sys
sys.path.append('/projects/colosseum/SAM2Act/sam2Act_COLOSSEUM')
from GLC.slowfast.datasets import utils

import sys
sys.path.append('/projects/colosseum/SAM2Act/sam2Act_COLOSSEUM')
import sam2act_colosseum.mvt.utils as mvt_utils
import sam2act_colosseum.rvt.rvt_utils as rvt_utils
from sam2act_colosseum.configs.config import SCENE_BOUNDS
from point_renderer.rvt_renderer import RVTBoxRenderer as BoxRenderer
from sam2act_colosseum.utils.tools import stack_on_channel, _norm_rgb, get_stored_demo, CAMERAS

def visualize_projection(video_path, csv_path, frame_index=0, output_img_path="projection_visualization.png"):
    """
    在指定帧上可视化投影点（无pandas版本）
    :param video_path: 视频文件路径
    :param csv_path: 生成的CSV文件路径
    :param frame_index: 要显示的帧序号 (默认为第0帧)
    :param output_img_path: 输出图像保存路径
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 跳转到指定帧
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    if not ret:
        cap.release()
        raise ValueError("无法读取指定帧")
    
    # 读取CSV数据（无pandas版本）
    timestamps = []
    x_coords = []
    y_coords = []
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            timestamps.append(float(row['timestamp']))
            x_coords.append(float(row['x_norm']))
            y_coords.append(float(row['y_norm']))
    
    if frame_index >= len(timestamps):
        cap.release()
        raise ValueError("帧索引超出CSV数据范围")
    
    # 获取归一化坐标
    x_norm = x_coords[frame_index]
    y_norm = y_coords[frame_index]
    
    # 转换为像素坐标
    height, width = frame.shape[:2]
    x_pixel = int(x_norm * width)
    y_pixel = int(y_norm * height)
    
    print(f"归一化坐标: ({x_norm}, {y_norm})")
    print(f"像素坐标: ({x_pixel}, {y_pixel})")
    print(f"视频帧尺寸: {width}x{height}")

    # 在帧上绘制标记
    marked_frame = cv2.circle(frame.copy(), 
                             (x_pixel, y_pixel), 
                             radius=3, 
                             color=(0, 255, 0),  # BGR格式：绿色
                             thickness=2)
    
    # 文字位置智能调整（避免超出边界）
    text = f"({x_norm:.2f}, {y_norm:.2f})"
    text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
    
    # 计算文字位置（自动避开边缘）
    text_x = max(10, min(x_pixel - text_size[0]//2, width - text_size[0] - 10))
    text_y = max(30, min(y_pixel - 20, height - 10))
    
    cv2.putText(marked_frame, text, (text_x, text_y), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    # 2. 添加帧信息（固定在左上角）
    cv2.putText(marked_frame, f"Frame: {frame_index}", (10, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
    
    # 保存结果图像（使用OpenCV）
    cv2.imwrite(output_img_path, marked_frame)
    print(f"可视化结果已保存到: {output_img_path}")
    
    # 或者使用matplotlib保存（可选）
    """
    plt.figure(figsize=(12, 6))
    plt.imshow(cv2.cvtColor(marked_frame, cv2.COLOR_BGR2RGB))
    plt.title(f"Frame {frame_index} | Projection at ({x_norm:.3f}, {y_norm:.3f})")
    plt.axis('off')
    plt.savefig(output_img_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    """
    
    cap.release()


def visualize_projection_video(video_path, csv_path, output_video_path="output_tracking.mp4", fps=None):
    """
    生成完整视频，在每一帧上同步显示坐标点轨迹
    
    :param video_path: 输入视频文件路径
    :param csv_path: 包含坐标数据的CSV文件路径
    :param output_video_path: 输出视频路径
    :param fps: 输出视频帧率（默认与输入视频相同）
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 获取视频属性
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    input_fps = cap.get(cv2.CAP_PROP_FPS)
    
    # 设置输出视频
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(
        output_video_path, 
        fourcc, 
        fps if fps else input_fps, 
        (width, height),
        isColor=True
    )
    
    # 读取CSV数据
    timestamps = []
    x_coords = []
    y_coords = []
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            timestamps.append(float(row['timestamp']))
            x_coords.append(float(row['x_norm']))
            y_coords.append(float(row['y_norm']))
    
    # 检查数据长度是否匹配
    min_length = min(frame_count, len(timestamps))
    print(f"视频总帧数: {frame_count} | CSV数据点数: {len(timestamps)}")
    print(f"将处理前 {min_length} 帧")
    
    # 轨迹可视化参数
    trajectory = []
    max_trajectory_length = 20  # 显示最近20个点的轨迹
    
    # 处理每一帧
    for frame_idx in range(min_length):
        ret, frame = cap.read()
        if not ret:
            break
        
        # 获取当前帧坐标
        x_norm = x_coords[frame_idx]
        y_norm = y_coords[frame_idx]
        x_pixel = int(x_norm * width)
        y_pixel = int(y_norm * height)
        
        # 更新轨迹
        trajectory.append((x_pixel, y_pixel))
        if len(trajectory) > max_trajectory_length:
            trajectory.pop(0)
        
        # 绘制轨迹线（渐变色）
        for i in range(1, len(trajectory)):
            alpha = i / len(trajectory)  # 透明度渐变
            color = (0, int(255 * alpha), int(255 * (1-alpha)))  # 绿->黄->红
            cv2.line(frame.copy(), trajectory[i-1], trajectory[i], color, 2)
        
        # 绘制当前点 绿色
        marked_frame = cv2.circle(frame.copy(), (x_pixel, y_pixel), radius=3, color=(0, 255, 0), thickness=2)
        
        # 显示坐标信息
        info_text = f"Frame: {frame_idx}" #  | Pos: ({x_norm:.2f}, {y_norm:.2f})
        cv2.putText(marked_frame, info_text, (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) # 帧信息（固定在左上角）
        
        # 文字位置智能调整（避免超出边界）
        text = f"Pos: ({x_norm:.2f}, {y_norm:.2f})"
        text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)[0]
        
        # 计算文字位置（自动避开边缘）
        text_x = max(10, min(x_pixel - text_size[0]//2, width - text_size[0] - 10))
        text_y = max(30, min(y_pixel - 20, height - 10))
        cv2.putText(marked_frame, text, (text_x, text_y), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        # 写入输出视频
        out.write(marked_frame)
        
        # 进度显示
        if frame_idx % 100 == 0:
            print(f"已处理 {frame_idx}/{min_length} 帧")
    
    # 释放资源
    cap.release()
    out.release()
    cv2.destroyAllWindows()
    
    print(f"视频生成完成，保存至: {output_video_path}")


def get_heatmap(visualize_root, visualize_task, visualize_epoch, visualize_stage, frame_index):
    GAUSSIAN_KERNEL = 19
    MEAN = [0.45, 0.45, 0.45]
    STD = [0.225, 0.225, 0.225]
    
    # 1. 加载视频帧
    video_path = os.path.join(
        visualize_root, "full_scale.gaze", 
        visualize_task, f"ep_{visualize_epoch}_{visualize_stage}.mp4"
    )
    
    target_frame = None
    with av.open(video_path) as container:
        for i, frame in enumerate(container.decode(video=0)):
            if i == frame_index:
                target_frame = frame.to_ndarray(format='rgb24')
                break
    
    if target_frame is None:
        raise ValueError(f"Frame {frame_index} not found in {video_path}")

    # 2. 转换为张量并归一化
    frame_tensor = torch.from_numpy(target_frame).float()  # HWC
    frame_tensor = utils.tensor_normalize(frame_tensor, MEAN, STD)
    
    # 3. 加载标签
    label_paths = [
        os.path.join(visualize_root, "gaze_frame_label", visualize_task, 
                    f"ep_{visualize_epoch}_{visualize_stage}_frame_label.csv"),
        os.path.join(visualize_root, "gaze", visualize_task, 
                    f"ep_{visualize_epoch}_{visualize_stage}.csv")
    ]
    
    label = None
    for path in label_paths:
        if os.path.exists(path):
            with open(path, "r") as f:
                rows = [list(map(float, row)) for i, row in enumerate(csv.reader(f)) if i > 0]
            label = np.array(rows)[frame_index, 1:3]
            break
    
    if label is None:
        raise FileNotFoundError(f"Label not found at paths: {label_paths}")

    # 4. 生成热图（关键修复部分）
    h = max(1, target_frame.shape[0] // 4)  # 确保最小尺寸为1
    w = max(1, target_frame.shape[1] // 4)
    
    if h <= 0 or w <= 0:
        raise ValueError(f"Invalid heatmap size: {h}x{w}. Frame size: {target_frame.shape}")

    heatmap = np.zeros((h, w))
    
    def _get_gaussian_map(arr, center, kernel_size=19, sigma=-1):
        """生成2D高斯热图（带安全校验）"""
        if arr.size == 0:
            raise ValueError("Cannot generate gaussian map for empty array")
            
        sigma = kernel_size / 6 if sigma < 0 else sigma
        center_x, center_y = center
        
        # 确保坐标在有效范围内
        center_x = max(0, min(arr.shape[1]-1, center_x))
        center_y = max(0, min(arr.shape[0]-1, center_y))
        
        x = np.arange(0, arr.shape[1], 1, float)
        y = np.arange(0, arr.shape[0], 1, float)
        y = y[:, np.newaxis]
        
        arr = np.exp(-((x-center_x)**2 + (y-center_y)**2) / (2*sigma**2))
        
        # 安全归一化
        min_val = arr.min()
        max_val = arr.max()
        if max_val - min_val < 1e-6:
            return np.ones_like(arr) / arr.size  # 均匀分布
        return (arr - min_val) / (max_val - min_val)
    
    try:
        gaze_x, gaze_y = label[0] * w, label[1] * h
        heatmap = _get_gaussian_map(heatmap, (gaze_x, gaze_y), GAUSSIAN_KERNEL)
    except Exception as e:
        raise RuntimeError(f"Failed to generate heatmap: {str(e)}")

    # 5. 保存结果
    output_path = f"{visualize_task}_ep{visualize_epoch}_{visualize_stage}_frame{frame_index}.png"
    
    plt.figure(figsize=(5, 5), dpi=300, frameon=False)
    ax = plt.Axes(plt.gcf(), [0., 0., 1., 1.])
    ax.set_axis_off()
    plt.gcf().add_axes(ax)
    
    # 使用更专业的颜色映射：'inferno'或'plasma'
    plt.imshow(heatmap, cmap='inferno', aspect='auto')
    
    # 保存纯净图像（无边框无空白）
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=300)
    plt.close()
    
    print(f"Heatmap saved to {output_path}")
    return heatmap

def save_heatmap_with_image(inputs, heatmaps, save_dir='debug_vis', num_samples=3):
    os.makedirs(save_dir, exist_ok=True)
    
    # 确保输入是张量且位于CPU
    inputs = inputs.detach().cpu().float()  # [B,C,T,H,W]或[B,C,H,W]
    heatmaps = heatmaps.detach().cpu().float()
    
    for i in range(min(num_samples, inputs.size(0))):
        # 获取图像（处理4D/5D情况）
        img = inputs[i] if inputs.dim() == 4 else inputs[i, :, 0]  # [C,H,W]
        
        # 检查通道顺序（RGB或BGR）
        if img.size(0) == 3:  # 确保是RGB顺序
            img = img[[2, 1, 0]] 
        
        # 反归一化（更安全的范围检查）
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)  # 自动归一化到[0,1]
        
        # 转换为numpy并调整维度
        img_np = img.permute(1, 2, 0).numpy()  # [H,W,C]
        hm_np = heatmaps[i, 0].numpy() if heatmaps.dim() == 4 else heatmaps[i].numpy()
        
        # 可视化保存（添加try-catch）
        try:
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(img_np)
            plt.subplot(1, 2, 2)
            plt.imshow(hm_np, cmap='inferno')
            plt.savefig(os.path.join(save_dir, f'sample_{i}.png'))
            plt.close()
        except Exception as e:
            print(f"Visualization failed: {str(e)}")

def save_raw_frame(video_path, frame_index=0, output_img_path="raw_frame.png"):
    """
    仅保存视频中的原始帧（无任何标记和文字）
    :param video_path: 视频文件路径
    :param frame_index: 要保存的帧序号 (默认为第0帧)
    :param output_img_path: 输出图像保存路径
    """
    # 读取视频
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("无法打开视频文件")
    
    # 跳转到指定帧
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
    ret, frame = cap.read()
    if not ret:
        cap.release()
        raise ValueError("无法读取指定帧")
    
    # 直接保存原始帧
    cv2.imwrite(output_img_path, frame)
    print(f"原始帧已保存到: {output_img_path}")
    
    cap.release()

def load_GLC_config(args):
    from slowfast.config.defaults import get_cfg
    from slowfast.models import build_model
    from slowfast.utils.env import checkpoint_pathmgr as pathmgr
    from slowfast.utils.checkpoint import load_checkpoint
    # cfg: TRAIN.BATCH_SIZE 16 TEST.ENABLE False NUM_GPUS 8 TRAIN.CHECKPOINT_FILE_PATH checkpoints/MViT_Ego4D_ckpt.pyth
    # Setup cfg.
    cfg = get_cfg()
    cfg.merge_from_file("configs/MVIT_B_16x4_CONV.yaml")
    cfg.OUTPUT_DIR = "checkpoints/"

    # Build the video model and print model statistics.
    model = build_model(cfg)
    
    # load last checkpoint
    d = os.path.join(cfg.OUTPUT_DIR, "checkpoints")
    names = pathmgr.ls(d) if pathmgr.exists(d) else []
    names = [f for f in names if "checkpoint" in f]
    assert len(names), "No checkpoints found in '{}'.".format(d)
    # Sort the checkpoints by epoch.
    name = sorted(names)[-1]
    last_checkpoint = os.path.join(d, name)
    load_checkpoint(last_checkpoint, model, cfg.NUM_GPUS > 1)

    # Enable eval mode.
    model.eval()

    # Perform the forward pass.
    bs, num_img, img_feat_dim, h, w = img.shape
    inputs_GLC = img[:,3,3:6].view(bs, 3, 1, h, w) # in MVT_Single
    preds = model(inputs_GLC) # size (B, 3, 8, 256, 256) --> (1,3,8,224,224)

    return cfg


def main():
    # RENDER = True
    # path_to_rlbench = '/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future'  

    # save_path = f'{path_to_rlbench}/clips.gaze'
    # untracked_csv = None

    # # pre-save rendered dynamic images/videos
    # root_train = "/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/train"
    # dst_gaze = f'{path_to_rlbench}/full_scale.gaze'
    # # os.makedirs(dst_gaze, exist_ok=True)# DONE
    # # if RENDER:
    # #     batch_convert_render(root_train, dst_gaze)
    # # else:
    # #     batch_convert(root_train, dst_gaze) # DONE

    # # clip_videos(source_path=dst_gaze, save_path=save_path, untrack_csv=untracked_csv) # DONE

    # output_dir = f'{path_to_rlbench}/gaze'  # 用于保存CSV
    # os.makedirs(output_dir, exist_ok=True) # DONE
    # if RENDER:
    #     batch_process_poses_parallel(root_train, output_dir, dst_gaze)
    # else:
    #     batch_process_poses(root_train, output_dir, dst_gaze) # DONE
    # 可视化
    visualize_task = "insert_onto_square_peg"
    visualize_epoch = "0"
    visualize_stage = "st2"
    visualize_root = "/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future"
    video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_{visualize_stage}.mp4"
    csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_{visualize_stage}.csv"
    save_raw_frame(video_path, frame_index=158, output_img_path="raw_frame.png")
    visualize_projection(video_path, csv_path, frame_index=158, output_img_path="output_visualization.png")
    get_heatmap(visualize_root, visualize_task, visualize_epoch, visualize_stage, frame_index=158)
    # visualize_task = "place_shape_in_shape_sorter"
    # visualize_epoch = "12"
    # visualize_root = "/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future"
    # video_path = f"{visualize_root}/full_scale.gaze/{visualize_task}/ep_{visualize_epoch}_st2.mp4"
    # csv_path = f"{visualize_root}/gaze/{visualize_task}/ep_{visualize_epoch}_st2.csv"
    # visualize_projection_video(video_path, csv_path,
    #     output_video_path=f"{visualize_task}_ep{visualize_epoch}_st2.mp4", fps=30  # 可选参数
    # )
    # get_rlbench_frame_label(data_path=path_to_rlbench, save_path=f'{path_to_rlbench}/gaze_frame_label')
    # generate_train_csv("/fs-computility/efm/shared/datasets/Official_Manipulation_Data/sim/colosseum/rlbench/future/clips.gaze", output_file="data/train_rlbench_gaze.csv")


if __name__ == '__main__':
    main()


