from langchain import hub
from langchain.agents import AgentExecutor, create_react_agent, tool
from langchain_openai import OpenAI, ChatOpenAI, AzureChatOpenAI
from tools import ToolKit
import ast
import sys
import os
import json
import re
from io import StringIO
from captioning import Captioning
from segment_feature import SegmentFeature
from tracking import Tracking
from reid import ReID
from multiprocessing import Process
import socket
from omegaconf import OmegaConf
import openai
import time
import numpy as np
# from evaluation_module import EvaluationModule
import json
from tqdm import tqdm
import traceback

# 全局变量用于统计token和时间
total_output_tokens = 0
total_llm_time = 0.0
llm_call_count = 0

def count_tokens_simple(text):
    """简单的token计数方法，大约按4个字符=1个token计算"""
    return len(text) // 4

def track_llm_call(llm, prompt, call_description="LLM调用"):
    """包装LLM调用，统计token和时间"""
    global total_output_tokens, total_llm_time, llm_call_count
    
    print(f"\n=== {call_description} 开始 ===")
    start_time = time.time()
    
    # 执行LLM调用
    response = llm.invoke(prompt)
    
    end_time = time.time()
    call_time = end_time - start_time
    
    # 统计输出token数（只统计输出）
    output_tokens = count_tokens_simple(response.content)
    
    # 更新全局统计
    total_output_tokens += output_tokens
    total_llm_time += call_time
    llm_call_count += 1
    
    print(f"本次调用输出token数: {output_tokens}")
    print(f"本次调用用时: {call_time:.2f}秒")
    print(f"=== {call_description} 结束 ===\n")
    
    return response

os.environ["AZURE_OPENAI_ENDPOINT"] = "YOUR_AZURE_OPENAI_ENDPOINT"
os.environ["AZURE_OPENAI_API_KEY"] = "YOUR_AZURE_OPENAI_API_KEY"


