import json
from tqdm import tqdm
import re
import os
from pprint import pprint
import random
import cv2
import numpy as np

import re


import cv2
import numpy as np

def visualize_detections(image, boxes):
    """
    Visualize detection results on the image
    """
    # Convert PIL Image to OpenCV format
    image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    
    # Draw boxes and labels
    for box in boxes:
            
        # Convert box coordinates to integers
        x1, y1, x2, y2 = map(int, box)
        
        cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
    
    return image

def extract_fields(text):
    json_pattern = r'```json\s*(.*?)\s*```'
    json_match = re.search(json_pattern, text, re.DOTALL)
    
    if not json_match:
        # print("## DEBUG: No JSON block found")
        return None
        
    json_content = json_match.group(1)
    # print("## DEBUG: Found JSON block:", json_content)
    try:
        json_data = json.loads(json_content)
    except json.JSONDecodeError:
        print("## DEBUG: Invalid JSON format")
        return None
    return json_data

def extract_fields_2(text):
    pattern = r'\{.*?\}'
    matches = re.findall(pattern, text, re.DOTALL)
    
    if matches:
        last_match = matches[-1]
    else:
        # print("## DEBUG: No JSON block found")
        return None
        
    json_content = last_match
    # print("Found JSON block:", json_content)
    try:
        json_data = json.loads(json_content)
    except json.JSONDecodeError:
        print("Invalid JSON format")
        return None
    return json_data

def get_img_crop(img, bbox):
    """
    Crop the image based on the bounding box coordinates.
    bbox: [x1, y1, x2, y2]
    """
    x1, y1, x2, y2 = bbox
    # Ensure the coordinates are within the image bounds
    x1 = max(0, int(x1))
    y1 = max(0, int(y1))
    # Handle different image types
    if isinstance(img, np.ndarray):  # NumPy array (OpenCV image)
        x2 = min(img.shape[1]-1, int(x2))
        y2 = min(img.shape[0]-1, int(y2))
        return img[y1:y2+1, x1:x2+1]
    else:  # Assume PIL Image
        width, height = img.size
        x2 = min(width-1, int(x2))
        y2 = min(height-1, int(y2))
        return img.crop((x1, y1, x2+1, y2+1))  # PIL crop takes (left, top, right, bottom)

def iou(box1, box2):
    inter_x1 = max(box1[0], box2[0])
    inter_y1 = max(box1[1], box2[1])
    inter_x2 = min(box1[2]-1, box2[2]-1)
    inter_y2 = min(box1[3]-1, box2[3]-1)
    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
        inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
    else:
        inter = 0
    union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
    return float(inter)/union

def extract_bbox_answer(content):
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    # bbox_pattern = r'\[(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+),\s*(-?\d*\.?\d+)\]'
    bbox_match = re.search(bbox_pattern, content)

    if bbox_match:
        bbox = [float(bbox_match.group(1)), float(bbox_match.group(2)), float(bbox_match.group(3)), float(bbox_match.group(4))]
        return bbox
    return [0, 0, 0, 0]

def extract_bbox_answer_rl(content):
    # Try to find the bbox within <answer> tags, if can not find, return [0, 0, 0, 0]
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\{.*\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]\s*.*\}'
    content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
    if content_answer_match:
        content_answer = content_answer_match.group(1).strip()
        bbox_match = re.search(bbox_pattern, content_answer, re.DOTALL)
        if bbox_match:
            bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
            x1, y1, x2, y2 = bbox
            return bbox
    return [0, 0, 0, 0]


def has_overlap(box1, box2, threshold=0.01):
    """Check if two bboxes have overlap greater than threshold."""
    iou_val = iou(box1, box2)
    return iou_val > threshold

def bbox_transform(bbox, w, h, w1, h1):
    """
    Transform bbox coordinates from original image size to new image size.
    bbox: [x1, y1, x2, y2]
    h, w: original image height and width
    h1, w1: new image height and width
    """
    x1, y1, x2, y2 = bbox
    x1 = int(x1 * w1 / w)
    x2 = int(x2 * w1 / w)
    y1 = int(y1 * h1 / h)
    y2 = int(y2 * h1 / h)
    return [x1, y1, x2, y2]

