import ast
from data_configs import DATASETS
import argparse
import numpy as np
import json
from tqdm import tqdm
import os
import re
import pickle
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info, smart_resize
import random
from PIL import Image


# @title Parsing JSON output
def parse_json(json_output):
    # Parsing out the markdown fencing
    lines = json_output.splitlines()
    for i, line in enumerate(lines):
        if line == "```json":
            json_output = "\n".join(lines[i + 1:])     # Remove everything before "```json"
            json_output = json_output.split("```")[0]     # Remove everything after the closing "```"
            break     # Exit the loop once "```json" is found
    return json_output


def calculate_iou(box1, box2):
    try:
        box1 = [int(coordinate) for coordinate in box1]
        box2 = [int(coordinate) for coordinate in box2]

        x1_inter = max(box1[0], box2[0])
        y1_inter = max(box1[1], box2[1])
        x2_inter = min(box1[2], box2[2])
        y2_inter = min(box1[3], box2[3])
    except:
        print("Invalid box coordinates:", box1, box2)
        return 0

    inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter)

    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    union_area = box1_area + box2_area - inter_area

    iou = inter_area / union_area if union_area != 0 else 0

    return iou


# refer to LEO: embodied-generalist
# https://github.com/embodied-generalist/embodied-generalist/blob/477dc44b8b18dbfbe6823c307436d896ec8b062e/evaluator/scanqa_eval.py#L41-L50
def answer_match(pred, gts):
    # return EM and refined EM
    if len(pred) == 0:
        return 0, 0
    if pred in gts:
        return 1, 1
    for gt in gts:
        if ''.join(pred.split()) in ''.join(gt.split()) or ''.join(gt.split()) in ''.join(pred.split()):
            return 0, 1
    return 0, 0


# refer to LEO: embodied-generalist
# https://github.com/embodied-generalist/embodied-generalist/blob/477dc44b8b18dbfbe6823c307436d896ec8b062e/data/data_utils.py#L322-L379
def clean_answer(data):
    data = data.lower()
    data = re.sub('[ ]+$', '', data)
    data = re.sub('^[ ]+', '', data)
    data = re.sub(' {2,}', ' ', data)

    data = re.sub('\.[ ]{2,}', '. ', data)
    data = re.sub('[^a-zA-Z0-9,\'\s\-:]+', '', data)
    data = re.sub('ç', 'c', data)
    data = re.sub('’', '\'', data)
    data = re.sub(r'\bletf\b', 'left', data)
    data = re.sub(r'\blet\b', 'left', data)
    data = re.sub(r'\btehre\b', 'there', data)
    data = re.sub(r'\brigth\b', 'right', data)
    data = re.sub(r'\brght\b', 'right', data)
    data = re.sub(r'\bbehine\b', 'behind', data)
    data = re.sub(r'\btv\b', 'TV', data)
    data = re.sub(r'\bchai\b', 'chair', data)
    data = re.sub(r'\bwasing\b', 'washing', data)
    data = re.sub(r'\bwaslked\b', 'walked', data)
    data = re.sub(r'\boclock\b', 'o\'clock', data)
    data = re.sub(r'\bo\'[ ]+clock\b', 'o\'clock', data)

    # digit to word, only for answer
    data = re.sub(r'\b0\b', 'zero', data)
    data = re.sub(r'\bnone\b', 'zero', data)
    data = re.sub(r'\b1\b', 'one', data)
    data = re.sub(r'\b2\b', 'two', data)
    data = re.sub(r'\b3\b', 'three', data)
    data = re.sub(r'\b4\b', 'four', data)
    data = re.sub(r'\b5\b', 'five', data)
    data = re.sub(r'\b6\b', 'six', data)
    data = re.sub(r'\b7\b', 'seven', data)
    data = re.sub(r'\b8\b', 'eight', data)
    data = re.sub(r'\b9\b', 'nine', data)
    data = re.sub(r'\b10\b', 'ten', data)
    data = re.sub(r'\b11\b', 'eleven', data)
    data = re.sub(r'\b12\b', 'twelve', data)
    data = re.sub(r'\b13\b', 'thirteen', data)
    data = re.sub(r'\b14\b', 'fourteen', data)
    data = re.sub(r'\b15\b', 'fifteen', data)
    data = re.sub(r'\b16\b', 'sixteen', data)
    data = re.sub(r'\b17\b', 'seventeen', data)
    data = re.sub(r'\b18\b', 'eighteen', data)
    data = re.sub(r'\b19\b', 'nineteen', data)
    data = re.sub(r'\b20\b', 'twenty', data)
    data = re.sub(r'\b23\b', 'twenty-three', data)

    # misc
    # no1, mat2, etc
    data = re.sub(r'\b([a-zA-Z]+)([0-9])\b', r'\g<1>', data)
    data = re.sub(r'\ba\b ([a-zA-Z]+)', r'\g<1>', data)
    data = re.sub(r'\ban\b ([a-zA-Z]+)', r'\g<1>', data)
    data = re.sub(r'\bthe\b ([a-zA-Z]+)', r'\g<1>', data)

    data = re.sub(r'\bbackwards\b', 'backward', data)

    return data


