import openai
import os
import sys
import re
import copy
import shutil
import time
import base64
import cv2
import decord
from typing import Tuple
from decord import VideoReader
from datetime import datetime
import tempfile
import math


FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 64
FRAME_FACTOR = 2
FPS = 2

PREFIX_PROMPT= """You are a helpful assistant.

Think step-by-step before providing your final answer.

Enclose your entire reasoning process within <think> and </think> tags. Enclose your final answer within <answer> and </answer> tags.

If analyzing a specific video segment is necessary to answer the question, you may use the following tool to extract a clip from `[start_time]` to `[end_time]`:

<tool_call>{\"name\":\"get_video_clip_frame\",\"arguments\":[{\"start_time\":[start_time],\"end_time\":[end_time]}]}</tool_call>

Use the insights from the clip to inform your reasoning and construct the final answer."""

CROP_SUCCESS_PROMPT = """Tool execution successful. Analyze the visual information from the provided video clip to answer the user's question."""
CROP_FAIL_PROMPT = """Tool execution failed. Please continue your analysis based on your existing knowledge and the information from the conversation so far."""


def _get_video_info(video_path: str) -> Tuple[float, int, int, int, float]:
    """
    Get basic video information (internal utility method)
    
    Uses Decord to read video metadata.
    Args:
        video_path (str): Path to the video file.
    Returns:
        Tuple[float, int, int, int, float]: Video's frame rate (fps), width, height, total frame count, and total duration.
    """
    # Check if the video file exists
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found: {video_path}")

    # Use Decord VideoReader to get video information
    try:
        vr = VideoReader(video_path)
        fps = vr.get_avg_fps()  # Average frame rate
        total_frames = len(vr)  # Total frame count
        frame_shape = vr[0].shape  # Get the shape of the first frame
        height, width = frame_shape[:2]  # Get height and width
        total_duration = total_frames / fps if fps > 0 else 0

        # Validate video metadata
        if fps <= 0 or width <= 0 or height <= 0 or total_frames <= 0 or total_duration <= 0:
            raise ValueError(f"Invalid video metadata for {video_path}")

        return fps, width, height, total_frames, total_duration

    except Exception as e:
        raise RuntimeError(f"Error reading video file {video_path}: {e}")


def smart_nframes(
    total_frames: int,
    video_fps: int | float,
) -> int:
    """calculate the number of frames for video used for model inputs.

    Args:
        ele (dict): a dict contains the configuration of video.
            support either `fps` or `nframes`:
                - nframes: the number of frames to extract for model inputs.
                - fps: the fps to extract frames for model inputs.
                    - min_frames: the minimum number of frames of the video, only used when fps is provided.
                    - max_frames: the maximum number of frames of the video, only used when fps is provided.
        total_frames (int): the original total number of frames of the video.
        video_fps (int | float): the original fps of the video.

    Raises:
        ValueError: nframes should in interval [FRAME_FACTOR, total_frames].

    Returns:
        int: the number of frames for video used for model inputs.
    """
    def ceil_by_factor(number: int, factor: int) -> int:
        """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
        return math.ceil(number / factor) * factor

    def floor_by_factor(number: int, factor: int) -> int:
        """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
        return math.floor(number / factor) * factor

    fps = FPS
    min_frames = ceil_by_factor(FPS_MIN_FRAMES, FRAME_FACTOR)
    max_frames = floor_by_factor(min(FPS_MAX_FRAMES, total_frames), FRAME_FACTOR)
    nframes = total_frames / video_fps * fps
    if nframes > total_frames:
        print(f"Warning: smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]")
    nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
    nframes = floor_by_factor(nframes, FRAME_FACTOR)
    if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
        raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
    return nframes

