import os
import io
import requests
import json
import base64
from tqdm import tqdm
import re
import random
import time
import argparse
import torch
import hashlib
from qwen_vl_utils import process_vision_info, fetch_video
from PIL import Image
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from functools import wraps
from typing import Optional, List, Dict, Any, Callable
from torchvision import transforms
import yaml
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoProcessor, LlavaNextVideoProcessor

os.environ["DECORD_EOF_RETRY_MAX"] = "20480"

SYSTEM_PROMPT = (
    "You are an expert video analyst specializing in temporal reasoning over long videos. "
    "When analyzing videos, you must construct structured reasoning traces that mirror the video's temporal structure. "
    "For each critical temporal segment, specify time intervals with `<time>start_time-end_time</time>`, "
    "describe visual evidence with `<caption>key visual elements</caption>`, "
    "and provide temporal analysis with `<think>reasoning about temporal relationships</think>`. "
    "Your reasoning should capture ordered event chains and temporal dependencies across segments. "
    "Employ natural cognitive expressions to articulate your temporal understanding process. "
    "After examining temporal traces, synthesize your findings and place the final answer in `<answer> </answer>` tags."
)

QUESTION_TEMPLATE = (
    "{Question}\n\n"
    "Analyze the video by constructing temporal reasoning traces. Identify key temporal segments and their relationships "
    "using `<time> </time>`, `<caption> </caption>`, `<think> </think>` tags. "
    "Conduct temporal analysis to derive your answer, then provide only the single option letter within `<answer> </answer>` tags."
)


def prepare_model(model_name, model_path):

    print(f"Preparing {model_name}...")

    llm = LLM(
        model=model_path,
        tensor_parallel_size=4,
        max_model_len=64000,
        gpu_memory_utilization=0.8,
        limit_mm_per_prompt={"image": 1, "video": 1},
    )

    sampling_params = SamplingParams(
        temperature=0.1,
        top_p=0.001,
        max_tokens=1024,
        stop_token_ids=[],
    )

    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    tokenizer.padding_side = "left"
    processor.tokenizer = tokenizer

    return llm, processor, sampling_params


def create_question_hash(item):
    content = item['problem'] + str(item['options']) + item['path']
    return hashlib.md5(content.encode()).hexdigest()


def get_video_path(video_path_config, data_source, default_source='default'):
    if isinstance(video_path_config, dict):
        if data_source in video_path_config:
            return video_path_config[data_source]
        elif default_source in video_path_config:
            print(f"Warning: Data source '{data_source}' not found, using default '{default_source}'")
            return video_path_config[default_source]
        else:
            first_key = list(video_path_config.keys())[0]
            print(f"Warning: Data source '{data_source}' and default '{default_source}' not found, using '{first_key}'")
            return video_path_config[first_key]
    else:
        return video_path_config


def predict_batch(llm, processor, sampling_params, examples, video_path_config, nframes=16):
    batch_messages = []
    for example in examples:
        data_source = example.get('data_source', 'default')
        video_base_path = get_video_path(video_path_config, data_source)

        if example["problem_type"] == 'multiple choice':
            question = example['problem'] + "\nOptions:\n"
            for op in example["options"]:
                question += op + "\n"
        else:
            question = example['problem']

        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": SYSTEM_PROMPT}]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": example['data_type'],
                        "video": os.path.join(video_base_path, example['path']),
                        "nframes": nframes,
                        "max_pixels": 128 * 28 * 28,
                    },
                    {
                        "type": "text",
                        "text": QUESTION_TEMPLATE.format(Question=question)
                    }
                ]
            }
        ]
        batch_messages.append(messages)

    prompts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]

    image_inputs, video_inputs, video_kwargs = process_vision_info(
        batch_messages,
        return_video_kwargs=True,
    )

    llm_inputs = []
    for idx, prompt in enumerate(prompts):
        sample_mm_data = {"video": video_inputs[idx]}
        sample_video_kw = {}
        for key, value in video_kwargs.items():
            sample_video_kw[key] = value[idx]

        llm_inputs.append({
            "prompt": prompt,
            "multi_modal_data": sample_mm_data,
            "mm_processor_kwargs": sample_video_kw,
        })

    outputs = llm.generate(llm_inputs, sampling_params=sampling_params)
    batch_output_text = [out.outputs[0].text for out in outputs]

    return batch_output_text


def extract_answer_from_solution(solution_text):
    pattern = r'<answer>\s*([A-Z])\s*</answer>'
    match = re.search(pattern, solution_text, re.IGNORECASE)
    if match:
        return match.group(1).upper()
    return None


def extract_predicted_answer(predicted_text):
    text_to_search = ""

    answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
    matches = re.findall(answer_pattern, predicted_text, re.DOTALL | re.IGNORECASE)

    if matches:
        text_to_search = matches[-1].strip()
    else:
        final_answer_pattern = r'Final answer:'
        match = re.search(final_answer_pattern, predicted_text, re.IGNORECASE)

        if match:
            text_to_search = predicted_text[match.end():].strip()
        else:
            text_to_search = predicted_text

    match = re.match(r'\s*([A-Z])\b', text_to_search, re.IGNORECASE)
    if match:
        return match.group(1).upper()

    for option in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O']:
        if f'{option}:' in text_to_search or f'[{option}' in text_to_search or f'{option} ' in text_to_search:
            return option

    return 'WRONG'


