import re
import json
from scipy.optimize import linear_sum_assignment
import numpy as np


def extract_think_description(output_text):
    think_pattern = r'<think>([^<]+)</think>'
    think_text = ""
    think_match = re.search(think_pattern, output_text)
    if think_match:
        think_text = think_match.group(1)
    
    description_pattern = r'<description>([^<]+)</description>'
    description_text = ""
    description_match = re.search(description_pattern, output_text)
    if description_match:
        description_text = description_match.group(1)
    
    return think_text, description_text

def dr2seg_format_reward(predict_str: str) -> float:
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    match = re.fullmatch(pattern, predict_str, re.DOTALL)
    thinking_format_reward = 1.0 if match else 0.0 
    
    def segmentation_format(predict_str: str) -> float:
        segmentation_format_reward = 0.0
        try:
            json_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
            if not json_match:
                return segmentation_format_reward
            data = json.loads(json_match.group(1))
            
            data_cnt = len(data)
            
            for item in data:
                cur_reward = 0.0

                if 'bbox_2d' in item:
                    bbox_2d = item['bbox_2d']
                    if isinstance(bbox_2d, list) and len(bbox_2d) == 4:
                        cur_reward += 1.0
                    
                if 'point_2d' in item:
                    point_2d = item['point_2d']
                    if isinstance(point_2d, list) and len(point_2d) == 2:
                        cur_reward += 1.0
                
                segmentation_format_reward += cur_reward / data_cnt
        except Exception:
            pass
        return segmentation_format_reward
        
    segmentation_format_reward = segmentation_format(predict_str)
    
    return thinking_format_reward + segmentation_format_reward


    

def dr2seg_accuracy_reward(predict_str: str, ground_truth: str) -> float:
    max_accuracy_reward = 0.0
    MAX_OBJECTS = 120  # 设置上限
    
    try:
        gt_bboxes = ground_truth["bbox_2d"]
        gt_points = ground_truth["point_2d"]
            
        json_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
        if json_match:
            data = json.loads(json_match.group(1))
            pred_bboxes = [item['bbox_2d'] for item in data]
            pred_points = [item['point_2d'] for item in data]
            
            # 只有当预测或真实值超过上限时才截断
            if len(pred_bboxes) > MAX_OBJECTS:
                pred_bboxes = pred_bboxes[:MAX_OBJECTS]
                pred_points = pred_points[:MAX_OBJECTS]
            
            if len(gt_bboxes) > MAX_OBJECTS:
                gt_bboxes = gt_bboxes[:MAX_OBJECTS]
                gt_points = gt_points[:MAX_OBJECTS]
            
            # 预处理数据为numpy数组
            pred_bboxes = np.array(pred_bboxes)  # (M,4)
            pred_points = np.array(pred_points)  # (M,2)
            gt_bboxes = np.array(gt_bboxes)    # (N,4)
            gt_points = np.array(gt_points)     # (N,2)
            
            # 并行计算所有指标
            iou_matrix = batch_iou(pred_bboxes, gt_bboxes)  # (M,N)
            l1_matrix = batch_l1_distance(pred_bboxes, gt_bboxes)  # (M,N)
            points_dist_matrix = batch_points_distance(pred_points, gt_points)  # (M,N)
            points_in_box = batch_points_in_box(pred_points, pred_bboxes)  # (M,)
            
            # 计算reward矩阵
            iou_reward = (iou_matrix > 0.5).astype(float)
            bbox_l1_reward = (l1_matrix < 10).astype(float)
            point_reward = ((points_dist_matrix < 30) & points_in_box[:,np.newaxis]).astype(float)
            
            # 构建最终的cost矩阵
            cost_matrix = 3.0 - (iou_reward + bbox_l1_reward + point_reward)
            
            # 使用匈牙利算法找最优匹配
            row_indices, col_indices = linear_sum_assignment(cost_matrix)
            
            # 直接从cost_matrix计算总reward
            total_reward = len(row_indices) * 3.0 - cost_matrix[row_indices, col_indices].sum()
            
            # 计算平均reward
            max_length = max(len(pred_bboxes), len(gt_bboxes))
            max_accuracy_reward = total_reward / max_length
            
    except Exception:
        pass
    return max_accuracy_reward