def merge_bboxes(boxes):
    """
    Merge multiple bboxes by taking the min and max coordinates.
    Returns the merged bbox or None if no valid boxes.
    """
    if not boxes or all(box == [0, 0, 0, 0] for box in boxes):
        return [0, 0, 0, 0]
    
    # Filter out [0,0,0,0] boxes
    valid_boxes = [box for box in boxes if box != [0, 0, 0, 0]]
    if not valid_boxes:
        return [0, 0, 0, 0]
    
    x_min = min(box[0] for box in valid_boxes)
    y_min = min(box[1] for box in valid_boxes)
    x_max = max(box[2] for box in valid_boxes)
    y_max = max(box[3] for box in valid_boxes)
    
    final_bbox = [x_min, y_min, x_max, y_max]
    return list(map(int, final_bbox))

def find_overlapping_boxes(boxes, threshold=0.01):
    """
    Find groups of overlapping boxes.
    Returns a list of lists, where each sublist contains indices of overlapping boxes.
    """
    n = len(boxes)
    overlapping_groups = []
    visited = set()
    
    for i in range(n):
        if i in visited or boxes[i] == [0, 0, 0, 0]:
            continue
            
        group = [i]
        visited.add(i)
        
        # Find all boxes that overlap with any box in the current group
        queue = [i]
        while queue:
            current = queue.pop(0)
            for j in range(n):
                if j not in visited and boxes[j] != [0, 0, 0, 0] and has_overlap(boxes[current], boxes[j], threshold):
                    group.append(j)
                    visited.add(j)
                    queue.append(j)
        
        if len(group) > 0:
            overlapping_groups.append(group)
    
    return overlapping_groups

def draw_bbox(image_path, model_answers, final_answer, ground_truth, show_img_path):
    img = cv2.imread(image_path)
    h, w = img.shape[:2]
    
    # Draw each individual model answer with thin blue lines
    for i, bbox in enumerate(model_answers):
        if bbox != [0, 0, 0, 0]:
            x_min, y_min, x_max, y_max = bbox
            cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=(255, 0, 0), thickness=2)
            # Add small label with inference number
            cv2.putText(img, str(i+1), (x_min, y_min-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
    
    # Draw the final model answer in red
    x_min, y_min, x_max, y_max = final_answer
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=(0, 0, 255), thickness=5)

    # Draw ground truth in green
    ground_truth = list(map(int, ground_truth))
    x_min, y_min, x_max, y_max = ground_truth
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=5)

    # resize to display size
    # img = cv2.resize(img, (256, 256))
    print(f"Saving image to: {show_img_path}")
    cv2.imwrite(show_img_path, img)







from collections import Counter
from typing import List, Dict
from difflib import SequenceMatcher

def exact_match(pred: str, targets: List[str]) -> bool:
    """
    Check if prediction exactly matches any of the ground truth answers
    """
    return pred.lower().strip() in [t.lower().strip() for t in targets]

def partial_match(pred: str, targets: List[str], threshold: float = 0.7) -> bool:
    """
    Check if prediction partially matches any ground truth answer using string similarity
    """
    
    pred = pred.lower().strip()
    return any(
        SequenceMatcher(None, pred, t.lower().strip()).ratio() >= threshold 
        for t in targets
    )

def normalize_answer(answer: str) -> str:
    """
    Normalize answer by removing articles, punctuation, etc.
    """
    import re
    
    # Remove articles
    answer = re.sub(r'\b(a|an|the)\b', ' ', answer.lower())
    
    # Remove punctuation
    answer = re.sub(r'[^a-zA-Z0-9\s]', '', answer)
    
    # Remove extra whitespace
    answer = ' '.join(answer.split())
    
    return answer

def normalized_match(pred: str, targets: List[str]) -> bool:
    """
    Compare normalized versions of prediction and ground truth
    """
    pred_norm = normalize_answer(pred)
    target_norms = [normalize_answer(t) for t in targets]
    return pred_norm in target_norms

def get_majority_answer(answers: List[str]) -> str:
    """
    Get the most common answer from ground truth
    """
    # Normalize all answers first
    normalized = [normalize_answer(a) for a in answers]
    return Counter(normalized).most_common(1)[0][0]

def evaluate_with_majority(pred: str, targets: List[str]) -> float:
    """
    Compare prediction with majority answer
    """
    majority = get_majority_answer(targets)
    return normalize_answer(pred) == majority