VIDEO_INFO_CACHE = {}


def get_args():
    parser = argparse.ArgumentParser(
        description='Evaluation for training-free video temporal grounding (Single GPU Version)')
    parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.')
    parser.add_argument('--split', default='default', type=str, help='Specify the split.')
    parser.add_argument("--model_base", type=str, default="/path/to/qwen-model")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
    parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
    parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use")
    return parser.parse_args()


def calc_iou(candidates, gt):
    start, end = candidates[:, 0], candidates[:, 1]
    s, e = gt[0], gt[1]
    inter = np.minimum(end, e) - np.maximum(start, s)
    union = np.maximum(end, e) - np.minimum(start, s)
    return inter.clip(min=0) / union


def cached_process_vision_info(messages, return_video_kwargs=False):
    global VIDEO_INFO_CACHE

    video_path = None
    for msg in messages:
        for content in msg.get('content', []):
            if isinstance(content, dict) and 'video' in content:
                video_path = content['video']
                break

    cache_key = f"{video_path}_{return_video_kwargs}"
    if video_path is not None and cache_key in VIDEO_INFO_CACHE:
        return VIDEO_INFO_CACHE[cache_key]

    result = process_vision_info(messages, return_video_kwargs=return_video_kwargs)
    VIDEO_INFO_CACHE[cache_key] = result

    return result


def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"):
    messages = [
        {
            "role":
                "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "type": "image",
                    "image": video_path,
                    "total_pixels": 3584 * 28 * 28,
                    "min_pixels": 16 * 28 * 28,
                },
            ]
        },
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True)
    fps_inputs = video_kwargs['fps']

    inputs = processor(text=[text],
                       images=image_inputs,
                       videos=video_inputs,
                       fps=fps_inputs,
                       padding=True,
                       return_tensors="pt")
    inputs = inputs.to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)

    generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

    input_height = inputs['image_grid_thw'][0][1] * 14
    input_width = inputs['image_grid_thw'][0][2] * 14

    return output_text[0], input_height, input_width


def parse_timestamp_output(output_string):
    matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string)
    if not matches:
        answer_match = re.search(r"<answer>(.*?)</answer>", output_string)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content)
            if answer_matches:
                last_match = answer_matches[-1]
                return float(last_match[0]), float(last_match[2])
        return None, None

    last_match = matches[-1]
    start_time_str = last_match[0]
    end_time_str = last_match[2]

    try:
        start_time = float(start_time_str)
        end_time = float(end_time_str)
        return start_time, end_time
    except ValueError:
        return None, None