def _crop_video(
        input_path: str,
        output_dir: str,
        start_time: float,
        end_time: float
    ) -> str:
    """Core video cropping tool with strict FPS consistency checks"""
    try:
        # Validate timestamp
        if start_time < 0 or end_time <= start_time:
            raise ValueError(f"Invalid timestamp: start={start_time}, end={end_time}")

        # Get original video information
        orig_fps, orig_width, orig_height, total_frames, orig_duration = _get_video_info(input_path)

        # Process boundaries and calculate crop segment duration
        start_time = min(max(0, start_time), orig_duration)
        end_time = min(end_time, orig_duration)
        clip_duration = end_time - start_time

        # Create temporary output file
        CUSTOM_TEMP_DIR = os.path.join(output_dir, datetime.now().strftime("%Y%m%d_%H%M%S"))  # Project temporary folder
        os.makedirs(CUSTOM_TEMP_DIR, exist_ok=True)
        temp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False, dir=CUSTOM_TEMP_DIR)
        output_path = temp_file.name
        temp_file.close()

        max_frames = int(round(clip_duration * orig_fps))  # Safety upper limit (to avoid infinite loops)

        nframes = smart_nframes(max_frames, orig_fps)
        crop_video_fps = nframes / clip_duration
        frame_interval = max_frames // nframes

        # Configure encoder
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            # fps=orig_fps,
            fps=crop_video_fps,
            frameSize=(orig_width, orig_height)
        )

        # Locate the starting frame
        cap = cv2.VideoCapture(input_path)
        pos_set_success = cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * orig_fps))
        if not pos_set_success:
            print(f"Warning: Seeking to start frame failed, reading frame-by-frame...")  
            current_pos = 0
            target_pos = int(start_time * orig_fps)
            while current_pos < target_pos and cap.isOpened():
                ret, _ = cap.read()
                if not ret:
                    raise RuntimeError(f"Unable to reach the starting frame (the original video is too short).")
                current_pos += 1

        # Read and write all frames within the segment (no sampling)
        current_frame_in_clip = 0  # Current frame index within the segment

        while current_frame_in_clip < max_frames:
            ret, frame = cap.read()
            if not ret:
                print(f"Warning: Reached end of video early. Expected {max_frames} frames, got {current_frame_in_clip}")  
                break
            
            if current_frame_in_clip % frame_interval == 0:
                out.write(frame)
            current_frame_in_clip += 1
        
        cap.release()
        out.release()  # Flush all frames and write metadata
        cv2.destroyAllWindows()  # Ensure no OpenCV resources are locked

        print(f"Video processing completed. Output: {output_path}")
        # Validate output file
        if not os.path.exists(output_path):
            raise RuntimeError(f"The output file has not been generated: {output_path}")
        file_size = os.path.getsize(output_path)
        if file_size < 1024:
            raise RuntimeError(f"The output file is too small ({file_size} bytes), and there is no valid frame data.")
        # raise RuntimeError(f"This is a test of a video crop error.")
        return output_path
        
    except Exception as e:
        return f"Video processing error: {str(e)}" 



