from data_configs import DATASETS
import argparse
import numpy as np
import json
from tqdm import tqdm
import os
import re
import pickle
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import random
import math

VIDEO_INFO_CACHE = {}


def get_args():
    parser = argparse.ArgumentParser(
        description='Evaluation for training-free video temporal grounding (Single GPU Version)')
    parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.')
    parser.add_argument('--split', default='default', type=str, help='Specify the split.')
    parser.add_argument("--model_base", type=str, default="/path/to/qwen-model")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
    parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
    parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use")
    return parser.parse_args()


def calc_iou(candidates, gt):
    start, end = candidates[:, 0], candidates[:, 1]
    s, e = gt[0], gt[1]
    inter = np.minimum(end, e) - np.maximum(start, s)
    union = np.maximum(end, e) - np.minimum(start, s)
    return inter.clip(min=0) / union


def cached_process_vision_info(messages, return_video_kwargs=False):
    global VIDEO_INFO_CACHE

    video_path = None
    for msg in messages:
        for content in msg.get('content', []):
            if isinstance(content, dict) and 'video' in content:
                video_path = content['video']
                break

    cache_key = f"{video_path}_{return_video_kwargs}"
    if cache_key in VIDEO_INFO_CACHE:
        return VIDEO_INFO_CACHE[cache_key]

    result = process_vision_info(messages, return_video_kwargs=return_video_kwargs)
    VIDEO_INFO_CACHE[cache_key] = result

    return result


def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"):
    messages = [
        {
            "role":
                "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "video": video_path,
                    "total_pixels": 3584 * 28 * 28,
                    "min_pixels": 16 * 28 * 28,
                },
            ]
        },
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True)
    fps_inputs = video_kwargs['fps']

    inputs = processor(text=[text],
                       images=image_inputs,
                       videos=video_inputs,
                       fps=fps_inputs,
                       padding=True,
                       return_tensors="pt")
    inputs = inputs.to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)

    generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return output_text[0]


def parse_timestamp_output(output_string):
    matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string)
    if not matches:
        answer_match = re.search(r"<answer>(.*?)</answer>", output_string)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content)
            if answer_matches:
                last_match = answer_matches[-1]
                return float(last_match[0]), float(last_match[2])
        return None, None

    last_match = matches[-1]
    start_time_str = last_match[0]
    end_time_str = last_match[2]

    try:
        start_time = float(start_time_str)
        end_time = float(end_time_str)
        return start_time, end_time
    except ValueError:
        return None, None


# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.

# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.xx) or time ranges (xx.xx to xx.xx) in <timestep> </timestep> tags.

# Then, provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.54 to 17.83"."""

GROUND_TEMPLATE = """To accurately pinpoint the object described as "[EVENT]" in the video, determine the precise time period of the occurance of the object.

Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.x) or time ranges (xx.x to xx.x).

Then, provide the start and end times (in seconds, precise to one decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.5 to 17.0"."""

# GROUND_TEMPLATE = """To accurately pinpoint the object described as "[EVENT]" in the video, determine the precise time period of the occurrence of the object.

# Provide the start and end times (in seconds, precise to one decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.5 to 17.0"."""


def create_work_items(data):
    work_items = []
    for vid, ann in data.items():
        for i in range(len(ann['sentences'])):
            work_items.append({'vid': vid, 'ann': ann, 'sentence_idx': i})
    # 随机打乱列表
    random.shuffle(work_items)
    return work_items


def setup_model(model_base, device):
    print(f"Setting up model on device {device}")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_base,
                                                               torch_dtype=torch.bfloat16,
                                                               use_sliding_window=True,
                                                               attn_implementation="flash_attention_2",
                                                               device_map=device)
    processor = AutoProcessor.from_pretrained(model_base)
    return model, processor


def get_checkpoint_path(checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    return os.path.join(checkpoint_dir, "checkpoint.pkl")


def load_checkpoint(checkpoint_path):
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'rb') as f:
                return pickle.load(f)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
    return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])}


def save_checkpoint(checkpoint_path, state):
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(state, f)


# def compute_reward(sp, ep, gt):
#     iou = 0
#     s, e = gt
#     iou_ = (min(e, ep) - max(s, sp)) / (max(e, ep) - min(s, sp))
#     iou = max(iou_, 0)

#     return iou


