from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
from typing import Dict, Any, Union
from trl.data_utils import maybe_apply_chat_template
import torch
from src.open_r1.vlm_modules import VLMBaseModule
from openai import OpenAI
import re
import random
import os


class Qwen2VLModule(VLMBaseModule):
    def __init__(self):
        super().__init__()

    def get_vlm_key(self):
        return "qwen"

    def get_model_class(self, model_id: str, model_init_kwargs: dict):
        if "Qwen2-VL" in model_id:
            model_cls = Qwen2VLForConditionalGeneration
        elif "Qwen2.5-VL" in model_id:
            model_cls = Qwen2_5_VLForConditionalGeneration
        else:
            raise ValueError(f"Unsupported model: {model_id}")
        return model_cls
    
    def post_model_init(self, model, processing_class):
        pass
    
    def get_processing_class(self):
        return AutoProcessor
    
    def get_vision_modules_keywords(self):  
        return ['visual']
    
    def get_custom_multimodal_keywords(self):
        return ['pixel_values', 'image_grid_thw']

    def get_non_generate_params(self):
        return []
    
    def get_custom_processing_keywords(self):
        return ['max_pixels', 'min_pixels']
    
    def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
        prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
        return prompts_text
    
    def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
        # FIXME
        # This could only process pure-multimodal or pure-text inputs
        if len(images) > 0:
            prompt_inputs = processing_class(
                text=prompts_text,
                images=images,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
        else:
            prompt_inputs = processing_class(
                text=prompts_text,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
        return prompt_inputs
    
    @staticmethod
    def get_question_template(task_type: str):
        # match task_type:
        if task_type == "rec":
            #return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
            #return "{Question} Report the bbox coordinates in JSON format."
            return "{Question} First output the thinking process then summarize the answer in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
        elif task_type == "ovd":
            return "{Question} First output the thinking process then summarize the answer in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."

        else:
            #return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
            #return "{Question} Report the bbox coordinates in JSON format."
            return "{Question} First output the thinking process then summarize the answer in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
            #return "{Question} First output the thinking process then summarize the answer and review the question to check the answer in <think> </think> tags. And then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."


    @staticmethod
    def format_reward_rec(completions, **kwargs):
        """Check if the Qwen model output matches a specific format."""
        import re
        pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]
    
    def format_reward(completions, **kwargs):
        pattern = r"<think>.*?</think>\s*<answer>.*?\[.*?{\"bbox_2d\":\s*\[\s*\d+,\s*\d+,\s*\d+,\s*\d+\s*\]\s*,\s*\"label\":\s*\".*?\"\s*}.*?\].*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]
    
    @staticmethod
    def format_reward_tongyong(completions, **kwargs):
        pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [0.5 if match else 0.0 for match in matches]

    @staticmethod
    def format_reward_TAC(completions, **kwargs):
        pattern = r"<think>.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        return [0.5 if match else 0.0 for match in matches]
    
    

    
    def evaluate_answer_similarity(student_answer, ground_truth):
        """Use llm to evaluate answer similarity."""
        try:
            response = client.chat.completions.create(
                model="qwen2.5:7b",
                messages=[
                    {
                        "role": "user",
                        "content": "You are a evaluation expert. First, analyze the student's response to identify and extract their final answer. Then, compare the extracted answer with the correct solution. Output ONLY '1.0' if the extracted answer matches the correct solution in meaning, or '0.0' if the student's response does not contain a clear or correct answer. No other output is allowed."
                    },
                    {
                        "role": "user",
                        "content": f"Student's response: {student_answer}\nCorrect solution: {ground_truth}\nOutput only 1.0 or 0.0:"
                    }
                ],
                temperature=0
            )
            result = response.choices[0].message.content.strip()
            return float(result)
        
        except Exception as e:
            print(f"Error in GPT evaluation: {e}")
            # If API call fails, fall back to simple text matching
            return 1.0 if student_answer ==ground_truth else 0.0


    @staticmethod
    def iou_reward(completions, solution, **kwargs):
        """Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
        import re
        import os
        from datetime import datetime
        def iou(box1, box2):
            inter_x1 = max(box1[0], box2[0])
            inter_y1 = max(box1[1], box2[1])
            inter_x2 = min(box1[2]-1, box2[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1)
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter = 0
            union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
            return float(inter)/union
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
        for content, sol in zip(contents, solution):
            reward = 0.0
            # Try symbolic verification first
            try:
                content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
                if content_answer_match:
                    content_answer = content_answer_match.group(1).strip()
                    bbox_match = re.search(bbox_pattern, content_answer)
                    if bbox_match:
                        bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                        # if iou(bbox, sol) > 0.5:
                        #     reward = 1.0
                        reward = iou(bbox, sol)
            except Exception:
                pass  # Continue to next verification method if this fails
                    
            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                # local_rank = int(os.getenv("LOCAL_RANK", 0))
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Solution: {sol}\n")
        return rewards


    @staticmethod
    def iou_TAC_reward(completions, solution, **kwargs):
        """Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box.
        When multiple bboxes exist in thinking process, use the last one as it's likely the final decision.
        Multiply the IoU from think tag with the IoU from answer tag to encourage consistency."""
        import re
        import os
        from datetime import datetime
        def iou(box1, box2):
            inter_x1 = max(box1[0], box2[0])
            inter_y1 = max(box1[1], box2[1])
            inter_x2 = min(box1[2]-1, box2[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1)
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter = 0
            union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
            return float(inter)/union if union > 0 else 0.0
            
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        think_tag_pattern = r'<think>(.*?)</think>'
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]'
        
        for content, sol in zip(contents, solution):
            think_reward = 0.0
            answer_reward = 0.0
            
            try:
                think_match = re.search(think_tag_pattern, content, re.DOTALL)
                if think_match:
                    think_content = think_match.group(1).strip()
                    think_bbox_matches = re.findall(bbox_pattern, think_content)
                    if think_bbox_matches:
                        last_bbox_match = think_bbox_matches[-1]
                        think_bbox = [int(last_bbox_match[0]), int(last_bbox_match[1]), 
                                int(last_bbox_match[2]), int(last_bbox_match[3])]
                        
                        think_reward = iou(think_bbox, sol)
            except Exception as e:
                if os.getenv("DEBUG_MODE") == "true":
                    log_path = os.getenv("LOG_PATH")
                    with open(log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} Error in think part: {str(e)} -------------\n")
                pass
            
            try:
                answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
                if answer_match:
                    answer_content = answer_match.group(1).strip()
                    answer_bbox_match = re.search(bbox_pattern, answer_content)
                    if answer_bbox_match:
                        answer_bbox = [int(answer_bbox_match.group(1)), int(answer_bbox_match.group(2)), 
                                    int(answer_bbox_match.group(3)), int(answer_bbox_match.group(4))]
                        
                        answer_reward = iou(answer_bbox, sol)
            except Exception as e:
                if os.getenv("DEBUG_MODE") == "true":
                    log_path = os.getenv("LOG_PATH")
                    with open(log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} Error in answer part: {str(e)} -------------\n")
                pass
            
            # final_reward = think_reward * answer_reward if answer_reward > 0 else 0.0
            final_reward = think_reward
                    
            rewards.append(final_reward)
            
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                with open(log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} TAC Reward -------------\n")
                    f.write(f"Think IoU: {think_reward:.4f}, Answer IoU: {answer_reward:.4f}\n")
                    f.write(f"Final reward: {final_reward:.4f}\n")
                    f.write(f"Content: {content[:200]}...\n")
                    f.write(f"Solution: {sol}\n")
        
        return rewards
   

    
    


    @staticmethod
    def TAC_reward(completions, solution, **kwargs):
        import re
        import os
        from datetime import datetime

        def iou3(box1, box2, box3):
            if not box1 or not box2 or not box3:
                return 0.0
            
            inter_x1 = max(box1[0], box2[0], box3[0])
            inter_y1 = max(box1[1], box2[1], box3[1])
            inter_x2 = min(box1[2]-1, box2[2]-1, box3[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1, box3[3]-1)
            
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter_ABC = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter_ABC = 0
            
            area1 = (box1[2]-box1[0])*(box1[3]-box1[1])
            area2 = (box2[2]-box2[0])*(box2[3]-box2[1])
            area3 = (box3[2]-box3[0])*(box3[3]-box3[1])
            
            # A∩B
            inter_AB_x1 = max(box1[0], box2[0])
            inter_AB_y1 = max(box1[1], box2[1])
            inter_AB_x2 = min(box1[2]-1, box2[2]-1)
            inter_AB_y2 = min(box1[3]-1, box2[3]-1)
            if inter_AB_x1 < inter_AB_x2 and inter_AB_y1 < inter_AB_y2:
                inter_AB = (inter_AB_x2-inter_AB_x1+1)*(inter_AB_y2-inter_AB_y1+1)
            else:
                inter_AB = 0
            
            # A∩C
            inter_AC_x1 = max(box1[0], box3[0])
            inter_AC_y1 = max(box1[1], box3[1])
            inter_AC_x2 = min(box1[2]-1, box3[2]-1)
            inter_AC_y2 = min(box1[3]-1, box3[3]-1)
            if inter_AC_x1 < inter_AC_x2 and inter_AC_y1 < inter_AC_y2:
                inter_AC = (inter_AC_x2-inter_AC_x1+1)*(inter_AC_y2-inter_AC_y1+1)
            else:
                inter_AC = 0
            
            # B∩C
            inter_BC_x1 = max(box2[0], box3[0])
            inter_BC_y1 = max(box2[1], box3[1])
            inter_BC_x2 = min(box2[2]-1, box3[2]-1)
            inter_BC_y2 = min(box2[3]-1, box3[3]-1)
            if inter_BC_x1 < inter_BC_x2 and inter_BC_y1 < inter_BC_y2:
                inter_BC = (inter_BC_x2-inter_BC_x1+1)*(inter_BC_y2-inter_BC_y1+1)
            else:
                inter_BC = 0
            
            # |A∪B∪C| = |A| + |B| + |C| - |A∩B| - |A∩C| - |B∩C| + |A∩B∩C|
            union = area1 + area2 + area3 - inter_AB - inter_AC - inter_BC + inter_ABC
            
            return float(inter_ABC)/union if union > 0 else 0.0

        def extract_last_bbox(text):
            bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]'
            all_matches = re.findall(bbox_pattern, text)
            
            if all_matches:
                last_match = all_matches[-1]
                return [int(last_match[0]), int(last_match[1]), int(last_match[2]), int(last_match[3])]
            return None

        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")


        for i, (content, sol) in enumerate(zip(contents, solution)):
            think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
            think_text = think_match.group(1) if think_match else ""
            think_bbox = extract_last_bbox(think_text)

            answer_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
            answer_text = answer_match.group(1) if answer_match else ""
            answer_bbox = extract_last_bbox(answer_text)


            if think_bbox and answer_bbox and sol:
                reward = iou3(think_bbox, answer_bbox, sol)
            else:
                reward = 0.0
            
            rewards.append(reward)


        return rewards

    @staticmethod
    def odLength_reward(content, sol, **kwargs):
        """
        Calculate reward for object detection task with length penalty.
        
        Args:
            content (str): Model's predicted answer containing bounding box annotations
            sol (str): Ground truth answer containing bounding box annotations
            **kwargs: Additional keyword arguments
        
        Returns:
            float: Reward score between 0 and 1 based on mAP and length penalty
        """
        # Pattern to extract content between <answer> tags
        from open_r1.utils.pycocotools.coco import COCO
        from open_r1.utils.pycocotools.cocoeval import COCOeval

        def calculate_map(pred_bbox_list, gt_bbox_list, score_type=0):
            # Calculate mAP

            # Initialize COCO object for ground truth
            gt_json = {"annotations": [], "images": [], "categories": []}
            gt_json["images"] = [{
                "id": 0,
                "width": 2048,
                "height": 2048,
                "file_name": "image_0.jpg"
            }]

            gt_json["categories"] = []

            cats2id = {}
            cat_count = 0
            for idx, gt_bbox in enumerate(gt_bbox_list):
                if gt_bbox["label"] not in cats2id:
                    cats2id[gt_bbox["label"]] = cat_count
                    gt_json["categories"].append({
                        "id": cat_count,
                        "name": gt_bbox["label"]
                    })
                    cat_count += 1

                gt_json["annotations"].append({
                    "id": idx+1,
                    "image_id": 0,
                    "category_id": cats2id[gt_bbox["label"]],
                    "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
                    "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
                    "iscrowd": 0
                })
            coco_gt = COCO(gt_json)

            dt_json = []
            for idx, pred_bbox in enumerate(pred_bbox_list):
                try:
                    dt_json.append({
                        "image_id": 0,
                        "category_id": cats2id[pred_bbox["label"]],
                        "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
                        "score": 1.0,
                        "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
                    })
                except:
                    pass

            if len(dt_json) == 0:
                return 0.0

            coco_dt = coco_gt.loadRes(dt_json)
            coco_eval = COCOeval(coco_gt, coco_dt, "bbox")

            coco_eval.evaluate()
            coco_eval.accumulate()
            coco_eval.summarize()
            return coco_eval.stats[score_type]



        def map_reward(content, sol, length_reward=False, score_type=0, **kwargs):
            """
            Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes.
            
            Args:
                content (str): String containing predicted bounding boxes in JSON format
                sol (str): String containing ground truth bounding boxes in JSON format
                length_reward (bool, optional): Whether to include length penalty in reward calculation. Defaults to False.
                score_type (int, optional): Type of COCO evaluation metric to use. Defaults to 0 (mAP).
                **kwargs: Additional keyword arguments
                
            Returns:
                float: mAP reward score between 0 and 1. If length_reward is True, the score is multiplied by a length penalty factor.
            """
            # Extract JSON content between ```json tags
            import json

            pattern = r'```json(.*?)```'
            json_match = re.findall(pattern, sol, re.DOTALL)
            bbox_json = json_match[-1].strip() if json_match else None
            # Parse ground truth JSON to get bbox list
            gt_bbox_list = []
            if bbox_json:
                bbox_data = json.loads(bbox_json)
                gt_bbox_list = [item for item in bbox_data]
            
            # Parse predicted JSON to get bbox list
            pred_bbox_list = []
            json_match = re.findall(pattern, content, re.DOTALL)
            if json_match:
                try:
                    bbox_data = json.loads(json_match[-1].strip())
                    pred_bbox_list = [item for item in bbox_data]
                except:
                    # Return empty list if JSON parsing fails
                    pred_bbox_list = []

            # Calculate mAP if both prediction and ground truth exist
            if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
                bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list, score_type=score_type)
            elif len(pred_bbox_list) == 0 and len(gt_bbox_list) == 0:
                bbox_reward = 1.0
            else:
                bbox_reward = 0.0
            
            if length_reward:
                # Calculate length penalty based on ratio of ground truth to predicted bounding boxes
                gt_length = len(gt_bbox_list)
                pred_length = len(pred_bbox_list)
                # Full score if prediction has fewer boxes than ground truth, otherwise penalize proportionally
                length_score = 1.0 if gt_length >= pred_length else gt_length/pred_length
                return bbox_reward * length_score
            else:
                return bbox_reward


        match_pattern = r'<answer>(.*?)</answer>'

        # Extract ground truth answer
        sol_match = re.search(match_pattern, sol, re.DOTALL)
        ground_truth = sol_match.group(1).strip() if sol_match else None
        # Extract predicted answer (using last match if multiple)
        content_match = re.findall(match_pattern, content, re.DOTALL)
        student_answer = content_match[-1].strip() if content_match else None

        # Return 0 if no prediction
        if student_answer is None:
            return 0.0
        # Return 1 if both prediction and ground truth are None
        elif ground_truth == "None" and student_answer == "None":
            return 1.0
        # Calculate mAP with length penalty
        else:
            bbox_reward = map_reward(student_answer, ground_truth, length_reward=True, score_type=0)
            return bbox_reward
        

    def accuracy_reward(completions, solution, **kwargs):
        
        def odLength_reward(content, sol, **kwargs):
            """
            Calculate reward for object detection task with length penalty.
            
            Args:
                content (str): Model's predicted answer containing bounding box annotations
                sol (str): Ground truth answer containing bounding box annotations
                **kwargs: Additional keyword arguments
            
            Returns:
                float: Reward score between 0 and 1 based on mAP and length penalty
            """
            # Pattern to extract content between <answer> tags
            from src.open_r1.utils.pycocotools.coco import COCO
            from src.open_r1.utils.pycocotools.cocoeval import COCOeval

            def calculate_map(pred_bbox_list, gt_bbox_list, score_type=0):
                # Calculate mAP

                # Initialize COCO object for ground truth
                gt_json = {"annotations": [], "images": [], "categories": []}
                gt_json["images"] = [{
                    "id": 0,
                    "width": 2048,
                    "height": 2048,
                    "file_name": "image_0.jpg"
                }]

                gt_json["categories"] = []

                cats2id = {}
                cat_count = 0
                for idx, gt_bbox in enumerate(gt_bbox_list):
                    if gt_bbox["label"] not in cats2id:
                        cats2id[gt_bbox["label"]] = cat_count
                        gt_json["categories"].append({
                            "id": cat_count,
                            "name": gt_bbox["label"]
                        })
                        cat_count += 1

                    gt_json["annotations"].append({
                        "id": idx+1,
                        "image_id": 0,
                        "category_id": cats2id[gt_bbox["label"]],
                        "bbox": [gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][1], gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0], gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]],
                        "area": (gt_bbox["bbox_2d"][2] - gt_bbox["bbox_2d"][0]) * (gt_bbox["bbox_2d"][3] - gt_bbox["bbox_2d"][1]),
                        "iscrowd": 0
                    })
                coco_gt = COCO(gt_json)

                dt_json = []
                for idx, pred_bbox in enumerate(pred_bbox_list):
                    try:
                        dt_json.append({
                            "image_id": 0,
                            "category_id": cats2id[pred_bbox["label"]],
                            "bbox": [pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][1], pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0], pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1]],
                            "score": 1.0,
                            "area": (pred_bbox["bbox_2d"][2] - pred_bbox["bbox_2d"][0]) * (pred_bbox["bbox_2d"][3] - pred_bbox["bbox_2d"][1])
                        })
                    except:
                        pass

                if len(dt_json) == 0:
                    return 0.0

                coco_dt = coco_gt.loadRes(dt_json)
                coco_eval = COCOeval(coco_gt, coco_dt, "bbox")

                coco_eval.evaluate()
                coco_eval.accumulate()
                coco_eval.summarize()
                return coco_eval.stats[score_type]



            def map_reward(content, sol, length_reward=False, score_type=0, **kwargs):
                """
                Calculate mean average precision (mAP) reward between predicted and ground truth bounding boxes.
                
                Args:
                    content (str): String containing predicted bounding boxes in JSON format
                    sol (str): String containing ground truth bounding boxes in JSON format
                    length_reward (bool, optional): Whether to include length penalty in reward calculation. Defaults to False.
                    score_type (int, optional): Type of COCO evaluation metric to use. Defaults to 0 (mAP).
                    **kwargs: Additional keyword arguments
                    
                Returns:
                    float: mAP reward score between 0 and 1. If length_reward is True, the score is multiplied by a length penalty factor.
                """
                # Extract JSON content between ```json tags
                import json


                pattern = r'```json(.*?)```'
                json_match = re.findall(pattern, sol, re.DOTALL)
                bbox_json = json_match[-1].strip() if json_match else None
                # Parse ground truth JSON to get bbox list
                
                gt_bbox_list = []
                if bbox_json:
                    bbox_data = json.loads(bbox_json)
                    gt_bbox_list = [item for item in bbox_data]
                
                # Parse predicted JSON to get bbox list
                pred_bbox_list = []
                json_match = re.findall(pattern, content, re.DOTALL)
                if json_match:
                    try:
                        bbox_data = json.loads(json_match[-1].strip())
                        pred_bbox_list = [item for item in bbox_data]
                    except:
                        # Return empty list if JSON parsing fails
                        pred_bbox_list = []

                # Calculate mAP if both prediction and ground truth exist
                if len(pred_bbox_list) > 0 and len(gt_bbox_list) > 0:
                    bbox_reward = calculate_map(pred_bbox_list, gt_bbox_list, score_type=score_type)
                elif len(pred_bbox_list) == 0 and len(gt_bbox_list) == 0:
                    bbox_reward = 1.0
                else:
                    bbox_reward = 0.0
                
                if length_reward:
                    # Calculate length penalty based on ratio of ground truth to predicted bounding boxes
                    gt_length = len(gt_bbox_list)
                    pred_length = len(pred_bbox_list)
                    # Full score if prediction has fewer boxes than ground truth, otherwise penalize proportionally
                    length_score = 1.0 if gt_length >= pred_length else gt_length/pred_length
                    return bbox_reward * length_score
                else:
                    return bbox_reward


            match_pattern = r'<answer>(.*?)</answer>'

            # Extract ground truth answer

            ground_truth = sol
            # Extract predicted answer (using last match if multiple)
            content_match = re.findall(match_pattern, content, re.DOTALL)
            student_answer = content_match[-1].strip() if content_match else None

            # Return 0 if no prediction
            if student_answer is None:
                return 0.0
            # Return 1 if both prediction and ground truth are None
            elif ground_truth == "None" and student_answer == "None":
                return 1.0
            # Calculate mAP with length penalty
            else:
                bbox_reward = map_reward(student_answer, ground_truth, length_reward=True, score_type=0)
                return bbox_reward


        def TAC_reward(completion, solution, **kwargs):
            import re
            import os
            from datetime import datetime

            def iou3(box1, box2, box3):
                if not box1 or not box2 or not box3:
                    return 0.0
                
                inter_x1 = max(box1[0], box2[0], box3[0])
                inter_y1 = max(box1[1], box2[1], box3[1])
                inter_x2 = min(box1[2]-1, box2[2]-1, box3[2]-1)
                inter_y2 = min(box1[3]-1, box2[3]-1, box3[3]-1)
                
                if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                    inter_ABC = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
                else:
                    inter_ABC = 0
                
                area1 = (box1[2]-box1[0])*(box1[3]-box1[1])
                area2 = (box2[2]-box2[0])*(box2[3]-box2[1])
                area3 = (box3[2]-box3[0])*(box3[3]-box3[1])
                
                # A∩B
                inter_AB_x1 = max(box1[0], box2[0])
                inter_AB_y1 = max(box1[1], box2[1])
                inter_AB_x2 = min(box1[2]-1, box2[2]-1)
                inter_AB_y2 = min(box1[3]-1, box2[3]-1)
                if inter_AB_x1 < inter_AB_x2 and inter_AB_y1 < inter_AB_y2:
                    inter_AB = (inter_AB_x2-inter_AB_x1+1)*(inter_AB_y2-inter_AB_y1+1)
                else:
                    inter_AB = 0
                
                # A∩C
                inter_AC_x1 = max(box1[0], box3[0])
                inter_AC_y1 = max(box1[1], box3[1])
                inter_AC_x2 = min(box1[2]-1, box3[2]-1)
                inter_AC_y2 = min(box1[3]-1, box3[3]-1)
                if inter_AC_x1 < inter_AC_x2 and inter_AC_y1 < inter_AC_y2:
                    inter_AC = (inter_AC_x2-inter_AC_x1+1)*(inter_AC_y2-inter_AC_y1+1)
                else:
                    inter_AC = 0
                
                # B∩C
                inter_BC_x1 = max(box2[0], box3[0])
                inter_BC_y1 = max(box2[1], box3[1])
                inter_BC_x2 = min(box2[2]-1, box3[2]-1)
                inter_BC_y2 = min(box2[3]-1, box3[3]-1)
                if inter_BC_x1 < inter_BC_x2 and inter_BC_y1 < inter_BC_y2:
                    inter_BC = (inter_BC_x2-inter_BC_x1+1)*(inter_BC_y2-inter_BC_y1+1)
                else:
                    inter_BC = 0
                
                # |A∪B∪C| = |A| + |B| + |C| - |A∩B| - |A∩C| - |B∩C| + |A∩B∩C|
                union = area1 + area2 + area3 - inter_AB - inter_AC - inter_BC + inter_ABC
                
                return float(inter_ABC)/union if union > 0 else 0.0

            def extract_last_bbox(text):
                bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]'
                all_matches = re.findall(bbox_pattern, text)
                
                if all_matches:
                    last_match = all_matches[-1]
                    return [int(last_match[0]), int(last_match[1]), int(last_match[2]), int(last_match[3])]
                return None

            content = completion[0]["content"]

            think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
            think_text = think_match.group(1) if think_match else ""
            think_bbox = extract_last_bbox(think_text)

            answer_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
            answer_text = answer_match.group(1) if answer_match else ""
            answer_bbox = extract_last_bbox(answer_text)


            if think_bbox and answer_bbox and sol:
                reward = iou3(think_bbox, answer_bbox, sol)
            else:
                reward = 0.0
            


            return reward

        
        task_list = kwargs.get('task', [])
    
        
        rewards = []
    
        for i, (completion, sol, task) in enumerate(zip(completions, solution, task_list)):
            
            if task == "rec":
                sample_reward = TAC_reward(completion, sol, **kwargs)
            elif task == "ovd":
                content_text = completion["content"] if isinstance(completion, dict) else completion[0]["content"]
                sample_reward = odLength_reward(content_text, sol, **kwargs)

            
            rewards.append(sample_reward)
        
        return rewards