def ReActAgent(toolkit, question):  # Receives pre-initialized toolkit
    assert vqa_tool in ['videollava', 'gpt-4v']

    @tool
    def divergent_search(input_tuple):
        """Find video segments related to a query within a broad time range and generate rough descriptions.
        TOOL INPUT FORMAT: ('query_text', (start_time, end_time))
        input must be: ('man with glasses', (150.0, 315.0))
        EXAMPLE: ('person', (0.0, 90.0))
        Returns top-k most relevant segments with timestamps and rough descriptions."""
        try:
            # 1. 清理输入字符串，移除可能的外部引号和特殊字符
            if isinstance(input_tuple, str):
                # 移除整体的引号和反引号
                input_tuple = input_tuple.strip()
                if (input_tuple.startswith('"') and input_tuple.endswith('"')) or \
                   (input_tuple.startswith("'") and input_tuple.endswith("'")) or \
                   (input_tuple.startswith("`") and input_tuple.endswith("`")):
                    input_tuple = input_tuple[1:-1].strip()
                
                # 检查是否缺少引号的情况
                if input_tuple.startswith("(") and "," in input_tuple:
                    # 尝试提取查询文本和时间范围
                    match = re.match(r"\(\s*([^,]*),\s*(\(.*?\))\s*\)", input_tuple)
                    if match:
                        query_text = match.group(1).strip()
                        time_range = match.group(2).strip()
                        
                        # 如果查询文本没有引号，添加单引号
                        if not (query_text.startswith("'") and query_text.endswith("'")) and not (query_text.startswith('"') and query_text.endswith('"')):
                            query_text = f"'{query_text}'"
                        
                        # 重构输入元组字符串
                        input_tuple = f"({query_text}, {time_range})"
            
            # 2. 尝试解析元组
            try:
                input_tuple = ast.literal_eval(input_tuple)
            except (ValueError, SyntaxError) as e:
                # 如果解析失败，尝试更复杂的正则表达式提取
                if isinstance(input_tuple, str):
                    pattern = r"\(\s*(?:'([^']*)'|\"([^\"]*)\"|([^,]*)),\s*\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*\)\s*\)"
                    match = re.match(pattern, input_tuple)
                    if match:
                        # 提取查询文本（可能在三个捕获组中的任意一个）
                        query = match.group(1) or match.group(2) or match.group(3).strip()
                        start_time = float(match.group(4))
                        end_time = float(match.group(5))
                        input_tuple = (query, (start_time, end_time))
                    else:
                        raise ValueError(f"无法解析输入: {input_tuple}")
            
            # 3. 验证元组格式
            if len(input_tuple) != 2 or not isinstance(input_tuple[1], tuple) or len(input_tuple[1]) != 2:
                return "\nInvalid input tuple! Expected format: (query, (start_time, end_time))\n"
            
            query = input_tuple[0]
            start_time, end_time = input_tuple[1]
            
            # 4. 确保时间值是浮点数
            try:
                start_time = float(start_time)
                end_time = float(end_time)
            except (ValueError, TypeError):
                return "\nInvalid time values! Time values must be numbers.\n"
            
            print('处理后的input_tuple:', (query, (start_time, end_time)))
            
        except Exception as e:
            return f"\nInvalid input format. Expected tuple (query, (start_time, end_time)). Error: {str(e)}\n"

        # Call toolkit.key_frame_selection with query and time range
        answer = toolkit.key_frame_selection(query, start_time, end_time, k=5)
        return f'\n{answer}\n'
    

    @tool
    def spatial_focus(input_list):
        """Analyze spatial relationships and visual attributes at specific time points in the video.
        TOOL INPUT FORMAT: [('question_text1', time_point1), ('question_text2', time_point2), ...]
        EXAMPLE: [('What objects are visible in the scene?', 10.5), ('What color is the car?', 20.3)]
        NOTE: Specializes in understanding scene composition, object attributes, and spatial relationships."""
        try:
            # 预处理输入字符串，移除时间值中的's'单位
            if isinstance(input_list, str):
                input_list = input_list.replace('s)', ')')
            
            # 尝试解析列表
            try:
                # 尝试直接使用 ast.literal_eval 解析
                input_list = ast.literal_eval(input_list)
            except (ValueError, SyntaxError) as e:
                print(f"ast.literal_eval 解析失败: {str(e)}")
                # 如果解析失败，尝试更复杂的正则表达式提取
                if isinstance(input_list, str):
                    # 尝试匹配整个列表模式
                    list_pattern = r"\[\s*(.*?)\s*\]"
                    list_match = re.search(list_pattern, input_list, re.DOTALL)
                    if list_match:
                        items_str = list_match.group(1)
                        # 分割列表项
                        items = []
                        # 匹配每个元组 ('text', number) 或 ("text", number) 或 (text, number)
                        tuple_pattern = r"\(\s*(?:'([^']*)'|\"([^\"]*)\"|([^,]*)),\s*([0-9.]+)\s*\)"
                        for match in re.finditer(tuple_pattern, items_str):
                            # 提取问题文本（可能在三个捕获组中的任意一个）
                            question = match.group(1) or match.group(2) or match.group(3).strip()
                            time_point = float(match.group(4))
                            items.append((question, time_point))
                        
                        if items:
                            input_list = items
                        else:
                            raise ValueError(f"无法解析输入列表项: {items_str}")
                    else:
                        raise ValueError(f"无法解析输入: {input_list}")
            
            # 验证列表格式
            if not isinstance(input_list, list):
                return "\nInvalid input! Expected a list of (question, time_point) tuples.\n"
            
            # 验证每个元组
            validated_items = []
            for item in input_list:
                if not isinstance(item, tuple) or len(item) != 2:
                    print(f"跳过无效项: {item}")
                    continue
                
                question, time_point = item
                
                try:
                    time_point = float(time_point)
                except (ValueError, TypeError):
                    print(f"跳过无效时间值: {item}")
                    continue
                
                # 检查时间点是否在有效范围内
                if time_point < 0 or time_point > toolkit.max_seconds:
                    print(f"\nWarning: Time point {time_point}s is outside the valid range (0 to {toolkit.max_seconds:.2f}s).\n")
                    time_point = max(0, min(toolkit.max_seconds, time_point))
                
                validated_items.append((question, time_point))
            
            print('Processed input_list:', validated_items)
            
            # 如果没有有效项，返回错误信息
            if not validated_items:
                return "\nNo valid input items found. Please check your input format.\n"
            
        except Exception as e:
            print(f"Error processing input: {str(e)}")
            traceback.print_exc()  # 打印完整的错误堆栈
            return f"\nInvalid input format. Expected list of (question, time_point) tuples. Error: {str(e)}\n"

        # 调用 toolkit.frame_analysis 函数
        answer = toolkit.frame_analysis(validated_items)
        return '\n' + answer + '\n'
        
    @tool
    def temporal_focus(input_list):
        """Identify key scenes within specific time intervals and generate captions.
        TOOL INPUT FORMAT: [(start_time1, end_time1), (start_time2, end_time2), ...]
        EXAMPLE: [(10.0, 30.0), (37.0, 47.5), (70.0, 78.0)]
        Returns timestamps and captions of the most representative scenes in the given time ranges."""
        try:
            # 预处理输入字符串，移除时间值中的's'单位
            if isinstance(input_list, str):
                input_list = input_list.replace('s)', ')')
            
            # 尝试解析列表
            try:
                # 尝试直接使用 ast.literal_eval 解析
                input_list = ast.literal_eval(input_list)
            except (ValueError, SyntaxError) as e:
                print(f"ast.literal_eval 解析失败: {str(e)}")
                # 如果解析失败，尝试更复杂的正则表达式提取
                if isinstance(input_list, str):
                    # 尝试匹配整个列表模式
                    list_pattern = r"\[\s*(.*?)\s*\]"
                    list_match = re.search(list_pattern, input_list, re.DOTALL)
                    if list_match:
                        items_str = list_match.group(1)
                        # 分割列表项
                        items = []
                        # 匹配每个元组 (start_time, end_time)
                        tuple_pattern = r"\(\s*([0-9.]+)\s*,\s*([0-9.]+)\s*\)"
                        for match in re.finditer(tuple_pattern, items_str):
                            start_time = float(match.group(1))
                            end_time = float(match.group(2))
                            items.append((start_time, end_time))
                        
                        if items:
                            input_list = items
                        else:
                            raise ValueError(f"无法解析输入列表项: {items_str}")
                    else:
                        raise ValueError(f"无法解析输入: {input_list}")
            
            # 验证列表格式
            if not isinstance(input_list, list):
                return "\nInvalid input! Expected a list of (start_time, end_time) tuples.\n"
            
            # 验证每个元组
            validated_items = []
            for item in input_list:
                if not isinstance(item, tuple) or len(item) != 2:
                    print(f"跳过无效项: {item}")
                    continue
                
                start_time, end_time = item
                
                # 确保时间值是浮点数
                try:
                    start_time = float(start_time)
                    end_time = float(end_time)
                except (ValueError, TypeError):
                    print(f"跳过无效时间值: {item}")
                    continue
                
                # 检查时间范围是否有效
                if start_time < 0 or end_time > toolkit.max_seconds or start_time >= end_time:
                    print(f"\nWarning: Time range {start_time}s to {end_time}s is invalid or outside the video range (0 to {toolkit.max_seconds:.2f}s).\n")
                    start_time = max(0, start_time)
                    end_time = min(toolkit.max_seconds, end_time)
                    if start_time >= end_time:
                        print(f"跳过无效时间范围: {start_time}s >= {end_time}s")
                        continue
                
                validated_items.append((start_time, end_time))
            
            print('Processed input_list:', validated_items)
            
            # 如果没有有效项，返回错误信息
            if not validated_items:
                return "\nNo valid input items found. Please check your input format.\n"
            
        except Exception as e:
            print(f"Error processing input: {str(e)}")
            traceback.print_exc()  # 打印完整的错误堆栈
            return f"\nInvalid input format. Expected list of (start_time, end_time) tuples. Error: {str(e)}\n"
        
        # 对每个时间段调用 toolkit.time_focused_analysis 并合并结果
        all_results = []
        for start_time, end_time in validated_items:
            result = toolkit.time_focused_analysis(start_time, end_time, k=5)
            all_results.append(f"Time Range ({start_time}s - {end_time}s):\n{result}")
        
        # 合并所有结果
        combined_result = "\n\n".join(all_results)
        return f'\n{combined_result}\n'
        
    # 初始化LLM
    # 初始化LLM
    llm = ChatOpenAI(model='', openai_api_base="", temperature=0.0, openai_api_key='')
    
    tools = [divergent_search, spatial_focus, temporal_focus]

    # 初始化memory为嵌套字典格式
    memory = {}  # 第一级是推理轮次，每个轮次包含action、input和observation
    
    # 存储原始问题和视频时长（用于prompt中）
    question_info = question
    video_duration_info = f"{toolkit.max_seconds:.2f}"
    
    # 使用time_focused_analysis对整个视频进行聚类分析初始化memory
    print("\n===== Initializing with temporal_focus =====") 
    # 构造整个视频时间范围的输入
    full_video_input = f"[(0.0, {toolkit.max_seconds})]"
    memory[0] = {
        "action": "temporal_focus",
        "input": full_video_input,
        "observation": temporal_focus.invoke(full_video_input)
    }
    print(f"Initial temporal_focus result:\n{memory[0]['observation']}")
    
    # 加载Perception Agent的prompt
    with open('prompts/prompt_perception.txt') as f:
        perception_prompt_template = f.read()
    
    # 加载Reflection Agent的prompt
    with open('prompts/prompt_reflection_new.txt') as f:
        reflection_prompt_template = f.read()
    
    max_iterations = 3
    current_iteration = 0
    final_answer = None
    guidance = "Start by understanding the overall video content and identifying key events or objects relevant to the question."
    
    # 在主循环中修改Reflection Agent的调用逻辑
    # 在主循环中修改Reflection Agent的调用逻辑
    
        # ... existing code ...
        
    # 在主循环部分的修改（替换现有的reflection agent逻辑）

    # 主循环
    while current_iteration < max_iterations:
        current_iteration += 1
        print(f"\n===== Iteration {current_iteration}/{max_iterations} =====")
        
        # 获取latest observation（如果存在）
        latest_observation = ""
        if current_iteration > 1 and (current_iteration - 1) in memory:
            latest_observation = memory[current_iteration - 1].get('observation', '')
        elif current_iteration == 1:
            latest_observation = memory[0].get('observation', '')  # 初始temporal_focus结果
        
        # 准备working memory（已验证的可信信息）
        working_memory = {}
        for k, v in memory.items():
            if k != current_iteration and 'observation' in v and v['observation']:
                working_memory[k] = v
        
        # Step 1: 验证步骤 - 分析latest observation
        if latest_observation:
            print("\n=== Step 1: Verification Analysis ===")
            
            # 加载第一步prompt模板
            with open('prompts/prompt_reflection_step1_verification.txt') as f:
                step1_prompt_template = f.read()
            
            step1_prompt = step1_prompt_template.format(
                latest_observation=latest_observation,
                question=question_info,
                video_duration=video_duration_info
            )
            
            step1_response = track_llm_call(llm, step1_prompt, "Step 1 Verification Analysis")
            print(f"Step 1 Verification Analysis:\n{step1_response.content}")
            
            # 解析第一步输出
            task_relevant = False
            verification_questions = []
            direct_info = ""
            
            # 在Step 1解析部分的修改

            try:
                # 解析第一步输出 - 根据prompt格式修改
                relevant_pattern = r"Key Information:\s*(YES|NO|yes|no)"
                relevant_match = re.search(relevant_pattern, step1_response.content, re.IGNORECASE)
                
                if relevant_match and relevant_match.group(1).lower() == "yes":
                    task_relevant = True
                    
                    # 提取关键声明
                    claims_pattern = r"Key Claims:\s*\[(.*?)\]"
                    claims_match = re.search(claims_pattern, step1_response.content, re.DOTALL)
                    
                    # 提取验证问题 - 新格式：列表形式
                    questions_pattern = r"Verification Questions:\s*\[(.*?)\]"
                    questions_match = re.search(questions_pattern, step1_response.content, re.DOTALL)
                    
                    if questions_match:
                        questions_str = questions_match.group(1)
                        try:
                            # 清理字符串，移除多余的换行符和空格
                            questions_str = re.sub(r'\s+', ' ', questions_str.strip())
                            
                            # 使用正则表达式提取所有问题和时间戳对，支持带"s"后缀的时间戳
                            question_pattern = r'\(\s*"([^"]+)"\s*,\s*([\d.]+)s?\s*\)'
                            matches = re.findall(question_pattern, questions_str)
                            
                            if matches:
                                # 构建标准格式的验证问题列表：[(question, timestamp), ...]
                                # 移除时间戳中可能的"s"后缀
                                verification_questions = [(question.strip(), float(timestamp.rstrip('s'))) for question, timestamp in matches]
                                print(f"提取的验证问题: {verification_questions}")
                            else:
                                # 如果正则匹配失败，尝试使用ast.literal_eval作为备选
                                verification_questions = ast.literal_eval(f"[{questions_str}]")
                                print(f"使用备选方法提取的验证问题: {verification_questions}")
                                
                        except Exception as e:
                            print(f"解析验证问题失败: {e}")
                            print(f"原始字符串: {repr(questions_str)}")
                            verification_questions = []
                    else:
                        print("未找到验证问题")
                        verification_questions = []
                        
                else:
                    print("Latest observation不包含任务相关信息")
                    task_relevant = False
                
        
            except Exception as e:
                print(f"解析Step 1输出失败: {e}")
            
            # 执行验证（如果需要）
            verified_info = ""
            if task_relevant and verification_questions:
                print("\n=== 执行验证 ===")
                try:
                    # 直接使用新格式的验证问题列表
                    verification_input = str(verification_questions)
                    
                    verification_observation = spatial_focus.invoke(verification_input)
                    print(f"验证结果: {verification_observation}")
                    
                    # 基于验证结果提取可信信息
                    verification_analysis_prompt = f"""
                    Based on the verification results, please determine what reliable information should be extracted. 
                    
                    Original observation: {latest_observation}
                    Verification questions: {verification_questions}
                    Verification results: {verification_observation}
                    
                    CRITICAL: You MUST output ONLY the following format with NO additional text or explanations:
                    Verified Information: [confirmed reliable information or "No reliable information confirmed"]
                    """
                    
                    verification_analysis = track_llm_call(llm, verification_analysis_prompt, "Verification Analysis")
                    print(f"Verification Analysis Output: {verification_analysis.content}")
                    
                    # 提取验证后的可信信息
                    verified_pattern = r"Verified Information:\s*(.+?)(?=\n\n|$)"
                    verified_match = re.search(verified_pattern, verification_analysis.content, re.DOTALL)
                    
                    if verified_match:
                        verified_info = verified_match.group(1).strip()
                        if "no reliable information" not in verified_info.lower():
                            print(f"验证后的可信信息: {verified_info}")
                        else:
                            verified_info = ""
                            print("验证后未发现可信信息")
                    
                except Exception as e:
                    print(f"验证执行失败: {e}")
            elif task_relevant and direct_info:
                verified_info = direct_info
            
            # 将可信信息存入working memory
            if verified_info:
                memory[current_iteration] = {
                    "action": "verified_information_extraction",
                    "input": latest_observation,
                    "observation": verified_info
                }
                # 更新working memory
                working_memory[current_iteration] = memory[current_iteration]
        
        # Step 2: 信息充分性判断
        print("\n=== Step 2: Information Sufficiency Assessment ===")
        
        # 加载第二步prompt模板
        with open('prompts/prompt_reflection_step2_sufficiency.txt') as f:
            step2_prompt_template = f.read()
        
        step2_prompt = step2_prompt_template.format(
            question=question_info,
            video_duration=video_duration_info,
            working_memory=json.dumps(working_memory, ensure_ascii=False, indent=2)
        )
        
        step2_response = track_llm_call(llm, step2_prompt, "Step 2 Sufficiency Assessment")
        print(f"Step 2 Sufficiency Assessment:\n{step2_response.content}")
        
        # 解析第二步输出
        try:
            # 检查决策
            decision_pattern = r"Decision:\s*(continue|terminate)"
            decision_match = re.search(decision_pattern, step2_response.content, re.IGNORECASE)
            
            if decision_match and decision_match.group(1).lower() == "terminate":
                # 提取最终答案
                answer_pattern = r"Final Answer:\s*(\d+)"
                answer_match = re.search(answer_pattern, step2_response.content)
                
                if answer_match:
                    final_answer = int(answer_match.group(1))
                    print(f"Reflection Agent决定终止，最终答案: {final_answer}")
                    break
                else:
                    print("警告: 决定终止但未找到最终答案，继续执行")
            
            # 如果需要继续，提取guidance
            if decision_match and decision_match.group(1).lower() == "continue":
                guidance_pattern = r"Guidance:\s*(.+?)(?=\n\n|$)"
                guidance_match = re.search(guidance_pattern, step2_response.content, re.DOTALL)
                
                if guidance_match:
                    guidance = guidance_match.group(1).strip()
                else:
                    guidance = "Continue gathering more information to answer the question."
            
        except Exception as e:
            print(f"解析Step 2输出失败: {e}")
            guidance = "Continue gathering more information to answer the question."
        
        # 如果决定终止，跳出循环
        if final_answer is not None:
            break
        
        # 3. Perception Agent: 基于引导信息选择感知工具
        perception_prompt = perception_prompt_template.format(
            memory=json.dumps(working_memory, ensure_ascii=False, indent=2),
            tools="\n".join([f"{t.name}: {t.description}" for t in tools]),
            question=question_info,
            video_duration=video_duration_info,
            guidance=guidance
        )
        
        perception_response = track_llm_call(llm, perception_prompt, "Perception Agent Decision")
        print(f"Perception Agent's decision:\n{perception_response.content}")
        
        # 为当前轮次初始化memory字典（如果还没有）
        if current_iteration not in memory:
            memory[current_iteration] = {
                "action": "",
                "input": "",
                "observation": ""
            }
        
        # 解析Perception Agent的输出并执行工具
        try:
            # 使用正则表达式提取工具名称和输入
            tool_pattern = r"Tool Name:\s*(.*?)\nTool Input:\s*(.*?)(?=\n\n|$)"
            tool_match = re.search(tool_pattern, perception_response.content, re.DOTALL)
            
            if tool_match:
                tool_name = tool_match.group(1).strip()
                tool_input = tool_match.group(2).strip()
                
                # 清理工具名称，移除可能的引号
                if (tool_name.startswith('"') and tool_name.endswith('"')) or \
                (tool_name.startswith("'") and tool_name.endswith("'")) or \
                (tool_name.startswith("`") and tool_name.endswith("`")):
                    tool_name = tool_name[1:-1].strip()
                
                # 处理工具输入中的换行符
                if tool_name in ["clip_reasoning", "frame_analysis"]:
                    tool_input = re.sub(r'\s*\n\s*', ' ', tool_input)
                    tool_input = re.sub(r'\s+', ' ', tool_input)
                
                # 记录工具名称和输入到当前轮次
                memory[current_iteration]["action"] = tool_name
                memory[current_iteration]["input"] = tool_input
                
                # 查找对应的工具
                selected_tool = None
                for t in tools:
                    if t.name == tool_name:
                        selected_tool = t
                        break
                
                if selected_tool:
                    # 执行工具
                    print(f"Executing tool: {tool_name}, input: {tool_input}")
                    observation = selected_tool.invoke(tool_input)
                    
                    # 记录observation到当前轮次
                    memory[current_iteration]["observation"] = observation
                    print(f"Tool execution result:\n{observation}")
                else:
                    print(f"Error: Tool {tool_name} not found")
                    memory[current_iteration]["observation"] = f"Error: Tool {tool_name} not found"
            
            else:
                print("Error: Unable to parse Perception Agent's output")
                memory[current_iteration]["observation"] = "Error: Unable to parse Perception Agent's output"
        
        except Exception as e:
            print(f"Error executing tool: {str(e)}")
            traceback.print_exc()
            memory[current_iteration]["observation"] = f"Error executing tool: {str(e)}"

    # ... existing code ...
    
    # 如果达到最大迭代次数但没有最终答案，进行最后一次反思
    if final_answer is None:
        print("\n===== Final Reflection =====")
        final_reflection_prompt = f"""You are a Reflection Agent responsible for analyzing the current information state and providing a final answer. 

IMPORTANT: You have reached the maximum number of iterations. You MUST terminate and provide a final answer based on the available information.

Current question: {question_info}
Video duration: {video_duration_info} seconds

Current working memory:
{json.dumps(memory, ensure_ascii=False, indent=2)}

IMPORTANT Notes:
1. The segment captions with prefix '#C' refer to the camera wearer, while those with prefix '#O' refer to someone other than the camera wearer.
2. You MUST make a decision based on the available information, even if it's not perfect.
3. You MUST provide a final answer between 0-4.

Please analyze the current working memory and provide your final decision:

Output format:
Chain of Thought: [detailed analysis of the available information and reasoning for your answer choice]
Decision: terminate
Reasoning: [explain your reasoning for the chosen answer based on available evidence]
Final Answer: [number between 0-4 based on the collected evidence]"""
        
        final_reflection_response = track_llm_call(llm, final_reflection_prompt, "Final Reflection")
        print(f"Final Reflection Agent's analysis:\n{final_reflection_response.content}")
        
        # 强制提取最终答案 - 增强解析逻辑
        answer_pattern = r"Final Answer:\s*\**\s*(\d+)\s*\**"
        answer_match = re.search(answer_pattern, final_reflection_response.content)
        
        if answer_match:
            final_answer = int(answer_match.group(1))
        else:
            # 尝试更宽松的匹配模式
            loose_patterns = [
                r"(?:Final Answer|Answer)\s*:?\s*\**\s*(\d+)\s*\**",
                r"Decision:\s*terminate.*?\**(\d+)\**",
                r"Decision:\s*terminate.*?(\d+)"
            ]
            
            for pattern in loose_patterns:
                match = re.search(pattern, final_reflection_response.content, re.DOTALL | re.IGNORECASE)
                if match:
                    final_answer = int(match.group(1))
                    break
            
            if final_answer is None:
                # 最后尝试提取任何0-4之间的数字
                number_pattern = r"\b([0-4])\b"
                number_matches = re.findall(number_pattern, final_reflection_response.content)
                if number_matches:
                    final_answer = int(number_matches[-1])  # 取最后一个匹配的数字
                    print(f"使用备用解析方案，最终答案: {final_answer}")
                else:
                    print("Warning: Unable to extract valid answer from Final Reflection, using default answer 0")
                    final_answer = 0

    # 生成完整日志
    log = f"Question: {question_info}\n\n"
    log += f"Video duration: {video_duration_info} seconds\n"
    log += f"Iterations: {current_iteration}/{max_iterations}\n\n"
    log += "Interaction history:\n" + json.dumps(memory, ensure_ascii=False, indent=2) + "\n\n"
    log += f"Final Answer: {final_answer}\n"
    
    return final_answer, log


