import math
import json
import re
import ast
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import numpy as np
from typing import List, Tuple, Union, Optional
import cv2


def get_seq(video_info, num_fr=16):
    image_inputs, video_inputs, real_idx, video_kwargs = video_info
    seq = video_inputs[0].shape[0]
    video_seq_list = []
    for i in range(math.ceil(seq / num_fr)):
        clip = video_inputs[0][i*num_fr:(i+1)*num_fr]
        video_seq_list.append((image_inputs, [clip], real_idx, video_kwargs))
    return video_seq_list


def parse_json_loose(raw):
    """
    Attempt to extract a JSON object from a loosely formatted string that may:
    - be a Python-escaped literal with surrounding quotes
    - include Markdown code fences ```json ... ```
    - contain extra text before/after
    Returns the parsed Python object (typically dict).
    Raises ValueError if nothing valid is found.
    """
    s = raw
    
    # 1. If it's wrapped like a Python string literal, unescape it.
    if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
        try:
            s = ast.literal_eval(s)
        except Exception:
            # ignore; leave s as-is
            pass
    
    s_stripped = s.strip()
    
    # 2. Direct parse attempt.
    try:
        return json.loads(s_stripped)
    except Exception:
        pass
    
    # 3. Find fenced code blocks ```[lang]? ... ```
    fence_re = re.compile(r"```(?:\w+)?\s*(.*?)```", re.DOTALL)
    for block in fence_re.findall(s):
        candidate = block.strip()
        try:
            return json.loads(candidate)
        except Exception:
            continue
    
    # 4. Fallback: grab from first '{' to last '}'.
    start = s.find('{')
    end = s.rfind('}')
    if start != -1 and end != -1 and end > start:
        candidate = s[start:end+1].strip()
        try:
            return json.loads(candidate)
        except Exception:
            pass
    
    raise ValueError("No valid JSON object found in the provided string.")




def process_video_with_bbox(video_tensor, bbox, original_size=None):
    """
    使用边界框裁剪视频张量，归一化后调整回原始大小。
    
    参数:
        video_tensor (torch.Tensor): 输入视频张量，形状为 (T, C, H, W)
        bbox (tuple): 边界框坐标 (x_min, y_min, x_max, y_max)
        original_size (tuple, optional): 要调整回的 (高度, 宽度)。如果为 None，则使用输入视频的尺寸。
        normalize (bool): 是否对裁剪后的视频进行归一化（默认将 [0, 255] 缩放到 [0, 1]）
    
    返回:
        torch.Tensor: 裁剪、归一化并调整大小后的视频张量
    """
    # 获取视频尺寸
    T, C, H, W = video_tensor.shape
    
    # 如果未提供 original_size，则使用输入视频的尺寸
    if original_size is None:
        original_size = (H, W)
    
    # 提取边界框坐标
    x_min, y_min, x_max, y_max = bbox
    
    # 确保边界框坐标在范围内
    x_min = max(0, int(x_min))
    y_min = max(0, int(y_min))
    x_max = min(W, int(x_max))
    y_max = min(H, int(y_max))
    
    # 检查边界框是否有效
    if x_max <= x_min or y_max <= y_min:
        raise ValueError("无效的边界框坐标")
    
    # 裁剪视频张量
    cropped_video = video_tensor[:, :, y_min:y_max, x_min:x_max]
    
    # 调整回原始大小
    resized_video = transforms.functional.resize(
        cropped_video,
        original_size,
        interpolation=InterpolationMode.BICUBIC,
        antialias=True
    ).float()
    
    return resized_video


def draw_bboxes_on_video_tensor(
    video_tensor: torch.Tensor,
    bboxes: List[List[float]],
    frame_indices: Optional[List[int]] = None,
    colors: Optional[List[Tuple[int, int, int]]] = None,
    thickness: int = 2,
    labels: Optional[List[str]] = None,
    font_scale: float = 0.5,
    font_thickness: int = 1
) -> torch.Tensor:
    """
    在视频tensor上绘制边界框
    
    Args:
        video_tensor: 形状为 [T, C, H, W] 的视频tensor，值范围 [0, 1] 或 [0, 255]
        bboxes: 边界框列表，每个bbox格式为 [x1, y1, x2, y2] (归一化坐标 0-1 或像素坐标)
        frame_indices: 指定在哪些帧上绘制bbox，如果为None则在所有帧上绘制
        colors: 每个bbox的颜色 (B, G, R)，如果为None则使用默认颜色
        thickness: 边界框线条粗细
        labels: 每个bbox的标签文本
        font_scale: 字体大小
        font_thickness: 字体粗细
    
    Returns:
        绘制了边界框的视频tensor
    """
    
    # 复制tensor避免修改原始数据
    video_out = video_tensor.clone()
    
    # 获取视频尺寸
    T, C, H, W = video_out.shape
    
    # 检查数据范围并转换为[0, 255]
    if video_out.max() <= 1.0:
        video_out = video_out * 255.0
    
    # 转换为uint8
    video_out = video_out.to(torch.uint8)
    
    # 默认颜色列表 (BGR格式)
    default_colors = [
        (0, 255, 0),    # 绿色
        (255, 0, 0),    # 蓝色
        (0, 0, 255),    # 红色
        (255, 255, 0),  # 青色
        (255, 0, 255),  # 品红
        (0, 255, 255),  # 黄色
        (128, 0, 128),  # 紫色
        (255, 165, 0),  # 橙色
    ]
    
    # 设置默认参数
    if colors is None:
        colors = [default_colors[i % len(default_colors)] for i in range(len(bboxes))]
    
    if frame_indices is None:
        frame_indices = list(range(T))
    
    # 处理每一帧
    for frame_idx in frame_indices:
        if frame_idx >= T:
            continue
            
        # 获取当前帧 [C, H, W]
        frame = video_out[frame_idx]
        
        # 转换为numpy数组 [H, W, C] (BGR格式)
        if C == 3:
            frame_np = frame.permute(1, 2, 0).cpu().numpy()
        else:
            # 如果是灰度图，转换为3通道
            frame_np = frame[0].cpu().numpy()
            frame_np = cv2.cvtColor(frame_np, cv2.COLOR_GRAY2BGR)
        
        # 绘制每个bbox
        for i, bbox in enumerate(bboxes):
            color = colors[i % len(colors)]
            
            # 处理bbox坐标
            x1, y1, x2, y2 = bbox
            
            # 如果坐标是归一化的(0-1)，转换为像素坐标

            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            
            # 确保坐标在图像范围内
            x1 = max(0, min(x1, W-1))
            y1 = max(0, min(y1, H-1))
            x2 = max(0, min(x2, W-1))
            y2 = max(0, min(y2, H-1))
            
            # 绘制矩形框
            cv2.rectangle(frame_np, (x1, y1), (x2, y2), color, thickness)
        
        # 转换回tensor格式
        if C == 3:
            frame_tensor = torch.from_numpy(frame_np).permute(2, 0, 1)
        else:
            frame_gray = cv2.cvtColor(frame_np, cv2.COLOR_BGR2GRAY)
            frame_tensor = torch.from_numpy(frame_gray).unsqueeze(0)
        
        # 更新video tensor
        video_out[frame_idx] = frame_tensor
    

    
    return video_out