# def dr2seg_accuracy_reward(predict_str: str, ground_truth: str) -> float:
#     max_accuracy_reward = 0.0
#     MAX_OBJECTS = 120  # 设置上限
    
#     try:
#         gt_data = json.loads(ground_truth)
#         gt_bboxes = [item['bbox_2d'] for item in gt_data]
#         gt_points = [item['point_2d'] for item in gt_data]
            
#         #json_match = re.search(r'```json\s*(.*?)\s*```', predict_str, re.DOTALL)
#         json_match = re.search(r'<answer>\s*(.*?)\s*</answer>', predict_str, re.DOTALL)
#         if json_match:
#             data = json.loads(json_match.group(1))
#             pred_bboxes = [item['bbox_2d'] for item in data]
#             pred_points = [item['point_2d'] for item in data]
            
#             # 只有当预测或真实值超过上限时才截断
#             if len(pred_bboxes) > MAX_OBJECTS:
#                 pred_bboxes = pred_bboxes[:MAX_OBJECTS]
#                 pred_points = pred_points[:MAX_OBJECTS]
            
#             if len(gt_bboxes) > MAX_OBJECTS:
#                 gt_bboxes = gt_bboxes[:MAX_OBJECTS]
#                 gt_points = gt_points[:MAX_OBJECTS]
            
#             # 预处理数据为numpy数组
#             pred_bboxes = np.array(pred_bboxes)  # (M,4)
#             pred_points = np.array(pred_points)  # (M,2)
#             gt_bboxes = np.array(gt_bboxes)    # (N,4)
#             gt_points = np.array(gt_points)     # (N,2)
            
#             # 并行计算所有指标
#             iou_matrix = batch_iou(pred_bboxes, gt_bboxes)  # (M,N)
#             l1_matrix = batch_l1_distance(pred_bboxes, gt_bboxes)  # (M,N)
#             points_dist_matrix = batch_points_distance(pred_points, gt_points)  # (M,N)
#             points_in_box = batch_points_in_box(pred_points, pred_bboxes)  # (M,)
            
#             # 计算reward矩阵
#             iou_reward = (iou_matrix > 0.5).astype(float)
#             bbox_l1_reward = (l1_matrix < 10).astype(float)
#             point_reward = ((points_dist_matrix < 30) & points_in_box[:,np.newaxis]).astype(float)
            
#             # 构建最终的cost矩阵
#             cost_matrix = 3.0 - (iou_reward + bbox_l1_reward + point_reward)
            
#             # 使用匈牙利算法找最优匹配
#             row_indices, col_indices = linear_sum_assignment(cost_matrix)
            
#             # 直接从cost_matrix计算总reward
#             total_reward = len(row_indices) * 3.0 - cost_matrix[row_indices, col_indices].sum()
            
#             # 计算平均reward
#             max_length = max(len(pred_bboxes), len(gt_bboxes))
#             max_accuracy_reward = total_reward / max_length
            
#     except Exception:
#         pass
#     return max_accuracy_reward

def dr2seg_non_repeat_reward(predict_str: str, description_predict_str: str) -> float:
    non_repeat_reward = 1.0  # 初始满分
    try:
        # 分句
        def split_sentences(text):
            parts = text.split('.')
            return [s.strip() for s in parts if s.strip()]
        
        pred_sents = split_sentences(predict_str)
        desc_sents = split_sentences(description_predict_str)

        # 合并所有句子
        all_sents = pred_sents + desc_sents

        seen = {}
        repeats = 0

        # 统计每个句子的出现次数
        for s in all_sents:
            seen[s] = seen.get(s, 0) + 1
            if seen[s] >= 2:
                repeats += 1
        
        # 若出现重复（自身或互重复），奖励为 0
        if repeats >= 1:
            non_repeat_reward = 0.0

    except Exception:
        pass

    return non_repeat_reward

