import os
import time
import torch
import json
import cv2
import pickle
from eva_clip_extractor import EVACLIPExtractor
from encoder import encode_sentences


model_cfgs = {
    'eva-clip-8b': {
        'model_name': 'EVA-CLIP-8B',
        'model_path': '/root/autodl-tmp/VideoAgent-main/hf_cache/models--BAAI--EVA-CLIP-8B/snapshots/0e4dca944e8ece27eb9dfe4a488c0ed0c4644fc9/',
    }
}

class SegmentFeature:
    def __init__(self, video_path, base_dir='preprocess'):
        self.video_path = video_path
        self.base_dir = base_dir
        # 设置为每秒1帧采样，不再使用原来的分段采样方式
        self.fps_sampling = 1  # 每秒1帧
        
    def create_visual_embedding(self, eva_clip_model=None):
        # 如果传入了预加载的模型，使用它；否则创建新的模型实例
        if eva_clip_model is not None:
            eva_clip = eva_clip_model
            print("使用预加载的EVA-CLIP-8B模型")
        else:
            start_time = time.time()
            # 初始化EVA-CLIP-8B模型
            eva_clip = EVACLIPExtractor()
            end_time = time.time()
            print(f'time for loading EVA-CLIP-8B model: {round(end_time - start_time, 3)} seconds')

        base_name = os.path.basename(self.video_path).replace(".mp4", "")
        video_dir = os.path.join(self.base_dir, base_name)
        if not os.path.exists(video_dir):
            os.makedirs(video_dir)

        # 检查是否已存在视觉特征文件
        visual_embedding_path = os.path.join(video_dir, 'visual_embedding.pkl')
        if os.path.exists(visual_embedding_path):
            print(f"视觉特征文件已存在: {visual_embedding_path}")
            return

        # 使用EVA-CLIP-8B的extract_video_features_1fps方法，每秒1帧采样
        start_time = time.time()
        visual_embeddings = eva_clip.extract_video_features_1fps(self.video_path)
        end_time = time.time()
        
        if visual_embeddings is not None:
            print(f"Embedding time for video {base_name}: {round(end_time - start_time, 3)} seconds")
            print(f"Extracted features shape: {visual_embeddings.shape}")
            
            # 保存视觉特征
            with open(visual_embedding_path, 'wb') as f:
                pickle.dump(visual_embeddings, f)
            print(f"Visual embeddings saved to: {visual_embedding_path}")
        else:
            print(f"Failed to extract features from video: {base_name}")

    def run(self, eva_clip_model=None):
        self.create_textual_embedding()
        self.create_visual_embedding(eva_clip_model=eva_clip_model)