# 与预测结果对比，计算准确率
def calculate_accuracy_and_score(predictions, correct_answers):
    correct_count = sum([1 for pred, ans in zip(predictions, correct_answers) if pred == ans])
    accuracy = correct_count / len(predictions) if len(predictions) > 0 else 0
    return accuracy


# 主函数
def main(video_question_list, video_path_list, base_dir, vqa_tool, use_reid, openai_api_key, caption_model='lavila', mini_test=True, mini_test_number=10, demo_mode=False):
    # 共享模型加载 - 无论是否为演示模式，都需要加载模型
    shared_models = ToolKit.preload_models()  # 模型只加载一次
    
    # 创建代理函数 - 无论是否为演示模式，都需要此函数
    def create_agent(video_path, question):
        toolkit = ToolKit(video_path=video_path,
                        base_dir=base_dir,
                        vqa_tool=vqa_tool,
                        use_reid=use_reid,
                        openai_api_key=openai_api_key,
                        shared_models=shared_models,
                        caption_model=caption_model)  # 传入 caption_model 参数
        
        return ReActAgent(toolkit, question)
    
    # 如果是演示模式，则使用自定义的视频和问题
    if demo_mode:
        print("Running in demo mode with custom video and question...")
        # 演示模式下直接处理提供的视频和问题
        total_time = 0
        for i, video_file in enumerate(video_path_list):
            question = video_question_list[i]
            print(f"Processing video {video_file}, question: {question}")
            
            try:
                start_time = time.time()
                # 执行推理
                answer, log = create_agent(video_path_list[i], video_question_list[i])
                
                # 打印结果
                print(f"\n===== Demo Result =====")
                print(f"Question: {question}")
                print(f"Answer: {answer}")
                

            except Exception as e:
                print(f"Error processing video {video_file}: {e}")
                # 打印完整的错误栈信息
                print("Complete error information:")
                print(traceback.format_exc())

    else:
        # 非演示模式 - 简化的数据集处理逻辑

        # 直接加载子集问题
        with open('./EgoSchema/subset_questions.json') as f:
            subset_questions = json.load(f)
        
        # 限制处理的视频数量
        subset_questions = subset_questions[:mini_test_number] if mini_test else subset_questions
        
        # 预测结果
        total_time = 0
        predictions = []
        correct_answers = []
        
        for i, q in enumerate(subset_questions):
            video_id = q["q_uid"]
            video_file = f"./EgoSchema/videos/videos/{video_id}.mp4"
            
            # 确保视频存在
            if not os.path.exists(video_file):
                print(f"视频 {video_id} 不存在, 跳过处理")
                continue
            
            # 构建问题和选项
            question_content = f"Question for video {video_id}: {q['question']}\n"
            for j in range(5):
                question_content += f'{j}: {q[f"option {j}"]}\n'
            
            # 获取正确答案
            correct_answer = q["correct_answer"]
            correct_answers.append(correct_answer)
            
            print(f"Processing video {video_file}, question: {question_content}")
            
            answer_path = f"./EgoSchema/answers/question_{video_id}_log.txt"
            # 确保答案目录存在
            os.makedirs(os.path.dirname(answer_path), exist_ok=True)
                
            try:
                start_time = time.time()
                # 执行推理
                answer, log = create_agent(video_file, question_content)
                predictions.append(int(answer))
                
                # 打印日志信息
                print(f"预测结果: {answer}, 真实答案: {correct_answer}")
                
                # 记录日志
                with open(answer_path, 'w') as f:
                    f.write(log)
                    f.write(f"真实答案: {correct_answer}\n\n")  # 添加真实答案到日志
                
                end_time = time.time()
                print(f"推理时间：{np.round(end_time - start_time, 2)}秒")
                total_time += end_time - start_time
            except Exception as e:
                print(f"处理视频 {video_file} 时出错: {e}")
                # 打印完整的错误栈信息
                print("完整错误信息:")
                print(traceback.format_exc())
                continue
        
        # 只在非演示模式下计算准确率
        if len(predictions) > 0:
            # 计算准确率
            accuracy = calculate_accuracy_and_score(predictions, correct_answers)
            print(f"模型预测的准确率：{accuracy * 100:.2f}%")
            # 计算平均时间
            avg_time_per_sample = total_time / len(predictions) if len(predictions) > 0 else 0
            print(f"每个样例的平均处理时间：{avg_time_per_sample:.2f}秒")
            
            # 确保accuracy_log.txt的目录存在
            accuracy_log_path = './EgoSchema/answers/accuracy_log.txt'
            os.makedirs(os.path.dirname(accuracy_log_path), exist_ok=True)
            
            # 保存准确率到文件
            with open(accuracy_log_path, 'w') as result_file:
                result_file.write(f"Accuracy: {accuracy * 100:.2f}%\n")
                result_file.write(f"每个样例的平均处理时间：: {avg_time_per_sample:.2f}秒\n")
            
