import time
import os
import json
from datetime import datetime
import random

DEBUG = False
import torch
import torch.nn as nn
from typing import Tuple, Dict, List, Union, Optional
from PIL import Image
import numpy as np
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt

from decord import VideoReader, cpu
# from .Process_Utils.process_video import adaptive_cutscene_detection
from .Process_Utils.process_messages import create_analysis_messages, videomme_preprocess, extract_score, lvb_preprocess
from .Process_Utils.sample_logic import generate_dynamic_samples, sample_logic
from sklearn.mixture import GaussianMixture
from .base_models.clip_model import (
    initialize_clip_model,
    encode_text_clip,
    encode_images_clip,
    compute_clip_similarity
)


class H_VPModel(nn.Module):
    def __init__(
            self,
            vl_model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",
            device_map: str = "auto",
            clip_model_name: str = "EVA02-L-14",
            attn_implementation=None,
            sub_batch: int = 24,
            min_pixels: int = 128 * 28 * 28,
            max_pixels: int = 768 * 28 * 28,
    ):

        super().__init__()
        self.sub_batch = sub_batch
        self.max_pixels = max_pixels
        self.min_pixels = min_pixels
        self.max_budget = None
        self.max_iteration = 6
        if 'Qwen2.5' in vl_model_name:
            from base_models.Qwen2_5VL import initialize_qwen_model
            self.model_type = "Qwen2.5"
            self.vl_model_set = initialize_qwen_model(
                model_name=vl_model_name,
                device_map=device_map
            )
        elif 'InternVL' in vl_model_name:
            from base_models.InternVL3 import initialize_intern_model
            self.model_type = "InternVL"
            need_quantization = "26B" in vl_model_name or "38B" in vl_model_name
            self.vl_model_set = initialize_intern_model(
                model_name=vl_model_name,
                device_map=device_map,
                load_in_8bit=need_quantization  # 大模型自动使用8bit
            )
            self.max_budget = 64
        elif 'llava-onevision' in vl_model_name.lower():
            from base_models.Llava_OV import initialize_llava_ov_model
            self.model_type = "LLaVA-OV"
            self.vl_model_set = initialize_llava_ov_model(
                model_name=vl_model_name,
                device_map=device_map
            )
            self.max_budget = 32
        elif 'VILA' in vl_model_name:
            from base_models.VILA1_5 import initialize_vila_model
            self.model_type = "VILA-1.5"
            self.attn_implementation = attn_implementation
            self.vl_model_set = initialize_vila_model(
                model_name=vl_model_name,
                device_map=device_map
            )
            self.max_budget = 8

        self.clip_model_name = clip_model_name
        self._initialize_clip_model()
        # 添加相似度缓存
        self.save_similarities = False
        self.text_features = None
        self.N_candidate = 24

    def _initialize_clip_model(self, model_name: Optional[str] = None):
        if model_name:
            self.clip_model_name = model_name
        device_str = "cuda"
        model_info = initialize_clip_model(self.clip_model_name, device_str)
        self.clip_model, self.clip_preprocess_or_processor, self.clip_tokenizer, self.clip_model_type = model_info
    def eval(self):
        self.vl_model_set[0].eval()
        if hasattr(self, 'clip_model'):
            self.clip_model.eval()
        return self

    @property
    def device(self):
        # return self.vl_model_set[0].device
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _batch_inference(self, model_set, input_messages: list, **kwargs):
        if self.model_type == "Qwen2.5":
            from base_models.Qwen2_5VL import qwen_batch_inference, qwen_video_inference
            if (len(input_messages) == 1 and
                    any(content.get('type') == 'video'
                        for msg in input_messages[0]
                        for content in msg.get('content', [])
                        if msg.get('role') == 'user')):
                return qwen_video_inference(
                    self.vl_model_set[0],
                    self.vl_model_set[1],
                    input_messages[0],
                    fps=kwargs.get('video_fps', 1.0),
                    device=self.device,
                    **kwargs
                )
            else:
                return qwen_batch_inference(
                    self.vl_model_set[0],
                    self.vl_model_set[1],
                    input_messages,
                    device=self.device,
                    **kwargs
                )
        elif self.model_type == "InternVL":
            from base_models.InternVL3 import intern_batch_inference
            return intern_batch_inference(
                self.vl_model_set[0],
                self.vl_model_set[1],
                input_messages,
                **kwargs
            )
        elif self.model_type == "LLaVA-OV":
            from base_models.Llava_OV import llava_ov_batch_inference
            responses = llava_ov_batch_inference(
                self.vl_model_set[0],
                self.vl_model_set[1],
                input_messages,
                **kwargs
            )
            return responses
        elif self.model_type == "VILA-1.5":
            from base_models.VILA1_5 import vila_batch_inference
            responses = vila_batch_inference(
                self.vl_model_set[0],
                self.vl_model_set[1],
                input_messages,
                **kwargs
            )
            if responses and isinstance(responses[0], list):
                responses = [r[0] if isinstance(r, list) else r for r in responses]
            return responses

        else:
            raise ValueError(f"Unsupported model type: {self.model_type}")

    @torch.inference_mode()
    def forward(
            self,
            messages: list[dict] | list[list[dict]],
            videos=None,
            min_frame_distance=12,  # 添加最小帧距离参数
            # 新增HISA参数
            hisa_weights=None,
            **generation_kwargs
    ):
        self.text_features = None
        self.clip_model.to(self.device)
        if isinstance(messages[0], dict):
            messages = [messages]
        # assert isinstance(messages, list) and isinstance(messages[0], list) , "Pre-check: messages is not a list !!!"
        if videos is None or not isinstance(videos, list):
            videos = []
            for m in messages:
                try:
                    video = m[1]['content'][0]['video']
                    videos.append(video)
                except:
                    print(f"Pre-check: video analysis failed, message: {m}")
        assert videos is not None, "Pre-check: video is none !!!"
        start_time = time.time()
        bs = len(messages)

        final_messages_batch = []
        for video_idx, (v, message) in enumerate(zip(videos, messages)):
            # 提取query文本
            # query_text = message[1]['content'][1]['text']

            full_text = message[1]['content'][1]['text']
            query_text = videomme_preprocess(full_text)
            # query_text = lvb_preprocess(full_text)
            VR = VideoReader(v, ctx=cpu(0))
            total_frames = len(VR)
            actual_fps = VR.get_avg_fps()
            video_fps = int(actual_fps)
            video_duration = total_frames / actual_fps
            min_frame_distance = int(np.clip(total_frames*0.001, 5, min_frame_distance))
            adaptive_budget = self.calculate_adaptive_budget(video_duration, self.max_budget)

            similarities, raw_sim, frame_indices, norm_params = self.compute_video_query_similarity(
                video_path=v,
                query=query_text,
                fps=1.0,  # 可以调整采样率
                return_frame_indices=True,
            )
            segments, threshold = self.find_relevant_segments(
                similarities,
                frame_indices,
                threshold='adaptive',
                min_segment_frames=2,
                video_fps=video_fps,
                merge_gap_seconds=3.0,
                gmm_coefficient=0.6,
            )
            if segments:
                event_splits = [[seg['start_frame'], seg['end_frame']] for seg in segments]
                if DEBUG:
                    print(f"Found {len(event_splits)} relevant segments")
            else:

                event_splits = []

            sampled_frames, sampled_frames_indices = self.adaptive_initial_sampling(
                VR, event_splits,
                use_sampling=True,
                max_frames_per_event=5,
                min_event_length=10,
                total_budget=adaptive_budget,
                clip_similarities=similarities,
                clip_frame_indices=frame_indices
            )

            if DEBUG:
                sampling_time = time.time() - sampling_start
                print(f"\n[初始采样耗时] {sampling_time:.2f}秒 (采样{len(sampled_frames)}帧)")

                gp_start = time.time()

            final_frames, final_frames_indices = self.iteration_algorithm(
                # structured_descriptions,
                sampled_frames,
                sampled_frames_indices,
                VR,
                query=query_text,
                generation_kwargs=generation_kwargs,
                budget=adaptive_budget,
                min_frame_distance=min_frame_distance,
                video_fps=video_fps,
                i2t_norm_params=norm_params,
                # 传递HISA参数
                hisa_weights=hisa_weights,
                clip_i2t_threshold=suggested_threshold,
                clip_i2i_threshold=0.8,
                clip_similarities=similarities,
                raw_clip_similarities=raw_sim,
                clip_frame_indices=frame_indices
            )

            final_message = [
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": "You are a helpful assistant"}]
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "video",
                            # "type": "image",
                            # "image": final_frames,
                            "video": final_frames,
                            "max_pixels": self.max_pixels,
                            "min_pixels": self.min_pixels,
                            # "total_pixels": 20480 * 28 * 28
                        },
                        {
                            "type": "text",
                            "text": message[1]['content'][1]['text']  # 注意这里也修正了
                            # "text": prompt_template.format(message[1]['content'][1]['text'])
                        }
                    ],
                }
            ]
            final_messages_batch.append(final_message)

            if DEBUG:
                print(f'Video {video_idx + 1}/{bs} processing complete')


        if DEBUG:
            t3 = time.time()
        if self.model_type == "Qwen2.5":
            response = []
            for single_message in final_messages_batch:
                single_response = self._batch_inference(
                    self.vl_model_set,
                    [single_message],
                    **generation_kwargs
                )
                response.append(single_response[0])
        else:
            response = self._batch_inference(
                self.vl_model_set,
                final_messages_batch,
                **generation_kwargs
            )
        # print(f'video {v}, length: {len(VR)}, {query_text}, sampled frames indices: {frame_indices}, final_pred: {response}')
        if DEBUG:
            print('Final QA batch inference done, time cost: ', time.time() - t3)
            print('Total processing time: ', time.time() - start_time)


        return response

    def _process_video_for_clip(
            self,
            video: str,
            text_features: torch.Tensor,
            fps: float,
            batch_size: int,
            query: Union[str, List[str]]
    ):


        if hasattr(self.clip_model, 'parameters'):
            model_device = next(self.clip_model.parameters()).device
            if model_device.type != 'cuda' and torch.cuda.is_available():
                self.clip_model = self.clip_model.cuda()

        # print("==================\n")

        # 统计信息
        stats = {
            'batch_times': [],
            'batch_gpu_memory': [],
            'total_time': 0,
            'peak_gpu_memory': 0,
            'num_batches': 0
        }
        initial_memory = 0
        batch_start_memory = 0
        start_time = time.time()
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            initial_memory = torch.cuda.memory_allocated() / 1024 ** 3  # GB

        VR = VideoReader(video, ctx=cpu(0))
        total_frames = len(VR)
        video_fps = VR.get_avg_fps()

        sample_interval = int(video_fps / fps)
        expected_frame_indices = list(range(0, total_frames, sample_interval))
        if len(expected_frame_indices) == 0 or expected_frame_indices[-1] != total_frames - 1:
            expected_frame_indices.append(total_frames - 1)
        video_path = video
        video_name = video_path.split('/')[-1].split('.')[0]
        cache_dir = ''
        os.makedirs(cache_dir, exist_ok=True)
        cache_path = os.path.join(cache_dir, f"{video_name}_frames{total_frames}_fps{fps}.npz")

        cache_valid = False
        if os.path.exists(cache_path):
            data = np.load(cache_path, allow_pickle=True)
            cached_frame_indices = data['frame_indices'].tolist()
            if cached_frame_indices == expected_frame_indices:
                cache_valid = True
                frame_indices = cached_frame_indices
                all_image_features = torch.from_numpy(data['image_features']).to(self.device)
                with autocast():
                    similarities_raw = compute_clip_similarity(
                        text_features, all_image_features,
                        model=self.clip_model,
                        model_type=self.clip_model_type,
                        normalization_method="none"
                    )

                if DEBUG:
                    print(f"使用有效缓存: {cache_path}")

            else:
                if DEBUG:
                    print(f"  期望indices: {len(expected_frame_indices)}个, 最后={expected_frame_indices[-1]}")
                    print(
                        f"  缓存indices: {len(cached_frame_indices)}个, 最后={cached_frame_indices[-1] if cached_frame_indices else 'N/A'}")
                # 删除无效缓存
                os.remove(cache_path)

        # 如果缓存无效或不存在，重新计算
        if not cache_valid:
            frame_indices = expected_frame_indices
            all_image_features = []  # 新增：保存所有图像特征
            all_similarities = []

            num_batches = (len(frame_indices) + batch_size - 1) // batch_size
            stats['num_batches'] = num_batches

            for i in range(0, len(frame_indices), batch_size):

                batch_indices = frame_indices[i:i + batch_size]
                frames = VR.get_batch(batch_indices).asnumpy()
                pil_images = [Image.fromarray(frame) for frame in frames]

                with autocast():
                    # 编码图像
                    if self.clip_model_type in ["openai", "open_clip", "eva_clip", "long_clip"]:
                        image_features = encode_images_clip(
                            self.clip_model, pil_images,
                            preprocess=self.clip_preprocess_or_processor,
                            processor=None,
                            model_type=self.clip_model_type,
                            device=self.device
                        )
                    else:
                        image_features = encode_images_clip(
                            self.clip_model, pil_images,
                            preprocess=None,
                            processor=self.clip_preprocess_or_processor,
                            model_type=self.clip_model_type,
                            device=self.device
                        )

                    # 保存特征用于缓存
                    all_image_features.append(image_features.cpu())

                    # 计算相似度
                    batch_similarities = compute_clip_similarity(
                        text_features, image_features,
                        model=self.clip_model,
                        model_type=self.clip_model_type,
                        normalization_method="none"
                    )

                all_similarities.append(batch_similarities)

            similarities_raw = torch.cat(all_similarities, dim=1)

            # 保存图像特征到缓存
            all_image_features = torch.cat(all_image_features, dim=0)
            np.savez(cache_path,
                     image_features=all_image_features.cpu().numpy(),
                     frame_indices=np.array(list(frame_indices), dtype=np.int32))
            if DEBUG:
                print(f"保存视频特征到缓存: {cache_path}")

            # 统计总体信息
            if DEBUG:
                stats['total_time'] = time.time() - start_time
                if torch.cuda.is_available():
                    stats['peak_gpu_memory'] = torch.cuda.max_memory_allocated() / 1024 ** 3 - initial_memory

        # 现在对整个视频进行统一归一化
        similarities, norm_params = self._normalize_full_video(similarities_raw)
        similarities = similarities.cpu().numpy()
        # ========== 在这里添加自适应阈值计算 ==========
        # 计算整个视频的相似度统计，用于后续的自适应阈值
        if similarities.size > 0:
            if DEBUG:
                sim_mean = np.mean(similarities)
                sim_std = np.std(similarities)
            sim_percentiles = np.percentile(similarities, [10, 25, 50, 75, 90])

            # 计算建议的阈值（例如：75百分位数）
            suggested_threshold = sim_percentiles[1]  # 75百分位

            # 将阈值信息加入norm_params
            norm_params['suggested_i2t_threshold'] = float(suggested_threshold)
            norm_params['percentiles'] = {
                'p10': float(sim_percentiles[0]),
                'p25': float(sim_percentiles[1]),
                'p50': float(sim_percentiles[2]),
                'p75': float(sim_percentiles[3]),
                'p90': float(sim_percentiles[4])
            }
            #
            # print(f"\n[相似度统计 - 全视频]")
            # print(f"  均值: {sim_mean:.3f}, 标准差: {sim_std:.3f}")
            # print(
            #     f"  百分位数: p25={sim_percentiles[1]:.3f}, p50={sim_percentiles[2]:.3f}, p75={sim_percentiles[3]:.3f}")
            # print(f"  建议阈值: {suggested_threshold:.3f}")
        # ========== 结束 ==========
        # 如果只有一个query，返回一维数组
        if isinstance(query, str) or len(query) == 1:
            similarities = similarities.squeeze(0)

        # ========== 新增：保存相似度 ==========

        return similarities, similarities_raw, frame_indices, stats, norm_params


    @torch.inference_mode()
    def compute_video_query_similarity(
            self,
            video_path: Union[str, List[str], List[List[dict]], VideoReader],
            query: Union[str, List[str], None] = None,
            clip_model_name: Optional[str] = None,
            fps: float = 1.0,
            batch_size: int = 80,
            return_frame_indices: bool = True,
            return_stats: bool = False  # 新增参数
    ):
        """
        计算视频帧与文本查询的CLIP相似度

        参数:
            video_path: 单个视频路径、视频路径列表或messages格式
            query: 查询文本或文本列表（如果video_path是messages格式则可以为None）
            clip_model_name: CLIP模型名称（如果提供则会更新当前模型）
            fps: 视频采样帧率
            batch_size: 批处理大小
            return_frame_indices: 是否返回帧索引

        返回:
            相似度数组和帧索引
        """
        # 如果提供了新的模型名称，则更新CLIP模型
        if clip_model_name and clip_model_name != self.clip_model_name:
            self._initialize_clip_model(clip_model_name)
        if self.text_features is None:
            with autocast():
                # 编码查询文本 - 根据模型类型传递正确的参数
                if self.clip_model_type in ["openai", "open_clip", "eva_clip", "long_clip"]:
                    self.text_features = encode_text_clip(
                        self.clip_model, query,
                        tokenizer=self.clip_tokenizer,  # 传递tokenizer
                        processor=None,  # 不传递processor
                        model_type=self.clip_model_type,
                        device=self.device
                    )
                else:  # huggingface
                    self.text_features = encode_text_clip(
                        self.clip_model, query,
                        tokenizer=None,  # 不传递tokenizer
                        processor=self.clip_preprocess_or_processor,  # 传递processor
                        model_type=self.clip_model_type,
                        device=self.device
                    )
        self.text_features = self.text_features.to(self.device)

        # 处理视频
        similarities, raw_sim, frame_indices, stats, norm_params = self._process_video_for_clip(
            video_path, self.text_features, fps, batch_size, query
        )

        if return_stats:
            return similarities, raw_sim, frame_indices, stats, norm_params
        elif return_frame_indices:
            return similarities, raw_sim, frame_indices, norm_params
        else:
            return similarities, raw_sim

    def find_relevant_segments(self, sims, frame_indices, threshold='adaptive',
                               min_segment_frames=3,  # 这里指的是采样后的帧数
                               video_fps=24.0,
                               merge_gap_seconds=1.0, gmm_coefficient=0.8):
        """
        找到相似度超过阈值的连续区间

        参数:
            similarities: 相似度数组
            frame_indices: 帧索引数组
            threshold: 相似度阈值
            min_segment_frames: 最小片段帧数
            video_fps: 视频帧率
            merge_gap_seconds: 合并间隔小于此秒数的片段

        返回:
            包含片段信息的列表
        """
        segments = []
        similarities = np.power(sims, 1)
        # 自适应阈值计算
        if threshold == 'adaptive':
            # 方案: 高斯混合模型 (适合有明显双峰分布的数据)
            gmm = GaussianMixture(n_components=2, random_state=42)
            gmm.fit(similarities.reshape(-1, 1))
            means = gmm.means_.flatten()
            stds = np.sqrt(gmm.covariances_.flatten())
            # 选择较高均值的组，取其下界
            high_mean_idx = np.argmax(means)
            threshold = means[high_mean_idx] - gmm_coefficient * stds[high_mean_idx]
            threshold = max(threshold, (means[0] + means[1]) / 2)  # 不低于中点

        # 找到所有超过阈值的帧
        above_threshold = similarities >= threshold

        if not np.any(above_threshold):
            return segments

        # 找到连续区间
        current_start = None
        current_frames = []

        for i, (is_above, sim, frame_idx) in enumerate(zip(above_threshold, similarities, frame_indices)):
            if is_above:
                if current_start is None:
                    current_start = i
                    current_frames = [i]
                else:
                    current_frames.append(i)
            else:
                # 当前帧不满足阈值，检查是否需要结束当前片段
                if current_start is not None and len(current_frames) >= min_segment_frames:
                    # 创建片段
                    segment_sims = similarities[current_frames]
                    segments.append({
                        'start_idx': current_start,
                        'end_idx': current_frames[-1],
                        'start_frame': frame_indices[current_start],
                        'end_frame': frame_indices[current_frames[-1]],
                        'start_time': frame_indices[current_start] / video_fps,
                        'end_time': frame_indices[current_frames[-1]] / video_fps,
                        'duration': (frame_indices[current_frames[-1]] - frame_indices[current_start]) / video_fps,
                        'num_frames': len(current_frames),
                        'avg_similarity': segment_sims.mean(),
                        'max_similarity': segment_sims.max(),
                        'min_similarity': segment_sims.min(),
                        'frame_indices': current_frames
                    })
                current_start = None
                current_frames = []

        # 处理最后一个片段
        if current_start is not None and len(current_frames) >= min_segment_frames:
            segment_sims = similarities[current_frames]
            segments.append({
                'start_idx': current_start,
                'end_idx': current_frames[-1],
                'start_frame': frame_indices[current_start],
                'end_frame': frame_indices[current_frames[-1]],
                'start_time': frame_indices[current_start] / video_fps,
                'end_time': frame_indices[current_frames[-1]] / video_fps,
                'duration': (frame_indices[current_frames[-1]] - frame_indices[current_start]) / video_fps,
                'num_frames': len(current_frames),
                'avg_similarity': segment_sims.mean(),
                'max_similarity': segment_sims.max(),
                'min_similarity': segment_sims.min(),
                'frame_indices': current_frames
            })

        # 合并时间间隔很近的片段
        if merge_gap_seconds > 0 and len(segments) > 1:
            merged_segments = []
            current_segment = segments[0]

            for next_segment in segments[1:]:
                gap = next_segment['start_time'] - current_segment['end_time']

                if gap <= merge_gap_seconds:
                    # 合并片段
                    all_frames = current_segment['frame_indices'] + next_segment['frame_indices']
                    all_sims = similarities[all_frames]

                    current_segment = {
                        'start_idx': current_segment['start_idx'],
                        'end_idx': next_segment['end_idx'],
                        'start_frame': current_segment['start_frame'],
                        'end_frame': next_segment['end_frame'],
                        'start_time': current_segment['start_time'],
                        'end_time': next_segment['end_time'],
                        'duration': next_segment['end_time'] - current_segment['start_time'],
                        'num_frames': len(all_frames),
                        'avg_similarity': all_sims.mean(),
                        'max_similarity': all_sims.max(),
                        'min_similarity': all_sims.min(),
                        'frame_indices': all_frames
                    }
                else:
                    merged_segments.append(current_segment)
                    current_segment = next_segment

            merged_segments.append(current_segment)
            segments = merged_segments

        # 按开始时间排序
        segments.sort(key=lambda x: x['start_time'])

        return segments, threshold

    def adaptive_initial_sampling(self, video, event_splits,
                                  use_sampling=True, max_frames_per_event=4,
                                  min_event_length=10,
                                  initial_sampling_ratio=0.5, total_budget=None,
                                  clip_similarities=None, clip_frame_indices=None):  # 新增参数
        """
        在选定的事件区间中进行自适应初始采样
        为GP迭代预留足够的空间

        参数:
            ... (原有参数)
            clip_similarities: CLIP相似度数组（新增）
            clip_frame_indices: CLIP计算时的帧索引（新增）

        返回:
            sampled_frames: 采样的帧数组 (N, H, W, C)
            sampled_frames_indices: 采样的帧索引列表
        """

        # 读取视频
        vr = video
        total_frames = len(vr)

        if not use_sampling:
            # 不采样，返回所有事件区间的所有帧
            all_indices = []
            for event in event_splits:
                start, end = int(event[0]), int(event[1])
                all_indices.extend(range(start, min(end + 1, total_frames)))
            sampled_frames = vr.get_batch(all_indices).asnumpy()
            return sampled_frames, all_indices

        initial_target = max(self.N_candidate, total_budget)

        if DEBUG:
            print(f"\n[初始采样] 采样策略:")
            print(f"  - 总Budget: {total_budget}帧")
            print(f"  - 初始采样目标: {initial_target}帧")
            print(f"  - 事件数: {len(event_splits)}")

        # 3. 准备事件数据
        event_data = []

        # 添加调试信息
        if DEBUG:
            print(f"\n[DEBUG] CLIP数据检查:")
            print(f"  - clip_similarities: {'有' if clip_similarities is not None else '无'}")
            print(f"  - clip_frame_indices: {'有' if clip_frame_indices is not None else '无'}")
        if clip_similarities is not None and DEBUG:
            print(f"  - 相似度范围: [{np.min(clip_similarities):.3f}, {np.max(clip_similarities):.3f}]")

        for event_idx, event in enumerate(event_splits):
            start_frame = int(event[0])
            end_frame = int(event[1])
            length = end_frame - start_frame + 1

            if length > 0 and end_frame < total_frames:
                # 初始化变量
                avg_similarity = 0.0
                event_sims = []  # 移到外面
                clip_count = 0

                # 计算该事件的平均CLIP相似度
                if clip_similarities is not None and clip_frame_indices is not None:
                    for i, frame_idx in enumerate(clip_frame_indices):
                        if start_frame <= frame_idx <= end_frame:
                            event_sims.append(clip_similarities[i])

                    clip_count = len(event_sims)
                    if event_sims:
                        avg_similarity = np.mean(event_sims)
                    # else:
                    #     # 如果事件内没有采样点，这是异常情况
                    #     print(f"[警告] 事件{event_idx + 1} [{start_frame}-{end_frame}] 内没有CLIP采样点!")

                event_data.append({
                    'start': start_frame,
                    'end': end_frame,
                    'length': length,
                    'avg_similarity': avg_similarity,
                    'clip_count': clip_count,
                    'allocated_frames': 1  # 默认分配
                })

        # 过滤掉没有CLIP采样点的事件（如果需要）
        if clip_similarities is not None:
            original_count = len(event_data)
            event_data = [e for e in event_data if e['clip_count'] > 0]
            # if len(event_data) < original_count:
            #     print(f"[过滤] 移除了 {original_count - len(event_data)} 个没有CLIP采样点的事件")

        if not event_data:
            # print("[WARN] 没有有效事件")
            return np.array([]), []

        # 打印事件分布统计
        # print(f"\n[事件统计]:")
        # lengths = [e['length'] for e in event_data]
        # print(f"  - 事件长度分布: 最小={min(lengths)}帧, 最大={max(lengths)}帧, 平均={np.mean(lengths):.1f}帧")
        # single_frame_events = sum(1 for e in event_data if e['length'] == 1)
        # if single_frame_events > 0:
        #     print(f"  - 单帧事件数: {single_frame_events} ({single_frame_events / len(event_data) * 100:.1f}%)")

        # 4. 如果事件数超过初始目标，需要选择性采样
        if len(event_data) > initial_target:
            # print(f"  - 注意：事件数({len(event_data)})超过初始目标({initial_target})，基于CLIP相似度选择事件")

            # 按平均CLIP相似度降序排序
            event_data.sort(key=lambda x: x['avg_similarity'], reverse=True)

            # 选择前N个事件，但至少保证有initial_target个采样点
            selected_events = []
            selected_frames = 0

            for event in event_data:
                # 每个事件至少采1帧
                frames_needed = 1

                if selected_frames + frames_needed <= initial_target:
                    selected_events.append(event)
                    selected_frames += frames_needed
                else:
                    break

            # 如果还有剩余的采样额度，可以给重要事件分配更多帧
            remaining_budget = initial_target - selected_frames
            if remaining_budget > 0:
                # 按事件长度和相似度分配剩余帧数
                for event in selected_events:
                    if remaining_budget <= 0:
                        break

                    # 根据事件长度和相似度计算额外帧数
                    if event['length'] > min_event_length:
                        extra_frames = min(
                            max_frames_per_event - 1,  # 减1因为已经分配了1帧
                            event['length'] // 10,  # 每10帧采1帧
                            remaining_budget
                        )
                        event['allocated_frames'] = 1 + extra_frames
                        remaining_budget -= extra_frames
                    else:
                        event['allocated_frames'] = 1
            else:
                for event in selected_events:
                    event['allocated_frames'] = 1

            # 打印选择结果
            # print(f"  - 选择了 {len(selected_events)} 个事件（按CLIP相似度排序）:")
            # for i, event in enumerate(selected_events[:5]):  # 只显示前5个
            #     print(f"    事件{i + 1}: 相似度={event['avg_similarity']:.3f}, "
            #           f"长度={event['length']}帧, 分配={event['allocated_frames']}帧")
            # if len(selected_events) > 5:
            #     print(f"    ... 和其他 {len(selected_events) - 5} 个事件")

            valid_events = selected_events
        else:
            # 事件数未超过目标，按CLIP相似度循环分配
            valid_events = event_data

            # 按CLIP相似度降序排序
            valid_events.sort(key=lambda x: x['avg_similarity'], reverse=True)

            # 初始化所有事件至少分配1帧
            for event in valid_events:
                event['allocated_frames'] = 1

            # 计算已分配的帧数
            allocated_count = len(valid_events)

            # 循环分配剩余帧数
            event_index = 0
            while allocated_count < initial_target:
                # 获取当前事件
                event = valid_events[event_index]

                # 检查是否还能分配更多帧
                if event['allocated_frames'] < min(event['length'], max_frames_per_event):
                    event['allocated_frames'] += 1
                    allocated_count += 1

                # 移到下一个事件（循环）
                event_index = (event_index + 1) % len(valid_events)

                # 安全检查：如果所有事件都达到上限，退出
                all_maxed = all(
                    e['allocated_frames'] >= min(e['length'], max_frames_per_event)
                    for e in valid_events
                )
                if all_maxed:
                    # print(f"  - 所有事件都已达到采样上限，停止分配")
                    break

            # 打印分配结果
            # print(f"  - 按CLIP相似度循环分配了 {allocated_count} 帧到 {len(valid_events)} 个事件:")
            # for i, event in enumerate(valid_events[:5]):
            #     print(f"    事件{i + 1}: 相似度={event['avg_similarity']:.3f}, "
            #           f"长度={event['length']}帧, 分配={event['allocated_frames']}帧")
            # if len(valid_events) > 5:
            #     print(f"    ... 和其他 {len(valid_events) - 5} 个事件")

        # 5. 执行采样
        sampled_indices = []

        for event_idx, event in enumerate(valid_events):
            start_frame = event['start']
            end_frame = event['end']
            num_frames = event['allocated_frames']  # 使用allocated_frames字段
            event_length = event['length']

            # 特别处理短事件
            if event_length <= min_event_length:
                mid_frame = (start_frame + end_frame) // 2
                if clip_frame_indices is not None:
                    # 找区间内最接近中点的采样点
                    in_range_indices = [idx for idx in clip_frame_indices
                                        if start_frame <= idx <= end_frame]
                    if in_range_indices:
                        closest_idx = min(in_range_indices,
                                          key=lambda x: abs(x - mid_frame))
                        sampled_indices.append(closest_idx)
                else:
                    sampled_indices.append(mid_frame)
                continue

            # 对于正常事件，使用CLIP相似度采样峰值
            if clip_similarities is not None and clip_frame_indices is not None:
                # 找出该事件区间内的所有采样点及其相似度
                event_points = []
                for i, frame_idx in enumerate(clip_frame_indices):
                    if start_frame <= frame_idx <= end_frame:
                        event_points.append((frame_idx, clip_similarities[i]))

                if event_points:
                    # 按相似度降序排序
                    event_points.sort(key=lambda x: x[1], reverse=True)

                    # 取前num_frames个峰值
                    selected_count = min(num_frames, len(event_points))
                    selected_indices = [point[0] for point in event_points[:selected_count]]

                    # 按帧索引排序（保持时间顺序）
                    sampled_indices.extend(sorted(selected_indices))

                    # 记录实际采样情况
                    # if selected_count < num_frames:
                    #     print(f"    [注意] 事件{event_idx + 1}可用采样点不足: "
                    #           f"需要{num_frames}帧，只有{len(event_points)}个采样点")
            else:
                # 没有CLIP相似度时的均匀采样
                if num_frames == 1:
                    mid_frame = (start_frame + end_frame) // 2
                    sampled_indices.append(mid_frame)
                else:
                    step = (event_length - 1) / (num_frames - 1) if num_frames > 1 else 0
                    positions = [start_frame + int(i * step) for i in range(num_frames)]
                    sampled_indices.extend(positions)

        # 去重并排序
        sampled_indices = sorted(list(set(sampled_indices)))

        # 6. 打印采样统计（修改版）
        if DEBUG:
            print(f"\n[初始采样] 结果统计:")
            print(f"  - 目标采样数: {initial_target}")
            print(f"  - 实际采样数: {len(sampled_indices)}")
            print(f"  - GP迭代空间: {total_budget - len(sampled_indices)}帧")

            # 打印每个事件的采样情况
            for i, event in enumerate(valid_events[:10]):  # 只显示前10个事件
                start = event['start']
                end = event['end']
                event_length = event['length']
                allocated_frames = event['allocated_frames']
                avg_sim = event['avg_similarity']

                # 找出实际采样的帧
                event_samples = [idx for idx in sampled_indices if start <= idx <= end]

                if event_length <= min_event_length:
                    event_type = f"(短事件≤{min_event_length}帧)"
                else:
                    event_type = ""

                print(f"  事件{i + 1} [{start}-{end}] (长度={event_length}帧){event_type}: "
                      f"平均相似度={avg_sim:.3f}, "
                      f"目标{allocated_frames}帧，实际{len(event_samples)}帧")

                # 显示采样的帧索引和相似度
                if clip_similarities is not None and clip_frame_indices is not None and event_samples:
                    sample_info = []
                    for sample_idx in event_samples[:5]:  # 最多显示5个
                        if sample_idx in clip_frame_indices:
                            sim_idx = clip_frame_indices.index(sample_idx)
                            sim = clip_similarities[sim_idx]
                            sample_info.append(f"{sample_idx}({sim:.3f})")

                    if sample_info:
                        display_text = ', '.join(sample_info)
                        if len(event_samples) > 5:
                            display_text += f"... ({len(event_samples) - 5}个更多)"
                        print(f"    采样帧: {display_text}")

            # 如果事件太多，显示总结
            if len(valid_events) > 10:
                print(f"  ... 还有 {len(valid_events) - 10} 个事件未显示")

                # 可选：显示一些统计信息
                total_allocated = sum(e['allocated_frames'] for e in valid_events)
                avg_sim_all = np.mean([e['avg_similarity'] for e in valid_events])
                print(f"\n  总体统计:")
                print(f"    - 总事件数: {len(valid_events)}")
                print(f"    - 平均CLIP相似度: {avg_sim_all:.3f}")
                print(f"    - 计划采样总数: {total_allocated}")

        # 获取采样的帧
        sampled_frames = vr.get_batch(sampled_indices).asnumpy()

        return sampled_frames, sampled_indices

    def iteration_algorithm(self,
                       video_frames, video_frames_indices,
                       VR, **kwargs):
        """
        使用HISA（Hybrid Iterative Sampling Algorithm）替换GP的版本
        保留原有的过滤和合并逻辑
        """
        # 参数提取
        budget = kwargs.get("budget")
        min_frame_distance = kwargs.get("min_frame_distance", 6)
        query = kwargs.get("query", None)
        video_fps = kwargs.get("video_fps", 24.0)
        i2t_norm_params = kwargs.get("i2t_norm_params", None)

        # 过滤和合并阈值
        clip_i2t_threshold = kwargs.get("clip_i2t_threshold", 0.25)

        # HISA特定参数
        hisa_weights = kwargs.get('hisa_weights', None)
        clip_sims = kwargs['clip_similarities']
        raw_clip_sims = kwargs.get('raw_clip_similarities', None)
        clip_indices = kwargs['clip_frame_indices']

        if DEBUG:
            print(f"\n[HISA算法] 视频信息:")
            print(f"  - 帧率: {video_fps} fps")
            print(f"  - 自适应Budget: {budget} 帧")
            print(f"  - CLIP i2t过滤阈值: {clip_i2t_threshold}")
            print(f"  - HISA权重: {hisa_weights}")
        if i2t_norm_params and DEBUG:
            print(f"  - i2t归一化参数: min={i2t_norm_params['min']:.3f}, max={i2t_norm_params['max']:.3f}")

        T = len(VR)  # 视频总帧数

        # 1. 初始化有序数据结构
        ordered_indices = []
        ordered_frames = []

        all_frame_indices = []
        if video_frames_indices and isinstance(video_frames_indices[0], list):
            for event_indices in video_frames_indices:
                all_frame_indices.extend(event_indices)
        else:
            all_frame_indices = list(video_frames_indices)

        ordered_indices = list(all_frame_indices)
        ordered_frames = list(video_frames)

        # 创建保存目录
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        save_dir = f"./hisa_algorithm_output/{timestamp}"
        os.makedirs(save_dir, exist_ok=True)

        # 保存原始帧和描述
        # original_dir = os.path.join(save_dir, "original_frames")
        # os.makedirs(original_dir, exist_ok=True)

        # print(f"\n保存原始帧到: {original_dir}")

        # original_info = []
        #
        # for i, (frame, idx) in enumerate(zip(ordered_frames, ordered_indices)):
        #     frame_img = Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame
        #     frame_img.save(os.path.join(original_dir, f"frame_{i:03d}_idx{idx}.png"))
        #
        #     frame_info = {
        #         "sequence_id": i,
        #         "original_frame_index": idx,
        #         "timestamp": idx / video_fps,
        #     }
        #     original_info.append(frame_info)
        #
        # with open(os.path.join(original_dir, "original_frames_info.json"), "w", encoding='utf-8') as f:
        #     json.dump({
        #         "total_frames": len(ordered_frames),
        #         "frame_indices": ordered_indices,
        #         "frames": original_info
        #     }, f, indent=2, ensure_ascii=False)

        # ========== 添加初始帧网格图 ==========
        # print("\n生成初始采样帧网格图...")

        # 准备帧信息（如果有CLIP相似度的话）
        # initial_frame_info = []
        # for i, idx in enumerate(ordered_indices):
        #     info = {"frame_index": idx}
        #     initial_frame_info.append(info)

        # 创建初始帧网格
        # initial_grid_path = os.path.join(save_dir, "frame_grid_initial.png")
        # self.create_frame_grid(
        #     frames=ordered_frames,
        #     frame_indices=ordered_indices,
        #     iteration_num=0,  # 0 表示初始采样
        #     save_path=initial_grid_path,
        #     title_prefix="Initial Sampling",
        #     max_cols=8,
        #     frame_info=initial_frame_info,
        #     video_fps=video_fps
        # )
        # 3. 准备HISA所需的dense query similarity数组
        dense_query_sim = np.full(T, np.nan)  # 默认值NaN
        raw_dense_query_sim = np.full(T, np.nan)  # 用于存储原始相似度
        # 如果有预计算的CLIP相似度，进行映射和插值
        if DEBUG:
            print(f"\n准备HISA dense相似度数组:")
            print(f"  - 已知CLIP点: {len(clip_indices)}")
            print(f"  - 视频总帧数: {T}")
        mapped_count = 0
        # 将稀疏的相似度映射到dense数组
        for i, idx in enumerate(clip_indices):
            if 0 <= idx < T:
                dense_query_sim[idx] = clip_sims[i]

                # 安全地处理原始相似度
                if raw_clip_sims is not None:
                    if raw_clip_sims.dim() == 2:
                        # 如果是2D张量 [1, N]，先取第一行再索引
                        raw_dense_query_sim[idx] = float(raw_clip_sims[0, i].item())
                    elif raw_clip_sims.dim() == 1:
                        # 如果是1D张量 [N]，直接索引
                        raw_dense_query_sim[idx] = float(raw_clip_sims[i].item())
                    elif isinstance(raw_clip_sims, np.ndarray):
                        raw_dense_query_sim[idx] = float(raw_clip_sims[i])
                    else:
                        raw_dense_query_sim[idx] = float(raw_clip_sims[i])
                mapped_count += 1

        # 用于HISA的采样集合（包含所有尝试过的帧）
        current_sampled = list(ordered_indices) # F_initial
        if DEBUG:
            print(f"初始帧优化后: {len(current_sampled)} 帧")
            print(f"\n{'#' * 100}")
            print(f"开始HISA迭代优化")
            print(f"  - 初始帧数: {len(ordered_indices)}")
            print(f"  - 目标Budget: {budget}")
            print(f"  - 过滤阈值: i2t={clip_i2t_threshold}")
            print(f"{'#' * 100}")
        # 6. HISA迭代采样
        iteration = 1
        early_stop = False
        iterations_dir = os.path.join(save_dir, "iterations")
        os.makedirs(iterations_dir, exist_ok=True)
        dense_analyzed_query_sim = np.full(T, np.nan)  # 初始化为nan，表示未分析的帧
        total_budget = kwargs.get("budget")  # 总预算，不修改

        F_final = []  # F^*_final
        F_probe = current_sampled   # F

        # make iteration fixed number, adding budget overall limit.
        while iteration <= self.max_iteration and not early_stop:
            # 替换第3859行为：
            # print(f"\n{'=' * 100}")
            # print(f"--- HISA迭代 {iteration} | 有效帧数 = {len(valid_indices)}/{budget} ---")
            # print(f"{'=' * 100}")

            # 打印当前所有有效帧的详细信息
            # print(f"\n[迭代开始] 当前有效帧详情:")
            # print(f"{'序号':<6} {'帧索引':<8} {'时间戳':<10} {'i2t相似度':<12}")
            # print("-" * 70)
            # for i, idx in enumerate(valid_indices):
            #     timestamp = idx / video_fps
            #     i2t_sim = valid_frames_cache[i]['clip_i2t_sim_norm'] if i < len(valid_frames_cache) else 0.5
            #     if DEBUG:
            #         print(f"{i:<6} {idx:<8} {timestamp:<10.2f} {i2t_sim:<12.3f} ")

            # iter_dir = os.path.join(iterations_dir, f"iteration_{iteration:02d}")
            # os.makedirs(iter_dir, exist_ok=True)
            # iteration_info = {
            #     "iteration": iteration,
            #     "current_sample_size": len(current_sampled),
            #     "selected_frames": []
            # }

            # 使用HISA选择下一个采样点
            # current_sampled = sorted(list(set(current_sampled)))
            # 最后一轮，用光所有budget

            # HISA步骤：选择候选帧并进行VLM分析
            # print(iteration)
            new_sampled_indices, dense_analyzed_query_sim, updated_sampled, early_stop = self.iteration_step(  #F^* = updated_sampled
                probed_indices = F_probe,
                query_sim=dense_query_sim,
                analyze_sim=dense_analyzed_query_sim,  # 传入现有的VLM分析结果
                budget=total_budget,
                weights=hisa_weights,
                N_candidate=max(budget//self.max_iteration, self.N_candidate),  # 候选帧数量
                VR=VR,  # VideoReader对象
                query=kwargs.get('query', ''),  # 查询文本
                generation_kwargs=kwargs.get('generation_kwargs', {}),
                # save_dir=iter_dir,
                low_score_threshold=2,  # 可以调整
                min_frame_distance=min_frame_distance,
                iteration = iteration,
            )
            if early_stop:
                break
            new_frames = VR.get_batch(new_sampled_indices).asnumpy()

            # 计算新帧的实际CLIP相似度并更新dense数组
            # new_frames_data = []

            for frame_idx, frame_array in zip(new_sampled_indices, new_frames):
                # 计算实际的CLIP相似度
                assert self.text_features is not None
                pil_image = Image.fromarray(frame_array)
                with autocast():
                    if self.clip_model_type in ["openai", "open_clip", "eva_clip", "long_clip"]:
                        image_features = encode_images_clip(
                            self.clip_model, [pil_image],
                            preprocess=self.clip_preprocess_or_processor,
                            processor=None,
                            model_type=self.clip_model_type,
                            device=self.device
                        )
                    else:
                        image_features = encode_images_clip(
                            self.clip_model, [pil_image],
                            preprocess=None,
                            processor=self.clip_preprocess_or_processor,
                            model_type=self.clip_model_type,
                            device=self.device
                        )

                clip_i2t_sim = compute_clip_similarity(
                    self.text_features, image_features,
                    model=self.clip_model,
                    model_type=self.clip_model_type,
                    normalization_method="none"
                ).item()

                # update the original sim using raw_clip_sims
                if clip_i2t_sim > i2t_norm_params['max']:
                    i2t_norm_params['max'] = clip_i2t_sim
                    # 重新归一化所有已有的值
                    min_val = i2t_norm_params['min']
                    max_val = i2t_norm_params['max']
                    range_val = max_val - min_val
                    if range_val > 1e-8:  # 避免除零
                        # 只更新非NaN的值
                        valid_mask = ~np.isnan(raw_dense_query_sim)
                        dense_query_sim[valid_mask] = (raw_dense_query_sim[valid_mask] - min_val) / range_val
                        # 裁剪到[0,1]范围
                        dense_query_sim[valid_mask] = np.clip(dense_query_sim[valid_mask], 0, 1)

                elif clip_i2t_sim < i2t_norm_params['min']:
                    i2t_norm_params['min'] = clip_i2t_sim
                    # 重新归一化所有已有的值
                    min_val = i2t_norm_params['min']
                    max_val = i2t_norm_params['max']
                    range_val = max_val - min_val
                    if range_val > 1e-8:  # 避免除零
                        # 只更新非NaN的值
                        valid_mask = ~np.isnan(raw_dense_query_sim)
                        dense_query_sim[valid_mask] = (raw_dense_query_sim[valid_mask] - min_val) / range_val
                        # 裁剪到[0,1]范围
                        dense_query_sim[valid_mask] = np.clip(dense_query_sim[valid_mask], 0, 1)
                # 归一化
                if i2t_norm_params:
                    clip_i2t_sim_norm = (clip_i2t_sim - i2t_norm_params['min']) / \
                                        (i2t_norm_params['max'] - i2t_norm_params['min'] + 1e-8)
                    clip_i2t_sim_norm = np.clip(clip_i2t_sim_norm, 0, 1)
                else:
                    clip_i2t_sim_norm = clip_i2t_sim
                # 更新dense数组
                dense_query_sim[frame_idx] = clip_i2t_sim_norm

            filtered_indices = [
                idx for idx in new_sampled_indices if dense_query_sim[idx] >= clip_i2t_threshold
            ]
            filtered_updated_indices = [
                idx for idx in updated_sampled if dense_query_sim[idx] >= clip_i2t_threshold
            ]
            F_probe = sorted(list(set(F_probe + filtered_updated_indices + filtered_indices))) # update F
            F_final = sorted(list(set(F_final + filtered_updated_indices)))
            # print(f"现采样{len(current_sampled)}")
            if not filtered_indices:
                if DEBUG:
                    print(f"\n[警告] 迭代 {iteration} 的所有新帧都被过滤掉")
                iteration += 1
                continue

            # frames_to_add = filtered_new_data   # 直接使用所有过滤后的新帧

            # # 在添加新的有效帧之后，跟踪新增的帧
            # newly_added_frames = []
            # newly_added_indices = []
            # newly_added_info = []
            # # 添加新的有效帧
            # for frame_data in frames_to_add:
            #     insert_pos = len(valid_indices)
            #     for i, idx in enumerate(valid_indices):
            #         if frame_data['index'] < idx:
            #             insert_pos = i
            #             break
            #
            #     valid_indices.insert(insert_pos, frame_data['index'])
            #     valid_frames.insert(insert_pos, frame_data['frame'])
            #
            #     valid_frames_cache.insert(insert_pos, {
            #         'image_features': frame_data['image_features'],
            #         'clip_i2t_sim': frame_data['clip_i2t_sim'],
            #         'clip_i2t_sim_norm': frame_data['clip_i2t_sim_norm'],
            #         'frame_index': frame_data['index']
            #     })
            #
            #     # 记录新增的帧
            #     newly_added_frames.append(frame_data['frame'])
            #     newly_added_indices.append(frame_data['index'])
            #     newly_added_info.append({
            #         'clip_i2t_sim_norm': frame_data['clip_i2t_sim_norm'],
            #         'frame_index': frame_data['index']
            #     })

            # 保存迭代信息
            # iteration_info['valid_frames_count'] = len(valid_indices)
            # iteration_info['filtered_count'] = len(new_frames_data) - len(filtered_new_data)
            # iteration_info['new_sampled_count'] = len(new_sampled_indices)
            # iteration_info['new_sampled_indices'] = [int(idx) for idx in new_sampled_indices]
            # iteration_info['filtered_indices'] = [int(fd['index']) for fd in new_frames_data if fd not in filtered_new_data]
            # all_iterations_info.append(iteration_info)

            # ========== 添加当前迭代的帧网格图 ==========
            # 1. 仅显示本轮新增的帧
            # if newly_added_frames:  # 如果有新增帧
            #     print(f"\n生成迭代 {iteration} 新增帧网格图...")
            #     new_frames_grid_path = os.path.join(iter_dir, f"frame_grid_new_frames_iter_{iteration:02d}.png")
            #     self.create_frame_grid(
            #         frames=newly_added_frames,
            #         frame_indices=newly_added_indices,
            #         iteration_num=iteration,
            #         save_path=new_frames_grid_path,
            #         title_prefix=f"HISA Iter {iteration} - New Frames",
            #         max_cols=8,
            #         frame_info=newly_added_info,
            #         video_fps=video_fps
            #     )
            # print(f"\n生成迭代 {iteration} 的帧网格图...")

            # 创建当前迭代的帧网格
            # iter_grid_path = os.path.join(iter_dir, f"frame_grid_iteration_{iteration:02d}.png")
            # self.create_frame_grid(
            #     frames=valid_frames,
            #     frame_indices=valid_indices,
            #     iteration_num=iteration,
            #     save_path=iter_grid_path,
            #     title_prefix="HISA Iteration",
            #     max_cols=8,
            #     frame_info=valid_frames_cache,  # 这里已经有完整的CLIP信息
            #     video_fps=video_fps
            # )
            # ========== 结束添加 ==========
            # 检查是否达到budget
            iteration += 1
        # if the sampled indices exceed the budget, simply select top.
        if len(F_final) > budget:
            # print(len(valid_indices))
            final_scores = []
            for idx in F_final:
                clip_score = dense_query_sim[idx] if not np.isnan(dense_query_sim[idx]) else 0.5
                vlm_score = dense_analyzed_query_sim[idx] if not np.isnan(dense_analyzed_query_sim[idx]) else 0

                if vlm_score > 0:
                    combined = clip_score / 4 + vlm_score / 5
                else:
                    combined = clip_score
                final_scores.append((idx, combined))

            final_scores.sort(key=lambda x: x[1], reverse=True)
            F_final = sorted([idx for idx, _ in final_scores[:budget]])
        final_frames = VR.get_batch(F_final).asnumpy()      # order, indices - > frames
        # 保存所有迭代的汇总信息
        # with open(os.path.join(iterations_dir, "all_iterations_summary.json"), "w", encoding='utf-8') as f:
        #     json.dump({
        #         "algorithm": "HISA",
        #         "total_iterations": len(all_iterations_info),
        #         "initial_frames": [int(idx) for idx in all_frame_indices],
        #         "final_sample_size": len(all_frame_indices),
        #         "hisa_config": {
        #             "weights": hisa_weights,
        #         },
        #         "iterations": all_iterations_info
        #     }, f, indent=2, ensure_ascii=False)

        # 最终输出处理
        # 修改GP_algorithm中的打印

        # print(f"\nHISA算法完成:")
        # print(f"  - 初始采样: {len(ordered_indices)}帧")
        # print(f"  - VLM分析: {np.sum(~np.isnan(dense_analyzed_query_sim))}帧")
        # print(f"  - 总采样尝试: {len(current_sampled)}帧")
        # print(f"  - 最终选中: {len(valid_indices)}帧")
            # print(
            #     f"    ├─ 经VLM验证: {sum(1 for i in valid_indices if not np.isnan(dense_analyzed_query_sim[valid_frames_cache[valid_indices.index(i)]['frame_index']]))}帧")
            # print(
            #     f"    └─ 仅CLIP评分: {sum(1 for i in valid_indices if np.isnan(dense_analyzed_query_sim[valid_frames_cache[valid_indices.index(i)]['frame_index']]))}帧")

        # 生成输出格式
        processed_frames = []
        processed_frame_indices = []

        for idx, frame in zip(F_final, final_frames):
            processed_frames.append(Image.fromarray(frame))
            processed_frame_indices.append(int(idx))
        # # 简化的描述（只包含必要信息）
        # processed_descs = [{'frame_index': idx, 'timestamp_seconds': round(float(idx / video_fps), 2)}
        #                    for idx in processed_frame_indices]

        # 保存最终结果
        # final_dir = os.path.join(save_dir, "final_frames")
        # os.makedirs(final_dir, exist_ok=True)

        # final_info = {
        #     "algorithm": "HISA",
        #     "initial_frames": len(ordered_indices),
        #     "total_attempts": len(current_sampled),
        #     "total_frames": len(processed_frames),
        #     "frame_indices": processed_frame_indices,
        #     "frames": []
        # }
        #
        # for i, (frame, idx) in enumerate(zip(processed_frames, processed_frame_indices)):
        #     frame_img = Image.fromarray(frame) if isinstance(frame, np.ndarray) else frame
        #     frame_img.save(os.path.join(final_dir, f"frame_{i:03d}_idx{idx}.png"))
        #
        #     frame_info = {
        #         "sequence_id": i,
        #         "frame_index": int(idx),
        #         "timestamp": float(idx / video_fps),
        #         "clip_i2t_sim": float(valid_frames_cache[i]['clip_i2t_sim_norm']) if i < len(
        #             valid_frames_cache) else 0.5
        #     }
        #     final_info["frames"].append(frame_info)

        # 保存所有迭代信息的汇总
        # with open(os.path.join(save_dir, "all_iterations_summary.json"), "w", encoding='utf-8') as f:
        #     json.dump({
        #         "algorithm": "HISA",
        #         "total_iterations": len(all_iterations_info),
        #         "initial_frames": [int(idx) for idx in ordered_indices],
        #         "final_frames": processed_frame_indices,
        #         "hisa_config": {
        #             "weights": hisa_weights,
        #             "i2t_threshold": clip_i2t_threshold,
        #             "budget": kwargs.get("budget")
        #         },
        #         "iterations": all_iterations_info
        #     }, f, indent=2, ensure_ascii=False)

        # 保存最终帧信息
        # with open(os.path.join(final_dir, "final_frames_info.json"), "w", encoding='utf-8') as f:
        #     json.dump(final_info, f, indent=2, ensure_ascii=False)
        #
        # 生成最终帧的网格图
        # final_grid_path = os.path.join(save_dir, "frame_grid_final.png")
        # self.create_frame_grid(
        #     frames=processed_frames,
        #     frame_indices=processed_frame_indices,
        #     iteration_num="Final",
        #     save_path=final_grid_path,
        #     title_prefix="Final Selected Frames",
        #     max_cols=8,
        #     # frame_info=valid_frames_cache,
        #     video_fps=video_fps
        # )
        #
        # print(f"\nHISA算法数据已保存到: {save_dir}")

        return processed_frames, processed_frame_indices

    def calculate_adaptive_budget(self, video_duration_seconds, max_budget):
        """根据模型类型返回budget"""

        # 固定budget的模型
        if self.model_type == 'VILA-1.5':
            return 8
        elif self.model_type == 'LLaVA-OV':
            return 32

        else:
            MIN_BUDGET = 32
            MAX_BUDGET = 64

            if video_duration_seconds <= 60:
                return MIN_BUDGET
            elif video_duration_seconds >= 375:
                return MAX_BUDGET
            else:
                budget = MIN_BUDGET + (MAX_BUDGET - MIN_BUDGET) * (video_duration_seconds - 60) / (375 - 60)
                budget = round(budget / 4) * 4 # 16的倍数
                return max(MIN_BUDGET, min(MAX_BUDGET, budget))

    def _normalize(self, data: np.ndarray) -> np.ndarray:
        """
        Min-max normalizes a 1D numpy array to the  range.
        Handles the edge case where all values are the same.
        """
        min_val = np.min(data)
        max_val = np.max(data)
        if max_val == min_val:
            # If all values are identical, return an array of 0.5s
            return np.full_like(data, 0.5, dtype=float)
        return (data - min_val) / (max_val - min_val)

    def iteration_step(self,
                            probed_indices: list[int],
                            query_sim: np.ndarray,
                            analyze_sim: np.ndarray,
                            budget: int,
                            weights: dict = None,
                            N_candidate: int = 24,
                            # VLM分析相关参数
                            VR=None,
                            query: str = "",
                            generation_kwargs: dict = None,
                            save_dir=None,
                            low_score_threshold: int = 2,  # 低分阈值
                            min_frame_distance: int = 4,
                            iteration: int = 1,
                            ):
        """
        HISA步骤（简化版）：
        1. 选择前N个候选帧进行VLM分析
        2. 删除低分帧，从后续区间补充相同数量的帧
        3. 基于综合分数进行最终采样
        """
        if weights is None:
            weights = {'length': 0.5, 'complexity': 0.5}

        total_frames = len(query_sim)

        # --- 准备区间和评分 ---
        combined_sim = query_sim.copy()
        has_vlm_info = not np.all(np.isnan(analyze_sim))

        # 记录已分析过的帧
        already_analyzed = []

        if has_vlm_info:
            for i in range(total_frames):
                if not np.isnan(analyze_sim[i]):
                    combined_sim[i] = np.nansum(query_sim[i]) / 2 + np.nansum(analyze_sim[i] / 4.0)
                    already_analyzed.append(i)
        else:
            combined_sim = query_sim

        # 识别所有区间
        intervals = []
        if not probed_indices:
            intervals.append([0, total_frames - 1])
        else:
            probed_indices = sorted(probed_indices)
            if probed_indices[0] > 0:
                intervals.append([0, probed_indices[0]])
            for i in range(len(probed_indices) - 1):
                intervals.append([probed_indices[i], probed_indices[i + 1]])
            if probed_indices[-1] < total_frames - 1:
                intervals.append([probed_indices[-1], total_frames - 1])

        valid_intervals = [(start, end) for start, end in intervals if end > start + 1]

        if not valid_intervals:
            if DEBUG:
                # print(f"\n[HISA] 无有效区间 - sampled数量:{len(sampled_indices)}, 前10个:{sorted(sampled_indices)[:10]}")
                print(f"[HISA] intervals数量:{len(intervals)}, 前3个:{intervals[:3] if intervals else 'None'}")
            if intervals:
                print(f"[HISA] 区间长度分布:{[end - start for start, end in intervals[:5]]}")

            if DEBUG:
                print("No valid intervals found.")
            return [], analyze_sim, [], False  # 返回4个值

        all_scores = []
        all_complexity = []
        for start, end in valid_intervals:
            # 1. Calculate Length (L)
            length = float(end - start)
            # Handle edge cases for very short intervals
            if length > 2:
                interval_sim = combined_sim[start:end]
                # 2. Calculate Relevance (R)
                # This is the `np.nanmean` from your code.
                relevance = np.nanmean(interval_sim) + (
                            np.nanmean(combined_sim[start]) + np.nanmean(combined_sim[end])) / 6
                # 3. Calculate Complexity (C) using Total Variation
                # This is ∫|sim'(t)|dt, calculated as the sum of absolute differences.
                # We normalize by (length - 1) to get the *average change per frame step*,
                # making the metric less dependent on the interval's length itself.
                non_nan_sim = interval_sim[~np.isnan(interval_sim)]
                # 3.b. Check if there are enough valid points to calculate change
                if len(non_nan_sim) < 2:
                    complexity = 0.0
                else:
                    total_variation = np.sum(np.abs(np.diff(non_nan_sim)))
                    complexity = total_variation / (length - 1.0)
                # 4. Assemble the RCL Score
                # Bonus from Complexity
                complexity_bonus = 1.0 + weights['complexity'] * complexity
                # Bonus from Length (log-scaled for diminishing returns)
                length_bonus = 1.0 + weights['length'] * np.log(length)
                # Final multiplicative score
                score = relevance * complexity_bonus * length_bonus
            else:
                score = 0.01
            all_scores.append(score)
        scores = np.array(all_scores)
        sorted_interval_indices = np.argsort(scores)[::-1]

        # --- 收集所有可能的候选帧（按优先级排序）---
        all_possible_candidates = []

        for interval_idx in sorted_interval_indices:
            start, end = valid_intervals[int(interval_idx)]
            all_possible_candidates.append(start)
            all_possible_candidates.append(end)
        all_possible_candidates = sorted(set(all_possible_candidates))
        if not all_possible_candidates:
            if DEBUG:
                print("[HISA] 没有新的候选帧可分析")
            return [], analyze_sim, [], False

        initial_candidates = all_possible_candidates[:N_candidate]
        frames_needed_analyze = [f for f in initial_candidates if f not in already_analyzed]
        if len(frames_needed_analyze) > 0:  # 添加检查
            if DEBUG:
                print(f"\n[HISA] 需要VLM分析 {len(frames_needed_analyze)} 个新帧")

            candidate_frames = VR.get_batch(frames_needed_analyze).asnumpy()
            analyzed_scores = self.analyze_frames_with_query(
                candidate_frames,
                query=query,
                sub_batch_size=min(len(frames_needed_analyze), self.sub_batch),
                generation_kwargs=generation_kwargs or {},
                save_dir=save_dir
            )

            # --- 更新分析结果（注意缩进！）---
            for i, idx in enumerate(frames_needed_analyze):
                analyze_sim[idx] = analyzed_scores[i]  # 更新analyze_sim
                combined_sim[idx] = combined_sim[idx] + analyzed_scores[i] / 4.0
        else:
            if DEBUG:
                print(f"\n[HISA] 所有 {len(initial_candidates)} 个候选帧都已分析过")

        # --- 筛选高质量帧（这部分在if外面）---
        final_candidates = []
        neutral_candidates = []
        final_scores = []
        low_score_indices = []
        sampled_indices = []
        # 收集高分帧（基于所有initial_candidates，不只是新分析的）
        for idx in initial_candidates:
            if not np.isnan(analyze_sim[idx]) and analyze_sim[idx] > low_score_threshold + 1:   # 3 is neutral, unsure.
                final_candidates.append(idx)
                final_scores.append(analyze_sim[idx])
            elif not np.isnan(analyze_sim[idx]) and analyze_sim[idx] == low_score_threshold + 1:
                neutral_candidates.append(idx)  # 记录中性帧
            else:
                low_score_indices.append(idx)  # 记录低分帧
        # 判断不同的analysis情况,选择进一步补帧或者退出
        # 对第一轮迭代放宽sampling条件
        if len(final_candidates) > budget:
            return [], analyze_sim, final_candidates, True  # early stop
        elif len(final_candidates) == 0:
            if iteration == 1:
                T = sorted(neutral_candidates + low_score_indices)
                sorted_T_indices = np.argsort(combined_sim[T])[::-1]
                top_half_indices_in_T = sorted_T_indices[:budget//self.max_iteration]
                sampled_indices = sorted([T[i] for i in top_half_indices_in_T])
            else:
                sampled_indices = sorted(neutral_candidates)
        elif low_score_indices and final_candidates:
            sampled_indices = final_candidates

        # --- 最终采样逻辑 ---
        # 创建综合相似度数组
        temp_analyze_sim = np.full_like(analyze_sim, np.nan, dtype=float)
        sample_budget = budget - len(sampled_indices) + 3  # 冗余sample，确保如果sample多了最后循环外会按顺序取top-budet个。
        # 对于所有最终候选帧计算综合分数
        for idx in sampled_indices:
            # 有VLM分数的帧
            score = analyze_sim[idx]
            if np.isnan(query_sim[idx]):
                temp_analyze_sim[idx] = score
                # print(f"caution! video {len(query_sim)} {query} query_sim[{idx}] is NaN, this should not happen in HISA logic")
            else:
                temp_analyze_sim[idx] = score * 3 / 5 + np.nansum(query_sim[idx])

        if DEBUG:
            print(f"\n[HISA] 最终候选帧数: {len(final_candidates)}")
            print(f"  - 其中 {len(final_scores)} 个经过VLM验证")
            print(f"  - 其中 {len(final_candidates) - len(final_scores)} 个仅基于CLIP分数")

            print(f"\n[DEBUG] temp_analyze_sim 统计:")
            print(f"  - min={np.nanmin(temp_analyze_sim):.3f}")
            print(f"  - max={np.nanmax(temp_analyze_sim):.3f}")
            print(f"  - 非nan值数量: {np.sum(~np.isnan(temp_analyze_sim))}")
            print(f"  - budget: {budget}")

        # 可视化（只显示VLM分析过的帧）
        # if save_dir and initial_candidates:
        #     vlm_frames = VR.get_batch(initial_candidates).asnumpy()
        #     self.visualize_frame_scores(
        #         frames=vlm_frames,
        #         scores=analyzed_scores,
        #         query=query,
        #         save_path=os.path.join(save_dir, "vlm_scores.png")
        #     )

        # 采样
        new_sampled_indices = sample_logic(temp_analyze_sim,
                                           sample_budget,
                                           min_frame_length=min_frame_distance,
                                           HIGH_SIM_THRESHOLD=2)
        if DEBUG:
            print(f"\n[DEBUG] sample_logic 返回: {len(new_sampled_indices)} 个索引")

        return new_sampled_indices, analyze_sim, sampled_indices, False

    def create_frame_grid(self, frames, frame_indices, iteration_num, save_path,
                          title_prefix="HISA Iteration", max_cols=8,
                          frame_info=None, video_fps=30.0):
        """
        创建帧的网格合并图

        参数:
            frames: 帧图像列表（numpy arrays或PIL Images）
            frame_indices: 帧索引列表
            iteration_num: 迭代次数
            save_path: 保存路径
            title_prefix: 标题前缀
            max_cols: 每行最大列数
            frame_info: 额外的帧信息（如相似度分数）
            video_fps: 视频帧率
        """
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        from PIL import Image
        import numpy as np

        n_frames = len(frames)
        if n_frames == 0:
            print("没有帧可以显示")
            return
        n_cols = min(n_frames, max_cols)
        n_rows = (n_frames + n_cols - 1) // n_cols

        # 创建图形
        fig_width = n_cols * 2
        fig_height = n_rows * 2.5  # 额外空间用于标注
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height))

        # 处理不同情况下的axes
        if n_frames == 1:
            # 只有一个子图
            axes_flat = [axes]
        elif n_rows == 1 and n_cols == 1:
            # 单个子图
            axes_flat = [axes]
        elif n_rows == 1:
            # 一行多列
            axes_flat = list(axes) if hasattr(axes, '__iter__') else [axes]
        elif n_cols == 1:
            # 多行一列
            axes_flat = list(axes) if hasattr(axes, '__iter__') else [axes]
        else:
            # 多行多列，展平
            axes_flat = axes.flatten()

        # 绘制每个帧
        for i in range(len(axes_flat)):
            ax = axes_flat[i]
            if i < n_frames:
                # 转换帧为PIL Image
                if isinstance(frames[i], np.ndarray):
                    img = Image.fromarray(frames[i])
                else:
                    img = frames[i]

                ax.imshow(img)

                # 添加帧信息
                frame_idx = frame_indices[i]
                time_sec = frame_idx / video_fps

                # 基本标题
                title = f"F{frame_idx}\n{time_sec:.1f}s"

                # 如果有额外信息（如相似度）
                if frame_info and i < len(frame_info):
                    if 'clip_i2t_sim_norm' in frame_info[i]:
                        sim_score = frame_info[i]['clip_i2t_sim_norm']
                        title += f"\nSim:{sim_score:.3f}"

                        # 根据相似度着色边框
                        if sim_score > 0.8:
                            color = 'green'
                            linewidth = 3
                        elif sim_score > 0.5:
                            color = 'yellow'
                            linewidth = 2
                        else:
                            color = 'red'
                            linewidth = 2

                        # 添加彩色边框
                        rect = patches.Rectangle((0, 0), img.width - 1, img.height - 1,
                                                 linewidth=linewidth, edgecolor=color,
                                                 facecolor='none', transform=ax.transData)
                        ax.add_patch(rect)

                ax.set_title(title, fontsize=10)
                ax.axis('off')
            else:
                # 隐藏多余的子图
                ax.axis('off')

        # 添加总标题
        if iteration_num == 0:
            main_title = f"Initial Sampling - {n_frames} frames"
        else:
            main_title = f"{title_prefix} {iteration_num} - {n_frames} frames"

        plt.suptitle(main_title, fontsize=16, fontweight='bold')
        plt.tight_layout()

        # 保存图片
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()

        print(f"保存帧网格图: {save_path}")

    def analyze_frames_with_query(self, sampled_frames, query, sub_batch_size=12,
                                  generation_kwargs=None, debug=False, save_dir="./vlm_analysis"):
        """
        使用VLM分析帧与查询的相关性并评分

        参数:
            sampled_frames: 采样的视频帧数组
            query: 用户查询文本
            sub_batch_size: 子批次大小
            generation_kwargs: 生成参数
            debug: 是否打印调试信息

        返回:
            scores: 每帧的评分列表 (1-5分)
        """
        # 在开始时打印
        if DEBUG:
            print(f"\n[DEBUG] analyze_frames_with_query:")
            print(f"  - 当前模型类型: {self.model_type}")
            print(f"  - VL模型: {self.vl_model_set[0].__class__.__name__}")
            print(f"  - 查询: {query}")

        if generation_kwargs is None:
            generation_kwargs = {}

        # 修改为评分消息格式
        analysis_messages = create_analysis_messages(
            sampled_frames, query, self.max_pixels, self.min_pixels, self.model_type
        )
        if DEBUG:
            print(f'分析帧与查询相关性，共 {len(analysis_messages)} 帧')

        # 2. 批量生成评分
        scores = []
        num_sub_batches = (len(analysis_messages) + sub_batch_size - 1) // sub_batch_size

        all_responses = []
        for j in range(num_sub_batches):
            t0 = time.time()
            start_idx = j * sub_batch_size
            end_idx = min((j + 1) * sub_batch_size, len(analysis_messages))
            batch_messages = analysis_messages[start_idx:end_idx]

            # 生成评分响应
            batch_response = self._batch_inference(
                self.vl_model_set,
                batch_messages,
                **generation_kwargs
            )
            all_responses.extend(batch_response)
            # 解析评分
            batch_scores = []
            for response in batch_response:
                score = extract_score(response.lower())
                batch_scores.append(score)

            if debug:
                print(f"\n[DEBUG] 评分生成 - 批次 {j + 1}/{num_sub_batches}:")
                for idx, (response, score) in enumerate(zip(batch_response, batch_scores)):
                    print(f"  [{start_idx + idx}] Score: {score}, Response: {response}")

            scores.extend(batch_scores)
            if DEBUG:
                print(f'  批次 {j + 1}/{num_sub_batches} 完成，用时: {time.time() - t0:.2f}秒')

        # os.makedirs(save_dir, exist_ok=True)
        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # filename = os.path.join(save_dir, f"vlm_analysis_{timestamp}.txt")
        #
        # with open(filename, 'w', encoding='utf-8') as f:
        #     f.write(f"查询: {query}\n")
        #     f.write("=" * 60 + "\n\n")
        #
        #     for i, response in enumerate(all_responses):
        #         f.write(f"帧 {i}:\n")
        #         f.write(f"{response}\n\n")
        #
        # print(f"\nVLM分析结果已保存到: {filename}")
        return scores

    def visualize_frame_scores(self, frames, scores, query, save_path=None):
        """可视化帧和对应的评分"""
        import matplotlib.pyplot as plt
        import matplotlib.patches as patches

        n_frames = len(frames)
        cols = min(5, n_frames)
        rows = (n_frames + cols - 1) // cols

        fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))

        # 处理不同的axes形状
        if rows == 1 and cols == 1:
            axes = [axes]
        elif rows == 1:
            axes = list(axes)
        elif cols == 1:
            axes = list(axes)
        else:
            axes = axes.flatten()

        for i, (frame, score) in enumerate(zip(frames, scores)):
            if i < len(axes):
                ax = axes[i]

                # 显示帧
                if isinstance(frame, np.ndarray):
                    ax.imshow(frame)
                else:
                    ax.imshow(np.array(frame))

                # 根据分数设置边框颜色
                color = 'green' if score >= 4 else 'orange' if score >= 3 else 'red'
                ax.set_title(f"Score: {score}", color=color, fontsize=12, fontweight='bold')
                ax.axis('off')

                # 添加彩色边框
                rect = patches.Rectangle((0, 0), 1, 1, transform=ax.transAxes,
                                         linewidth=4, edgecolor=color, facecolor='none')
                ax.add_patch(rect)

        # 隐藏多余的子图
        for i in range(len(frames), len(axes)):
            axes[i].axis('off')
            axes[i].set_visible(False)

        plt.suptitle(f'VLM Analysis - Query: "{query[:50]}..."', fontsize=14, fontweight='bold')
        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"保存VLM评分可视化: {save_path}")

        plt.close()  # 避免内存泄漏

    def _normalize_full_video(self, similarity):
        """对整个视频的相似度进行归一化，并返回归一化参数"""
        if DEBUG:
            print(f"[Full Video] similarity shape: {similarity.shape}")
            print(f"[Full Video] similarity dtype: {similarity.dtype}")  # 添加调试信息

        # 确保 similarity 是 float32 类型
        if similarity.dtype not in [torch.float32, torch.float64]:
            similarity = similarity.float()  # 转换为 float32

        min_s = similarity.min()
        max_s = similarity.max()

        # 计算百分位数（更鲁棒）
        p10 = torch.quantile(similarity.flatten(), 0.1)  # 添加 flatten() 确保是1D
        p90 = torch.quantile(similarity.flatten(), 0.9)

        if DEBUG:
            print(f"[Full Video] min: {min_s:.4f}, max: {max_s:.4f}, range: {(max_s - min_s):.4f}")
            print(f"[Full Video] p10: {p10:.4f}, p90: {p90:.4f}")

        # 创建归一化参数字典
        norm_params = {
            'min': min_s.item(),
            'max': max_s.item(),
            'p10': p10.item(),
            'p90': p90.item(),
            'mean': similarity.mean().item(),
            'std': similarity.std().item()
        }

        if (max_s - min_s) < 1e-8:
            return torch.full_like(similarity, 0.5), norm_params

        # 使用min-max归一化
        normalized = (similarity - min_s) / (max_s - min_s)
        # normalized = torch.pow(normalized, 2.5)

        return normalized, norm_params

    def _extract_from_messages(self, messages):
        """从messages格式中提取视频路径和查询文本"""
        if isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], list):
            # 处理messages格式
            extracted_videos = []
            extracted_queries = []
            for message in messages:
                try:
                    video = message[1]['content'][0]['video']
                    text = message[1]['content'][1]['text']
                    # 提取问号前的内容
                    if '？' in text:
                        text = text.split('？')[0] + '？'
                    elif '?' in text:
                        text = text.split('?')[0] + '?'
                    extracted_videos.append(video)
                    extracted_queries.append(text)
                except:
                    raise ValueError(f"Invalid message format: {message}")

            video_path = extracted_videos[0] if len(extracted_videos) == 1 else extracted_videos
            query = extracted_queries[0] if len(extracted_queries) == 1 else extracted_queries
            return video_path, query
        else:
            return None, None

    @torch.inference_mode()
    def visualize_clip_similarities(
            self,
            video_path: Union[str, List[List[dict]]],
            query: Union[str, None] = None,
            clip_model_name: Optional[str] = None,
            fps: float = 1.0,
            save_path: Optional[str] = None,
            show_plot: bool = True,
            precomputed_similarities: Optional[Tuple[np.ndarray, List[int]]] = None,
            show_extreme_frames: bool = True,  # 改名，显示极端帧（最高和最低）
            num_extreme_frames: int = 3,  # 显示多少个低相似度帧
            highlight_segments: Optional[List[dict]] = None,  # 新增参数
            ground_truth: Optional[List[Tuple[float, float]]] = None  # 新增参数：[(start_time, end_time), ...]
    ) -> np.ndarray:
        """
        可视化视频帧与查询的相似度曲线

        参数:
            video_path: 视频文件路径或messages格式
            query: 查询文本（如果video_path是messages格式则可以为None）
            clip_model_name: CLIP模型名称（如果提供则会更新当前模型）
            fps: 采样帧率
            save_path: 保存图表的路径
            show_plot: 是否显示图表
            precomputed_similarities: 预计算的相似度和帧索引，格式为(similarities, frame_indices)
        返回:
            similarities: 相似度数组
        """
        # 检查是否是messages格式
        extracted_video, extracted_query = self._extract_from_messages(video_path)
        if extracted_video is not None:
            video_path = extracted_video
            query = extracted_query

        if query is None:
            raise ValueError("Query must be provided either directly or through messages format")

        if isinstance(video_path, list):
            raise ValueError("visualize_clip_similarities only supports single video")

        plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
        plt.rcParams['axes.unicode_minus'] = False

        # 使用预计算的相似度或重新计算
        if precomputed_similarities is not None:
            similarities, frame_indices = precomputed_similarities
        else:
            # 如果提供了新的模型名称，则会在compute_video_query_similarity中更新
            similarities, raw_sim, frame_indices = self.compute_video_query_similarity(
                video_path=video_path,
                query=query,
                clip_model_name=clip_model_name,
                fps=fps
            )

        # 获取视频信息
        vr = VideoReader(video_path, ctx=cpu(0))
        video_fps = vr.get_avg_fps()

        # ========== 新增：显示极端相似度帧 ==========
        if show_extreme_frames and isinstance(video_path, str):
            # 找出相似度最低和最高的帧
            sorted_indices = np.argsort(similarities)
            low_sim_indices = sorted_indices[:num_extreme_frames]
            high_sim_indices = sorted_indices[-num_extreme_frames:][::-1]  # 反转使最高的在前

            print(f"\n分析查询文本: '{query}'")
            print(f"相似度范围: [{similarities.min():.3f}, {similarities.max():.3f}]")
            print("=" * 80)

            # 创建一个新的figure来显示极端帧
            fig_frames = plt.figure(figsize=(15, 10))

            # 显示最低相似度帧（上排）
            print(f"\n相似度最低的 {num_extreme_frames} 个帧:")
            print("-" * 60)

            for i, idx in enumerate(low_sim_indices):
                frame_idx = frame_indices[idx]
                similarity_score = similarities[idx]
                timestamp = frame_idx / video_fps

                # 获取该帧的图像
                frame = vr.get_batch([frame_idx]).asnumpy()[0]

                # 在subplot中显示（上排）
                ax = plt.subplot(2, num_extreme_frames, i + 1)
                ax.imshow(frame)
                ax.set_title(
                    f'Low Sim #{i + 1}\nFrame {frame_idx}\nTime: {timestamp:.1f}s\nSim: {similarity_score:.3f}',
                    fontsize=10, color='red')
                ax.axis('off')

                print(
                    f"  低相似度帧 {i + 1}: 索引={frame_idx}, 时间={timestamp:.1f}秒, 相似度={similarity_score:.3f}")

            # 显示最高相似度帧（下排）
            print(f"\n相似度最高的 {num_extreme_frames} 个帧:")
            print("-" * 60)

            for i, idx in enumerate(high_sim_indices):
                frame_idx = frame_indices[idx]
                similarity_score = similarities[idx]
                timestamp = frame_idx / video_fps

                # 获取该帧的图像
                frame = vr.get_batch([frame_idx]).asnumpy()[0]

                # 在subplot中显示（下排）
                ax = plt.subplot(2, num_extreme_frames, num_extreme_frames + i + 1)
                ax.imshow(frame)
                ax.set_title(
                    f'High Sim #{i + 1}\nFrame {frame_idx}\nTime: {timestamp:.1f}s\nSim: {similarity_score:.3f}',
                    fontsize=10, color='green')
                ax.axis('off')

                print(
                    f"  高相似度帧 {i + 1}: 索引={frame_idx}, 时间={timestamp:.1f}秒, 相似度={similarity_score:.3f}")

            # 添加总标题
            fig_frames.suptitle(f'Query: "{query}"\nTop: Lowest Similarity | Bottom: Highest Similarity',
                                fontsize=14, y=0.98)

            plt.tight_layout()

            # 保存极端帧的图片
            if save_path:
                extreme_frames_path = save_path.replace('.png', '_extreme_frames.png')
                plt.savefig(extreme_frames_path, dpi=150, bbox_inches='tight')
                print(f"\n极端相似度帧已保存到: {extreme_frames_path}")

            if show_plot:
                plt.show()
            else:
                plt.close()

            print("=" * 80)

            # 额外分析：找出相似度恰好为1.0或0.0的帧
            exact_one = np.where(similarities == 1.0)[0]
            exact_zero = np.where(similarities == 0.0)[0]

            if len(exact_one) > 0:
                print(f"\n发现 {len(exact_one)} 个相似度恰好为 1.000 的帧:")
                for idx in exact_one[:5]:  # 最多显示5个
                    frame_idx = frame_indices[idx]
                    print(f"  帧 {frame_idx}, 时间: {frame_idx / video_fps:.1f}秒")

            if len(exact_zero) > 0:
                print(f"\n发现 {len(exact_zero)} 个相似度恰好为 0.000 的帧:")
                for idx in exact_zero[:5]:  # 最多显示5个
                    frame_idx = frame_indices[idx]
                    print(f"  帧 {frame_idx}, 时间: {frame_idx / video_fps:.1f}秒")
        # 计算时间戳
        timestamps = [idx / video_fps for idx in frame_indices]

        # 创建图表
        plt.figure(figsize=(12, 6))

        # 先绘制区间背景（这样区间会在曲线后面）
        if highlight_segments:
            for i, seg in enumerate(highlight_segments):
                # 使用浅色高亮显示区间
                plt.axvspan(
                    seg['start_time'],
                    seg['end_time'],
                    alpha=0.2,
                    color='green',
                    label=f"Segment {i + 1}" if i == 0 else None  # 只为第一个区间添加图例
                )

                # 可选：在区间顶部添加文字标注
                # mid_time = (seg['start_time'] + seg['end_time']) / 2
                # plt.text(
                #     mid_time,
                #     plt.ylim()[1] * 0.95,  # 在图表顶部附近
                #     f"Seg {i + 1}\n{seg['avg_similarity']:.3f}",
                #     ha='center',
                #     va='top',
                #     fontsize=9,
                #     bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7)
                # )

        if ground_truth:
            print(f"正在绘制 {len(ground_truth)} 个 Ground Truth 区间")
            for i, (start, end) in enumerate(ground_truth):
                plt.axvspan(
                    start,
                    end,
                    alpha=0.15,  # 半透明红色
                    facecolor='red',
                    edgecolor='darkred',
                    linewidth=2,
                    linestyle='--',
                    label='Ground Truth' if i == 0 else None,
                    zorder=1  # 确保在底层
                )

        plt.plot(timestamps, similarities, 'b-', linewidth=2)

        plt.xlabel('Time (seconds)', fontsize=12)
        plt.ylabel('CLIP Similarity Score', fontsize=12)
        plt.title(f'Video Frame Similarity with Query: "{query}"', fontsize=14)
        plt.grid(True, alpha=0.3)

        # 标记最高相似度的点
        max_idx = np.argmax(similarities)
        plt.scatter(timestamps[max_idx], similarities[max_idx],
                    color='red', s=100, zorder=5,
                    label=f'Max Score: {similarities[max_idx]:.3f}')

        # 如果有区间，添加阈值线
        if highlight_segments and len(highlight_segments) > 0:
            # 从第一个区间推断阈值（假设所有区间的最小相似度接近阈值）
            threshold = min(seg['min_similarity'] for seg in highlight_segments) * 0.99
            plt.axhline(y=threshold, color='r', linestyle='--', alpha=0.5,
                        label=f'Threshold: {threshold:.3f}')

        plt.legend()

        # 保存或显示
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved to: {save_path}")

        if show_plot:
            plt.show()
        else:
            plt.close()

        return similarities