def extract_thinking_process(predicted_response):
    answer_pattern = r'<answer>\s*.*?\s*</answer>'
    answer_match = re.search(answer_pattern, predicted_response, re.IGNORECASE | re.DOTALL)

    if answer_match:
        thinking = predicted_response[:answer_match.start()]
    else:
        thinking = predicted_response

    return thinking.strip()


def load_data(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except json.JSONDecodeError:
        print("JSON parsing failed, trying JSONL format...")
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if line:
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError as e:
                        print(f"Error parsing line {line_num}: {e}")
                        continue
        return data


def parse_video_path_config(video_path_arg):
    if not video_path_arg:
        return 'videos/'

    if video_path_arg.startswith('{') and video_path_arg.endswith('}'):
        try:
            return json.loads(video_path_arg)
        except json.JSONDecodeError:
            print(f"Warning: Failed to parse video_path as JSON, treating as string path: {video_path_arg}")
            return video_path_arg
    else:
        return video_path_arg

def evaluate(video_path_config, json_file_path, output_path, model_name, model_path, batch_size=4, nframes=16):
    llm, processor, sampling_params = prepare_model(model_name, model_path)

    os.makedirs(output_path, exist_ok=True)

    data = load_data(json_file_path)

    output_process = []
    json_file_output = os.path.join(output_path, f"Results-{model_name}.json")

    processed_hashes = set()
    if os.path.exists(json_file_output):
        with open(json_file_output, "r", encoding="utf-8") as f:
            output_process = json.load(f)
            for item in output_process:
                question_hash = item.get("Question Hash")
                if question_hash:
                    processed_hashes.add(question_hash)

    correct_count = 0
    total_count = 0

    unprocessed_data = []
    for idx, item in enumerate(data):
        question_hash = create_question_hash(item)
        if question_hash not in processed_hashes:
            data_source = item.get('data_source', 'default')
            video_base_path = get_video_path(video_path_config, data_source)
            video_full_path = os.path.join(video_base_path, item.get('path'))

            if os.path.exists(video_full_path):
                unprocessed_data.append((idx, item))
            else:
                print(f"Warning: Video file not found: {video_full_path}")

    for i in tqdm(range(0, len(unprocessed_data), batch_size), desc="Processing batches"):
        batch_data = unprocessed_data[i:i + batch_size]
        batch_indices = [idx for idx, _ in batch_data]
        batch_items = [item for _, item in batch_data]

        batch_responses = predict_batch(llm, processor, sampling_params, batch_items, video_path_config, nframes)

        for idx, item, predicted_response in zip(batch_indices, batch_items, batch_responses):
            question_hash = create_question_hash(item)

            solution = item.get('solution')
            correct_answer = extract_answer_from_solution(solution)
            if not correct_answer:
                print(f"Warning: Could not extract answer from solution for question {idx}")
                continue

            try:
                predicted_answer = extract_predicted_answer(predicted_response)
            except Exception as e:
                predicted_answer = "WRONG"

            data_source = item.get('data_source', 'default')
            video_base_path = get_video_path(video_path_config, data_source)
            video_filename = item.get('path')

            print(f"Question {idx}: Predicted: {predicted_answer}, Correct: {correct_answer}")

            thinking = extract_thinking_process(predicted_response)

            is_correct = predicted_answer == correct_answer
            if is_correct:
                correct_count += 1
            total_count += 1

            output_process.append({
                "Question Index": idx,
                "Question Hash": question_hash,
                "Problem": item.get('problem'),
                "Problem Type": item.get('problem_type'),
                "Options": item.get('options'),
                "GT": correct_answer,
                "Predicted Answer": predicted_answer,
                "Full Response": predicted_response,
                "Thinking": thinking,
                "Correct": is_correct,
                "Video Path": video_filename,
                "Data Source": item.get('data_source', 'Unknown'),
                "Video Base Path": video_base_path
            })

        with open(json_file_output, "w", encoding="utf-8") as f:
            json.dump(output_process, f, indent=2, ensure_ascii=False)

    overall_accuracy = correct_count / total_count if total_count > 0 else 0

    print(f"\nResults:")
    print(f"Correct: {correct_count}/{total_count}")
    print(f"Overall Accuracy: {overall_accuracy:.4f}")

    stats = {
        "model_name": model_name,
        "video_path_config": video_path_config,
        "total_questions": total_count,
        "correct_answers": correct_count,
        "accuracy": overall_accuracy
    }

    stats_file = os.path.join(output_path, f"Stats-{model_name}.json")
    with open(stats_file, "w", encoding="utf-8") as f:
        json.dump(stats, f, indent=2, ensure_ascii=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run Evaluation")
    parser.add_argument('--model_name', default="Qwen2.5-VL-7B", type=str)
    parser.add_argument('--model_path', default=None, type=str)
    parser.add_argument('--video_path', default='videos/', type=str,
                        help='Video path config: can be a single path string or JSON string for multiple data sources, e.g., \'{"source1": "/path1", "source2": "/path2"}\'')
    parser.add_argument('--benchmark', default='train_data.json', type=str)
    parser.add_argument('--output_path', default='Results/', type=str)
    parser.add_argument('--batch_size', default=4, type=int, help='Batch size for processing')
    parser.add_argument('--nframes', default=16, type=int, help='Number of frames to extract from video')
    args = parser.parse_args()

    video_path_config = parse_video_path_config(args.video_path)

    evaluate(video_path_config, args.benchmark, args.output_path, args.model_name, args.model_path, args.batch_size,
             args.nframes)