def parse_bbox(content, input_width, input_height, width, height):
    bounding_boxes = parse_json(content)
    pred_box = [0, 0, 0, 0]

    try:
        json_output = ast.literal_eval(bounding_boxes)

        pred_box = json_output[0]['bbox_2d']
    except Exception as e:
        try:
            end_idx = bounding_boxes.rfind('"}') + len('"}')
            truncated_text = bounding_boxes[:end_idx] + "]"
            json_output = ast.literal_eval(truncated_text)

            pred_box = json_output[0]['bbox_2d']

        except Exception as e:
            pass
            # print("Error parsing bounding box:", e)

    try:
        # Convert normalized coordinates to absolute coordinates
        abs_y1 = int(pred_box[1] / input_height * height)
        abs_x1 = int(pred_box[0] / input_width * width)
        abs_y2 = int(pred_box[3] / input_height * height)
        abs_x2 = int(pred_box[2] / input_width * width)

        if abs_x1 > abs_x2:
            abs_x1, abs_x2 = abs_x2, abs_x1

        if abs_y1 > abs_y2:
            abs_y1, abs_y2 = abs_y2, abs_y1

        pred_box = [abs_x1, abs_y1, abs_x2, abs_y2]
    except Exception as e:
        print(f"Error converting coordinates: {e}")

    return pred_box


# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.

# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.xx) or time ranges (xx.xx to xx.xx) in <timestep> </timestep> tags.

# Then, provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.54 to 17.83"."""

# GROUND_TEMPLATE = """To accurately pinpoint the object described as "[EVENT]" in the video, determine the precise time period of the occurance of the object.

# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.x) or time ranges (x.xx to xx.x) in <timestep> </timestep> tags.

# Then, provide the start and end times (in seconds, precise to one decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.5 to 17.0"."""

# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.

# Provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.54 to 17.83"."""

QUESTION_TEMPLATE = """Outline the object according to the description "[EVENT]". Output the thinking process in <think> </think>. Outline the bbox_2d coordinates in JSON format."""
# QUESTION_TEMPLATE = """Outline the object according to the description "[EVENT]". Outline the bbox_2d coordinates in JSON format."""

# QUESTION_TEMPLATE = """Outline the functional interactive element referred to by the task description "[EVENT]". (e.g., a button affords pressing, a drawer knob affords pulling). Output the thinking process in <think> </think>. Outline the bbox_2d coordinates in JSON format."""
# QUESTION_TEMPLATE = """Outline the functional interactive element referred to by the task description "[EVENT]". (e.g., a button affords pressing, a drawer knob affords pulling). Outline the bbox_2d coordinates in JSON format."""


def create_work_items(data):
    work_items = []

    for vid, ann in data.items():
        for i in range(len(ann['sentences'])):
            work_items.append({'vid': vid, 'ann': ann, 'sentence_idx': i})
    random.shuffle(work_items)
    return work_items


def setup_model(model_base, device):
    print(f"Setting up model on device {device}")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_base,
                                                               torch_dtype=torch.bfloat16,
                                                               use_sliding_window=True,
                                                               attn_implementation="flash_attention_2",
                                                               device_map=device)
    processor = AutoProcessor.from_pretrained(model_base)
    return model, processor


def get_checkpoint_path(checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    return os.path.join(checkpoint_dir, "checkpoint.pkl")


def load_checkpoint(checkpoint_path):
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'rb') as f:
                return pickle.load(f)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
    return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])}


def save_checkpoint(checkpoint_path, state):
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(state, f)