if __name__ == '__main__':
    # 配置和API密钥
    config = OmegaConf.load('config/default.yaml')
    openai_api_key = config['openai_api_key']
    use_reid = config['use_reid']
    vqa_tool = config['vqa_tool']
    base_dir = config['base_dir']
    caption_model = 'lavila'  # 获取 caption_model 参数，默认为 'lavila'

    # 存储视频路径和问题列表
    video_path_list = [
        #"sample_videos/boats.mp4"
        #"sample_videos/talking.mp4",
        #"sample_videos/books.mp4",
        "sample_videos/painting.mp4",
        #"sample_videos/kitchen.mp4"
        #"sample_videos/4882821564.mp4"
        #"sample_videos/00b9a0de-c59e-49cb-a127-6081e2fb8c8e.mp4"
        #"./EgoSchema/videos/videos/026a2f15-c454-4c28-80e0-24c85d7f4ecf.mp4"
        #"sample_videos/8604794910.mp4"
        #"sample_videos/3145698830.mp4"
        #"sample_videos/5919180502.mp4"
    ]
    video_question_list = [
        #"When does c hold the red book?"
        #"When the boats appear in the video?",
        #"How many boats are there in the video?"
        #"From what clue do you know that the woman with black spectacles at the start of the video is married?",
        #"Based on the actions observed, what could be a possible motivation or goal for what c is doing in the video?",
        #"What was the primary purpose of the cup of water in this video, and how did it contribute to the overall painting process?",
        #"Is there a microwave in the kitchen?"
        "What was the primary purpose of the cup of water in this video, and how did it contribute to the overall painting process?\n0: To provide a source of water for the paintbrush.\n1：To provide a place to store the paintbrush.\n2：To provide a place to dispose of the paintbrush.\n3：To provide a place to rest the paintbrush.\n4：To clean the paintbrush."
        #"Considering the sequence of events, what can be inferred about the importance of precision and accuracy in the character's actions, and how is this demonstrated within the video? \n0: For straight line cutting. \n1: For even, consistent cuts. \n2: For safe, efficient cutting. \n3: For correct sizing. \n4: For quick, efficient cutting."
        #"What can be deduced about c's level of expertise in the task by observing the kind of adjustments made throughout the video?\n0:C is a novice woodworker. he was not able to cut the wood to size and install it on the wall without making several adjustments.\n1：C is an expert woodworker. he was able to cut the wood to size and install it on the wall without making any adjustments.\n2: C is a professional woodworker. he was able to cut the wood to size and install it on the wall in a timely and efficient manner.\n3: C is an experienced woodworker. he was able to cut the wood to size and install it on the wall with few adjustments.\n4: C is an amateur woodworker. he was able to cut the wood to size and install it on the wall, but he took a long time to do so."
        #"how did the children protect their head from the sun?\n0: use hands.\n1：shade and cap.\n2：wear hat.\n3：wear yellow hat.\n4：cap and sunglasses."
        #"Why did the boy pick up one present from the group of them and move to the sofa?\n1: share with the girl.\n2: approach lady sitting there.\n3: unwrap it.\n4: playing with toy train.\n5: gesture something."
        #"What color are the eyes of the man with glasses?\n0: blue.\n1: green.\n2: red.\n3: yellow.\n4: black."
        #"how many people are sitting at the ledge of the swimming pool?\n0: six.\n1: four.\n2: two.\n3: eight.\n4: seven."
        #"What does the boy hold in his hand at last?\n0: toy.\n1: book.\n2: tablet computer.\n3: game controller.\n4: cell phone."
        #"Why did the boy move to the sofa?\n0: To open the gift package.\n1: To play with a remote control.\n2: To play with a teddy bear.\n3: To read a book.\n4: to take a rest."
    ]

    # 执行主函数
    main(video_question_list, video_path_list, base_dir, vqa_tool, use_reid, openai_api_key, 
         caption_model=caption_model, mini_test=True, mini_test_number=50, demo_mode=False)
    
    # 输出总体LLM统计信息
    print(f"\n" + "="*50)
    print(f"总体LLM调用统计:")
    print(f"总调用次数: {llm_call_count}")
    print(f"总输出token数: {total_output_tokens}")
    print(f"LLM总用时: {total_llm_time:.2f}秒")
    if llm_call_count > 0:
        print(f"平均每次调用输出token数: {total_output_tokens/llm_call_count:.1f}")
        print(f"平均每次调用用时: {total_llm_time/llm_call_count:.2f}秒")
    print(f"="*50)


    