def run_agent_with_sandbox(client: openai.Client, model_name: str, user_prompt: str, user_video_path: str):
    """
    Core logic to run an agent with code sandbox capability.

    :param client: OpenAI client instance.
    :param model_name: Name of the model to use.
    :param user_prompt: The question provided by the user.
    :param user_video_path: Initial path to the user's video.
    :return: The final answer string from the model.
    """
    MAX_ITERATIONS = 3  # Set maximum number of iterations to prevent infinite loops

    # --- 1. Set up the context for this run ---
    # Create a unique output directory for each run to store generated images and logs
    run_timestamp = int(time.time())
    output_dir = os.path.join("eval/agent_demo_runs_test", f"run_{run_timestamp}")
    os.makedirs(output_dir, exist_ok=True)
    # shutil.copy(user_video_path, output_dir)
    print(f"📂 Intermediate files for this run will be saved in: {os.path.abspath(output_dir)}")

    # Conversation history for ongoing interaction with the model
    conversation_history = []
    # Execution context for passing variables between code executions
    execution_context = {}

    # --- 2. Prepare initial message ---
    # Add "/agentic_think" to trigger the model's agentic thinking mode
    conversation_history.append({"role": "system", "content": PREFIX_PROMPT})
    initial_content = [
        {"type": "video", "video": user_video_path},
        {"type": "text", "text": user_prompt},
    ]
    conversation_history.append({"role": "user", "content": initial_content})

    print("\n" + "="*20 + " Agent is running " + "="*20)
    print(f"🤔 Question: {user_prompt}")
    print(f"🎬 Video: {user_video_path}")

    # --- 3. Core agent loop ---
    for i in range(MAX_ITERATIONS):

        print("\n" + f"--- Iteration {i + 1}/{MAX_ITERATIONS} ---")

        # --- Call the model ---
        print("🧠 Calling model for reasoning...")
        try:
            response = client.chat.completions.create(
                model=model_name,
                messages=conversation_history,
                temperature=0.1,  
                max_tokens=4096,
                stop=["</answer>","<|im_end|>"] 
            )
            generated_text = response.choices[0].message.content
            print(f"🤖 Model Response:\n{generated_text}")
        except Exception as e:
            print(f"❌ API call failed: {e}")
            break # If API fails, break the loop

        # Append the model's response to the conversation history
        conversation_history.append({"role": "assistant", "content": [{"type": "text", "text": generated_text}]})

        # --- Check if the model's response contains the final answer ---
        if "</answer>" in generated_text:
            print("\n✅ Found final answer, task completed.")
            break

        # --- Check if the model's response contains video clipping ---
        timestamp_match = re.search(r"<tool_call>(.*?)</tool_call>", generated_text, re.DOTALL)
        if timestamp_match:
            try:
                tool_call = json.loads(timestamp_match.group(1).strip())
                if tool_call['name'] == "get_video_clip_frame":
                    clip_timestamps = []
                    for timestamp in tool_call['arguments']:
                        start_time = float(timestamp['start_time'])
                        end_time = float(timestamp['end_time'])
                        clip_timestamps.append([start_time, end_time])
                else:
                    raise NotImplementedError(f"Unsupported tool call: {tool_call}")
            except Exception as e:
                print(f"⚠️ Timestamp conversion error: {e}")
                break

            print(f"\n🐍 Found video clipping timestamps, preparing to execute video clipping:\n---\n{clip_timestamps}\n---")

            # --- Call video clipping method ---
            video_save_dir = os.path.join(output_dir, f"iteration_{i+1}_videos")
            os.makedirs(video_save_dir, exist_ok=True)
            
            cliped_videos = []
            error_info = []
            for start_time, end_time in clip_timestamps:
                processed_path = _crop_video(
                    user_video_path,
                    video_save_dir,
                    start_time,
                    end_time
                )
                if os.path.exists(processed_path):
                    cliped_videos.append(processed_path)
                else:
                    error_info.append(processed_path)

            # --- Prepare feedback for video clipping ---
            feedback = []
            if len(error_info) > 0:
                for error in error_info:
                    feedback.append({"type": "text", "text": error})
                feedback.append({"type": "text", "text": CROP_FAIL_PROMPT})
            else:
                print("✅ Video clipping executed successfully.")
                for path in cliped_videos:
                    feedback.append({"type": "video", "video": path})
                feedback.append({"type": "text", "text": CROP_SUCCESS_PROMPT})
                # feedback.append({"type": "text", "text": USER_PROMPT_2})

            # Append the sandbox results as a new 'user' message to the history, allowing the model to continue processing
            conversation_history.append({"role": "user", "content": feedback})
        else:
            print("⚠️ The model did not provide an answer or generate clipping timestamps, it may be stuck. Terminating loop.")
            break

    # --- 4. Loop end, extract final answer ---
    print("\n" + "="*20 + " Agent run ended " + "="*20)
    final_response_text = conversation_history[-1]['content'][0]['text']
    answer_match = re.search(r"<answer>(.*?)</answer>", final_response_text, re.DOTALL)

    if answer_match:
        final_answer = answer_match.group(1).strip()
    else:
        final_answer = final_response_text  # If no <answer> tag is found, return a default value

    conv_history = []
    for idx in range(len(conversation_history)):
        if conversation_history[idx]['role'] == 'assistant':
            conv_history.append(conversation_history[idx]['content'][0]['text'])

    return conv_history, final_answer