def process_work_items(work_items, model_base, device, checkpoint_dir, resume=False):
    ious = []
    thresh = np.array([0.3, 0.5, 0.7])
    recall = np.array([0, 0, 0])

    em = []
    em_refined = []
    preds = []


    model, processor = setup_model(model_base, device)

    # item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items]
    # remaining_items = [(i, item)
    #                    for i, (item, item_id) in enumerate(zip(work_items, item_ids))
    #                    if not resume or item_id not in processed_items]

    # if not remaining_items:
    #     print("All items already processed")
    #     return ious, recall

    # print(f"Processing {len(remaining_items)} out of {len(work_items)} items")

    pbar = tqdm(work_items)
    for idx, item in enumerate(pbar):
        vid = item['vid']
        ann = item['ann']
        sentence_idx = item['sentence_idx']
        item_id = f"{vid}_{sentence_idx}"

        prompt = QUESTION_TEMPLATE.replace('[EVENT]', ann['sentences'][sentence_idx])

        # duration = ann['duration'] if 'duration' in ann else ann['video_duration']
        # video_path = None
        # for ext in ['mp4', 'mkv', 'webm']:
        #     path = os.path.join(video_dir_path, f"{vid}.{ext}")
        #     if os.path.isfile(path):
        #         video_path = path
        #         break
        image_path = ann['image_path'][sentence_idx]
        sol = ann['bbox_2d'][sentence_idx]

        image = Image.open(image_path)
        width, height = image.size
        # input_height, input_width = smart_resize(height, width)

        # print('image size:', image.size)
        # print(input_height, input_width)

        content, input_height, input_width = inference(image_path, prompt, model, processor, device=device)
        # print(input_height, input_width)
        print('[question]', ann['sentences'][sentence_idx])
        print('[answer]', content)
        print('[gt]', sol)

        pred_box = parse_bbox(content, input_width, input_height, width, height)

        reward = calculate_iou(pred_box, sol)
        print('reward:', reward)

        em_flag = 0
        em_refined_flag = 0
        if reward >= 0.25:
            em_flag = 1
        if reward >= 0.5:
            em_refined_flag = 1

        em.append(em_flag)
        em_refined.append(em_refined_flag)

        # processed_items.add(item_id)

        # if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1:
        #     state = {'processed_items': processed_items, 'ious': ious, 'recall': recall}
        #     save_checkpoint(checkpoint_path, state)

        running_em = sum(em) / len(em) if em else 0
        running_em_r = sum(em_refined) / len(em_refined) if em_refined else 0
        pbar.set_postfix({"Acc@25": running_em, "Acc@50": running_em_r})

        preds.append({
            'image': ann['image_path'][sentence_idx],
            'question': ann['sentences'][sentence_idx],
        # 'obj_id': ann['obj_ids'][sentence_idx],
        # 'annot_id': ann['annot_id'][sentence_idx],
            'response': content,
            'pred': pred_box,
            'gt': sol,
            'reward': reward,
            'em': em_flag,
            'em_refined': em_refined_flag,
        })

        with open(os.path.join(checkpoint_dir, 'preds.json'), 'w') as f:
            json.dump(preds, f, indent=4)

        # miou = sum(ious) / len(ious) if ious else 0
        # recall_str = str(recall / len(ious) if ious else [0, 0, 0])
        # pbar.set_postfix({"mIoU": miou, 'recall': recall_str})

    # except Exception as e:
    #     print(e)
    # print(f"Error processing {item}")

    # break

    print('=== final result ===')
    # if ious:
    # print('mIoU:', sum(ious) / len(ious))
    # for th, r in zip(thresh, recall):
    #     print(f'R@{th}:', r / len(ious))
    print('EM:', sum(em) / len(em))
    print('EM_refined:', sum(em_refined) / len(em_refined))

    return em, em_refined


def evaluate(data, args):
    dataset = DATASETS[args.dataset]
    # video_dir_path = dataset['video_path']

    work_items = create_work_items(data)

    em, em_refined = process_work_items(work_items, args.model_base, args.device, args.checkpoint_dir, args.resume)

    return em, em_refined


if __name__ == '__main__':
    args = get_args()
    assert args.dataset in DATASETS
    dataset = DATASETS[args.dataset]
    assert args.split in dataset['splits']

    print('evaluate', args.dataset, args.split)

    # load data
    with open(dataset['splits'][args.split]['annotation_file']) as f:
        data = json.load(f)

    print(f"Loaded {len(data)} items from {dataset['splits'][args.split]['annotation_file']}")

    evaluate(data, args)