def compute_reward(sp, ep, gt_frames, fps=2):
    iou = 0

    # 2) Convert to frame indices
    start_frame = math.floor(sp * fps)
    end_frame = math.ceil(ep * fps)
    pred_set = set(range(start_frame, end_frame))

    # 3) GT you passed in as list of frame indices
    gt_set = set(gt_frames)

    # 4) IoU via set ops
    inter = pred_set & gt_set
    uni = pred_set | gt_set
    iou = len(inter) / len(uni) if uni else 0.0

    return iou


def process_work_items(work_items, video_dir_path, model_base, device, checkpoint_dir, resume=False):
    ious = []
    thresh = np.array([0.3, 0.5, 0.7])
    recall = np.array([0, 0, 0])
    preds = []

    checkpoint_path = get_checkpoint_path(checkpoint_dir)
    processed_items = set()

    if resume and os.path.exists(checkpoint_path):
        checkpoint = load_checkpoint(checkpoint_path)
        processed_items = checkpoint['processed_items']
        ious = checkpoint['ious']
        recall = checkpoint['recall']
        print(f"Resuming from checkpoint with {len(processed_items)} processed items")

    model, processor = setup_model(model_base, device)

    item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items]
    remaining_items = [(i, item)
                       for i, (item, item_id) in enumerate(zip(work_items, item_ids))
                       if not resume or item_id not in processed_items]

    if not remaining_items:
        print("All items already processed")
        return ious, recall

    print(f"Processing {len(remaining_items)} out of {len(work_items)} items")

    pbar = tqdm(remaining_items)
    for idx, (_, item) in enumerate(pbar):
        vid = item['vid']
        ann = item['ann']
        sentence_idx = item['sentence_idx']
        item_id = f"{vid}_{sentence_idx}"

        prompt = GROUND_TEMPLATE.replace('[EVENT]', ann['sentences'][sentence_idx])

        duration = ann['duration'] if 'duration' in ann else ann['video_duration']
        video_path = None
        for ext in ['mp4', 'mkv', 'webm']:
            path = os.path.join(video_dir_path, f"{vid}.{ext}")
            if os.path.isfile(path):
                video_path = path
                break

        if video_path:
            try:
                ans = inference(video_path, prompt, model, processor, device=device)
                # print('[prompt]', prompt)
                # print('[answer]', ans)

                sp, ep = parse_timestamp_output(ans)
                print(f"Parsed times: {sp}, {ep}")
                print(f"Ground truth: {ann['timestamps'][sentence_idx]}")
                print('-' * 50)

                gt = ann['timestamps'][sentence_idx]

                if (sp is not None) and (ep is not None):
                    iou = compute_reward(sp, ep, gt)
                    recall += (thresh <= iou)
                ious.append(iou)

                processed_items.add(item_id)

                if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1:
                    state = {'processed_items': processed_items, 'ious': ious, 'recall': recall}
                    save_checkpoint(checkpoint_path, state)

                miou = sum(ious) / len(ious) if ious else 0
                recall_str = str(recall / len(ious) if ious else [0, 0, 0])
                pbar.set_postfix({"mIoU": miou, 'recall': recall_str})

                print(ans)

                preds.append({
                    'video_id': vid,
                    'question': ann['sentences'][sentence_idx],
                    'response': ans,
                    'pred': [sp, ep],
                    'gt': gt,
                    'iou': iou,
                })

                with open(os.path.join(checkpoint_dir, 'preds.json'), 'w') as f:
                    json.dump(preds, f, indent=4)

            except Exception as e:
                print(f"Error processing {vid}_{sentence_idx}: {e}")

    print('=== final result ===')
    # if ious:
    print('mIoU:', sum(ious) / len(ious))
    for th, r in zip(thresh, recall):
        print(f'R@{th}:', r / len(ious))

    return ious, recall


def evaluate(data, args):
    dataset = DATASETS[args.dataset]
    video_dir_path = dataset['video_path']

    work_items = create_work_items(data)

    ious, recall = process_work_items(work_items, video_dir_path, args.model_base, args.device, args.checkpoint_dir,
                                      args.resume)

    return ious, recall


if __name__ == '__main__':
    args = get_args()
    assert args.dataset in DATASETS
    dataset = DATASETS[args.dataset]
    assert args.split in dataset['splits']

    print('evaluate', args.dataset, args.split)

    # load data
    with open(dataset['splits'][args.split]['annotation_file']) as f:
        data = json.load(f)

    evaluate(data, args)
