import re
import json


def extract_solution(solution_str, method='strict'):
    assert method in ['strict', 'flexible']
    # Step 0: 从字符串中尝试提取 JSON 块
    if isinstance(solution_str, str):
        try:
            #match = re.search(r"```json\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", solution_str, re.DOTALL)
            if match:
                parsed = json.loads(match.group(1))
            else:
                parsed = json.loads(solution_str)
        except json.JSONDecodeError:
            return False, ["solution_str is a string but not valid JSON"], None
    else:
        return False, ["solution_str is not a string"], None
    
    # Step 1: 检查是否为 dict
    if not isinstance(parsed, dict):
        return False, ["Parsed content is not a JSON object"], None
    
    # Step 2: 必须包含 Action 和 Objects 字段
    if ("Action" not in parsed ) or ("Objects" not in parsed ):
        return False, ["Missing 'Action' or 'Objects' field"], None
    
    try:
        action = parsed["Action"]
        objects = parsed["Objects"]
    except:
        return False, ["Invalid format of 'Action' or 'Objects' field"], None

    # Step 3: 如果指定合法动作，检查 Action 合法性
    valid_actions = ["Explore Entity", "Choose Relation", "Finish"]
    valid_actions_lower = [a.lower() for a in valid_actions]
    if not isinstance(action, str) or action.lower() not in valid_actions_lower:
        return False, ["Invalid Action"], None

    # Step 4: 校验 Objects 类型
    if not isinstance(objects, list):
        return False, ["Objects should be a list"], None

    final_answer = {"Action": action, "Objects": objects} 
    return True, [], final_answer



def compute_score(solution_str, ground_truth, method='strict', format_score=0., score=1.):
    """The scoring function for GraphQA.
    Args:
        solution_str: the solution text
        ground_truth: the ground truth
        method: the method to extract the solution, choices are 'strict' and 'flexible'
        format_score: the score for the format
        score: the score for the correct answer
    """
    
    result, _, answer = extract_solution(solution_str=solution_str, method=method)
    print(answer)
    
    print("******************** Response ********************")
    print(solution_str)
    print("******************** Ground Truth ********************")
    print(ground_truth)
    
    
    score = 0
    # 格式错误，0分; 格式正确，基础分 0.1 分
    if result == False:
        print(f"格式错误, ******************** score: {score} ********************")
        return 0
    else:
        score = 0.1
        
    gt_action = ground_truth["Action"].lower()
    gt_obj = [obj.lower() for obj in ground_truth["Object"]]
    try:
        ans_action = answer["Action"].lower()
    except:
        ans_action = " "
    try:
        ans_obj = [obj.lower() for obj in answer["Objects"]]
    except:
        ans_obj = []

    
    """ prm 写法 """
    print(set(ans_obj) & set(gt_obj))
    if ans_action == gt_action and (set(ans_obj) == set(gt_obj) or set(gt_obj).issubset(set(ans_obj))):
        score = 1.0
        print(f"******************** 答案正确, score: {score} ********************")
        return 1
    else:
        # if len(set(ans_obj)) != 0 and set(ans_obj).issubset(set(gt_obj)): 
        if set(ans_obj) & set(gt_obj):
            score = 0.2  # 答对了一部分
            print(f"******************** 答案部分正确, score: {score} ********************")
        else:
            print(f"******************** 答案错误, score: {score} ********************")
    return score

    # """ orm 写法 """
    # print(set(ans_obj) & set(gt_obj))
    # if gt_action == "Finish":
    #     if ans_action == gt_action and (set(ans_obj) == set(gt_obj) or set(gt_obj).issubset(set(ans_obj))):
    #         score = 1.0
    #         print(f"******************** 答案正确, score: {score} ********************")
    #         return score
    #     else:
    #         # if len(set(ans_obj)) != 0 and set(ans_obj).issubset(set(gt_obj)): 
    #         if set(ans_obj) & set(gt_obj):
    #             score = 0.2  # 答对了一部分
    #             print(f"******************** 答案部分正确, score: {score} ********************")
    #         else:
    #             print(f"******************** 答案错误, score: {score} ********************")
    #     return score
    # else:
    #     print("非最终 step, 无奖励。")
    #     return 0.0



if __name__ == "__main__":
    example_data = """
    {
        "Thought": "The question asks what is in front of the man. From the current now_state, we already have the triple [name: net; ..., in front of, name: man; ...], which directly tells us that the net is in front of the man. No further exploration is necessary.",
        "Action": "Finish",
        "Objects": ["net"]
    }
    """
    
    gt = {
        "Action": "Choose Relation",
        "Object": [
            "name: net; (x,y,w,h): (0, 153, 499, 206), in front of, name: man; (x,y,w,h): (124, 1, 265, 359)"
        ]
    }
    score = compute_score(example_data, gt)
    print(score)
    
    

   
    