def dr2seg_compute_score(predict_str: str, description_answers_str: str, ground_truth: str, format_weight: float = 0.1) -> float:
    format_reward = dr2seg_format_reward(predict_str)
    accuracy_reward = dr2seg_accuracy_reward(predict_str, ground_truth)
    non_repeat_reward = dr2seg_non_repeat_reward(predict_str, description_answers_str)

    description_reward = dr2seg_accuracy_reward(description_answers_str, ground_truth)
    reward = format_reward + accuracy_reward + description_reward + non_repeat_reward

    # compute the lenghth of think and description
    response_think_str, response_description_str = extract_think_description(predict_str)
    if len(response_think_str) == 0 or len(response_description_str) == 0:
        response_str = predict_str
    else:
        response_str = response_think_str + response_description_str
    
    description_answers_think_str, description_answers_description_str = extract_think_description(description_answers_str)
    if len(description_answers_think_str) == 0 or len(description_answers_description_str) == 0:
        pass
    else:
        description_answers_str = description_answers_think_str + description_answers_description_str


    return reward, response_str, description_answers_str

def batch_iou(boxes1, boxes2):
    # boxes1: (M,4), boxes2: (N,4)
    # 广播机制自动扩展维度
    x11, y11, x12, y12 = np.split(boxes1, 4, axis=1)  # (M,1)
    x21, y21, x22, y22 = np.split(boxes2, 4, axis=1)  # (N,1)
    
    xA = np.maximum(x11, np.transpose(x21))  # (M,N)
    yA = np.maximum(y11, np.transpose(y21))
    xB = np.minimum(x12, np.transpose(x22))
    yB = np.minimum(y12, np.transpose(y22))
    
    interArea = np.maximum(0, xB - xA + 1) * np.maximum(0, yB - yA + 1)
    box1Area = (x12 - x11 + 1) * (y12 - y11 + 1)  # (M,1)
    box2Area = (x22 - x21 + 1) * (y22 - y21 + 1)  # (N,1)
    
    unionArea = box1Area + np.transpose(box2Area) - interArea
    iou = interArea / unionArea  # (M,N)
    return iou

def batch_l1_distance(boxes1, boxes2):
    # boxes1: (M,4), boxes2: (N,4)
    boxes1 = boxes1[:, np.newaxis, :]  # (M,1,4)
    boxes2 = boxes2[np.newaxis, :, :]  # (1,N,4)
    return np.mean(np.abs(boxes1 - boxes2), axis=2)  # (M,N)

def batch_points_distance(points1, points2):
    # points1: (M,2), points2: (N,2)
    points1 = points1[:, np.newaxis, :]  # (M,1,2)
    points2 = points2[np.newaxis, :, :]  # (1,N,2)
    
    # 计算欧氏距离
    dist = np.sqrt(np.sum((points1 - points2)**2, axis=2))  # (M,N)
    return dist

def batch_points_in_box(points, boxes):
    """
    检查每个点是否在对应的框内
    points: (M,2) - M个点的坐标
    boxes: (M,4) - M个框的坐标 [x1,y1,x2,y2]
    返回: (M,) 布尔数组
    """
    x_check = (points[:,0] >= boxes[:,0]) & (points[:,0] <= boxes[:,2])
    y_check = (points[:,1] >= boxes[:,1]) & (points[:,1] <= boxes[:,3])
    return x_check & y_check

if __name__ == "__main__":
    predict_str = """
<answer>
[{"bbox_2d": [10, 100, 398, 423], "point_2d": [283, 169]}]
</answer>
"""
    ground_truth = """
[{"bbox_2d": [416, 7, 833, 553], "point_2d": [648, 249]}]"""
    print(predict_str)
    print(ground_truth)
    print(dr2seg_compute_score(predict_str, ground_truth))
    