if __name__ == "__main__":
    # --- 1. Configuration ---
    # Please replace this URL with your VLLM server address
    VLLM_BASE_URL = "http://0.0.0.0:8000/v1"

    # --- 2. Connect to client ---
    print(f"🚀 Connecting to VLLM server: {VLLM_BASE_URL}")
    try:
        client = openai.Client(api_key="EMPTY", base_url=VLLM_BASE_URL)
        models = client.models.list()
        if not models.data:
            raise ValueError("The server did not return any models.")
        MODEL_NAME = models.data[0].id
        print(f"✅ Connection successful! Using model: {MODEL_NAME}")
    except Exception as e:
        print(f"\n❌ Unable to connect to VLLM server.")
        print("Please confirm:")
        print("   1. VLLM server is running successfully at the specified address and port.")
        print("   2. Network connection is stable and can access the address from the current machine.")
        print(f"   Error details: {e}\n")
        sys.exit(1)

    def vllm_api_process_(item, output_file):

        # PROMPT = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.\n"
        PROMPT = item['question']
        for option in item['options']:
            PROMPT += f"\n{option}"
        VIDEO_PATH = item['video_path']

        # --- 3. Run agent ---
        conv_history, final_answer = run_agent_with_sandbox(
            client=client,
            model_name=MODEL_NAME,
            user_prompt=PROMPT,
            user_video_path=VIDEO_PATH
        )

        print("#"*50)
        print(f"Response: {final_answer}")
        print("#"*50)

        item['response'] = final_answer
        item['conv_history'] = conv_history
        with lock:
            with open(output_file, 'a') as f:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
                f.flush()


    import threading
    from concurrent.futures import ThreadPoolExecutor, as_completed
    lock = threading.Lock()

    import json
    from tqdm import tqdm
    import pandas as pd

    data = pd.read_parquet("lmms-lab/Video-MME/videomme/test-00000-of-00001.parquet")
    output_path = "eval/video-mme/output/test.json"

    existing_set = set()
    if os.path.exists(output_path):
        for line in open(output_path, "r"):
            item = json.loads(line)
            existing_set.add(item['video_id']+item['question_id'])

    output_dir = os.path.dirname(output_path)
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    input_list = []
    num = 0
    for i in range(len(data)):

        video_info = {
            "video_id": data['video_id'][i],
            "video_path": f"lmms-lab/Video-MME/data/{data['videoID'][i]}.mp4",
            "duration": data['duration'][i],
            "domain": data['domain'][i],
            "sub_category": data['sub_category'][i],
            "question_id": data['question_id'][i],
            "task_type": data['task_type'][i],
            "question": data['question'][i],
            "options": data['options'][i].tolist(),
            "answer": data['answer'][i],
            "response": "A",
        }
        if video_info['video_id']+video_info['question_id'] in existing_set:
            continue

        if not os.path.exists(video_info['video_path']):
            print(f"❌ Error: Video file '{video_info['video_path']}' does not exist. Please check the path.")
            sys.exit(1)

        input_list.append(video_info)


    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(vllm_api_process_, item, output_path) for item in input_list]
        for _ in tqdm(as_completed(futures), total=len(futures)):
            pass
    
    print('Waiting for all subprocesses done...')
    executor.shutdown(wait=True)
    print('All subprocesses done.')

    video_info = {}
    with open(output_path, "r") as f:
        for line in f:
            item = json.loads(line)
            if item['video_id'] not in video_info.keys():
                video_info[item['video_id']] = []
            video_info[item['video_id']].append(item)

    formatted_data = []
    for values in video_info.values():
        data = values[0]
        video = {
            'video_id': data['video_id'],
            'video_path': data['video_path'],
            'duration': data['duration'],
            'domain': data['domain'],
            'sub_category': data['sub_category'],
            'questions': []
        }

        for i in range(len(values)):
            video['questions'].append(
                {
                    "question_id": values[i]['question_id'],
                    "task_type": values[i]['task_type'],
                    "question": values[i]['question'],
                    "options": values[i]['options'],
                    "answer": values[i]['answer'],
                    "response": values[i]['response'],
                    "conv_history": values[i]['conv_history']
                }
            )
        formatted_data.append(video)

    with open(output_path, "w") as f:
        json.dump(formatted_data, f, indent=4)