from typing import List, Tuple
import re
import ast

def extract_trajectory_agent1(response_text: str):
    pattern = r'<\s*trajectory\s+for\s+agent1\s*>\s*=\s*(\[[^\]]*\])'
    m = re.search(pattern, response_text, flags=re.IGNORECASE | re.DOTALL)
    if not m:
        raise ValueError("未找到 `<trajectory for agent1> = [...]` 结构")

    coord_str = m.group(1).strip()
    try:
        coords = ast.literal_eval(coord_str)
    except Exception as e:
        raise ValueError(f"坐标列表解析失败：{e}")

    if not isinstance(coords, (list, tuple)) or not all(
        isinstance(p, (list, tuple)) and len(p) == 2 for p in coords
    ):
        raise ValueError("解析结果不是二维坐标的列表")

    return [tuple(int(v) for v in p) for p in coords]

def parse_and_calculate_reward_tree_new(prompts=None, completions=None, completions_ids=None, node_depth=None, **kwargs):

    import collections
    import math

    IDEAL_ANCHORS = int(kwargs.get('ideal_anchor_num', 2))
    ALPHA_ANCHOR = float(kwargs.get('alpha_anchor_penalty', 2.0))
    BETA_LOOP = float(kwargs.get('beta_loop_penalty', 1.0))
    K_BASE = float(kwargs.get('k_base_quality_weight', 0.6))
    P_EXP = float(kwargs.get('p_exponent', 1.2))
    MIN_SCORE = float(kwargs.get('min_score_clamp', -20.0))
    MAX_SCORE = float(kwargs.get('max_score_clamp', 20.0))
    PARSE_FAIL_PENALTY = float(kwargs.get('parse_fail_penalty', -10.0))

    def _to_tuple(p):
        return tuple(p) if isinstance(p, (list, tuple)) else p

    point_reward_sum = collections.defaultdict(float)
    point_visit_count = collections.defaultdict(int)

    completion_coord_lists = []
    completion_base_rewards = [] 
    scores = [] 

    for idx, completion in enumerate(completions):
        astar_pts = kwargs.get('astar_path', [[]])[idx]
        astar_set = set(_to_tuple(p) for p in astar_pts)
        astar_len = len(astar_pts)

        if isinstance(completion, list) and completion and isinstance(completion[0], dict) and "content" in completion[0]:
            response_text = completion[0]["content"]
        elif isinstance(completion, dict) and "content" in completion:
            response_text = completion["content"]
        else:
            response_text = str(completion)
        coord_list = []
        try:
            lines = response_text.strip().split('\n')
            for line in lines:
                coord_list = extract_trajectory_agent1(line)

            if  coord_list:
                completion_coord_lists.append([])
                completion_base_rewards.append(PARSE_FAIL_PENALTY)
                scores.append(float(PARSE_FAIL_PENALTY))
                continue

        except Exception as e:
            completion_coord_lists.append([])
            completion_base_rewards.append(PARSE_FAIL_PENALTY)
            scores.append(float(PARSE_FAIL_PENALTY))
            continue

        if any(
            not isinstance(p, (list, tuple)) or len(p) != 2 or not all(isinstance(val, (int, float)) for val in p)
            for p in coord_list
        ):
            completion_coord_lists.append([])
            completion_base_rewards.append(PARSE_FAIL_PENALTY)
            scores.append(float(PARSE_FAIL_PENALTY))
            continue

        base_quality = 0.0
        if astar_set and astar_len > 2 and len(coord_list) > 2:
            path_length = 0
            for i in range(1, len(coord_list)):
                x0, y0 = coord_list[i-1]
                x1, y1 = coord_list[i]
                path_length += abs(x1 - x0) + abs(y1 - y0)

            astar_path_length = max(0, astar_len - 1)

            diff = path_length - astar_path_length
            base_quality = 20.0 - 1.5 * (diff ** 2)
        elif astar_set and len(coord_list) > 0:
            in_cnt = sum(1 for p in coord_list if _to_tuple(p) in astar_set)
            out_cnt = len(coord_list) - in_cnt
            base_quality = 2.0 * in_cnt - 2.0 * out_cnt
        else:
            base_quality = -2.0

        completion_coord_lists.append(coord_list)
        completion_base_rewards.append(base_quality)
        for point in coord_list:
            pt = _to_tuple(point)
            point_reward_sum[pt] += base_quality
            point_visit_count[pt] += 1

    point_avg_reward = {
        pt: (point_reward_sum[pt] / point_visit_count[pt]) if point_visit_count[pt] > 0 else 0.0
        for pt in point_reward_sum
    }

    final_scores = []
    for coord_list, base_quality in zip(completion_coord_lists, completion_base_rewards):
        if not coord_list:
            final_scores.append(float(PARSE_FAIL_PENALTY))
            continue

        route_sum = 0.0
        for point in coord_list:
            route_sum += point_avg_reward.get(_to_tuple(point), 0.0)

        extra_anchors = max(0, len(coord_list) - IDEAL_ANCHORS)
        anchor_penalty = ALPHA_ANCHOR * extra_anchors

        unique_cnt = len(set(_to_tuple(p) for p in coord_list))
        loop_extra = max(0, len(coord_list) - unique_cnt)
        loop_penalty = BETA_LOOP * loop_extra

        raw = route_sum + K_BASE * base_quality - anchor_penalty - loop_penalty

        amplified = math.copysign(abs(raw) ** P_EXP, raw)

        amplified = max(MIN_SCORE, min(MAX_SCORE, amplified))

        final_scores.append(float(amplified))

    return final_scores

def parse_and_calculate_reward_original(prompts=None, completions=None, completions_ids=None, node_depth = None ,**kwargs):
    def _to_tuple(p):
        return tuple(p) if isinstance(p, (list, tuple)) else p

    scores = []
    for idx, completion in enumerate(completions):
        
        astar_set = set(_to_tuple(p) for p in kwargs['astar_path'][idx])
        astar_len = len(kwargs['astar_path'][idx])

        if isinstance(completion, list) and completion and isinstance(completion[0], dict) and "content" in completion[0]:
            response_text = completion[0]["content"]
        elif isinstance(completion, dict) and "content" in completion:
            response_text = completion["content"]
        else:
            response_text = str(completion)

        coord_list = []
        try:
            lines = response_text.strip().split('\n')
            coord_line = ""
            for line in lines:
                if '<trajectory for agent1> =' in line:
                    coord_line = line
                    break

            if coord_line:
                coord_str = coord_line.split('=')[1].strip()
                coord_list = eval(coord_str)
            else:
                reward = -6
                scores.append(float(reward))
                continue

        except Exception as e:
            reward = -6
            scores.append(float(reward))
            continue

        reward = 0
        
        if astar_set and astar_len > 2 and len(coord_list) > 2:
            path_length = 0
            for i in range(1, len(coord_list)):
                prev_point = coord_list[i-1]
                curr_point = coord_list[i]
                dist = abs(curr_point[0] - prev_point[0]) + abs(curr_point[1] - prev_point[1])
                path_length += dist
            astar_path_length = len(kwargs['astar_path'][idx])-1
            if path_length == astar_path_length:
                reward = 10
            else:
                reward = max(0, 10 - (path_length - astar_path_length))


        elif astar_set and len(coord_list) > 0:
            for point in coord_list:
                if _to_tuple(point) in astar_set:
                    reward += 1
                else:
                    reward -= 1
        
        else:
            reward = 0

        scores.append(float(reward))

    return scores

def parse_and_calculate_reward_tree(prompts=None, completions=None, completions_ids=None, node_depth = None,**kwargs):
    import collections

    def _to_tuple(p):
        return tuple(p) if isinstance(p, (list, tuple)) else p

    scores = []
    point_reward_sum = collections.defaultdict(float)
    point_visit_count = collections.defaultdict(int)
    completion_coord_lists = []
    completion_rewards = []

    for idx, completion in enumerate(completions):
        astar_set = set(_to_tuple(p) for p in kwargs['astar_path'][idx])
        astar_len = len(kwargs['astar_path'][idx])

        if isinstance(completion, list) and completion and isinstance(completion[0], dict) and "content" in completion[0]:
            response_text = completion[0]["content"]
        elif isinstance(completion, dict) and "content" in completion:
            response_text = completion["content"]
        else:
            response_text = str(completion)

        coord_list = []
        try:
            lines = response_text.strip().split('\n')
            coord_line = ""
            for line in lines:
                if '<trajectory for agent1> =' in line:
                    coord_line = line
                    break

            if coord_line:
                coord_str = coord_line.split('=')[1].strip()
                coord_list = eval(coord_str)
            else:
                reward = -6
                scores.append(float(reward))
                completion_coord_lists.append([])
                completion_rewards.append(reward)
                continue

        except Exception as e:
            reward = -6
            scores.append(float(reward))
            completion_coord_lists.append([])
            completion_rewards.append(reward)
            continue


        if any(
            not isinstance(p, (list, tuple)) or len(p) != 2 or not all(isinstance(val, (int, float)) for val in p)
            for p in coord_list
        ):
            reward = -6.0
            completion_coord_lists.append([])
            completion_rewards.append(reward)
            continue



        reward = 0
        if astar_set and astar_len > 2 and len(coord_list) > 2:
            path_length = 0
            for i in range(1, len(coord_list)):
                prev_point = coord_list[i-1]
                curr_point = coord_list[i]
                dist = abs(curr_point[0] - prev_point[0]) + abs(curr_point[1] - prev_point[1])
                path_length += dist
            astar_path_length = len(kwargs['astar_path'][idx])-1
            if path_length == astar_path_length:
                reward = 10
            else:
                reward = max(0, 10 - (path_length - astar_path_length))
        elif astar_set and len(coord_list) > 0:
            for point in coord_list:
                if _to_tuple(point) in astar_set:
                    reward += 1
                else:
                    reward -= 1
        else:
            reward = 0

        completion_coord_lists.append(coord_list)
        completion_rewards.append(reward)

        for point in coord_list:
            pt = _to_tuple(point)
            point_reward_sum[pt] += reward
            point_visit_count[pt] += 1


    point_avg_reward = {}
    for pt in point_reward_sum:
        if point_visit_count[pt] > 0:
            point_avg_reward[pt] = point_reward_sum[pt] / point_visit_count[pt]
        else:
            point_avg_reward[pt] = 0.0

    final_scores = []
    for coord_list in completion_coord_lists:
        if not coord_list:
            final_scores.append(-6.0)
            continue
        total = 0.0
        for point in coord_list:
            pt = _to_tuple(point)
            total += point_avg_reward.get(pt, 0.0)
        final_scores.append(float(total))

    return final_scores

def parse_and_calculate_reward_tree_normal_grpo(prompts=None, completions=None, completions_ids=None, node_depth=None, **kwargs):
    import collections
    import math

    IDEAL_ANCHORS = int(kwargs.get('ideal_anchor_num', 2))
    ALPHA_ANCHOR = float(kwargs.get('alpha_anchor_penalty', 2.0))
    BETA_LOOP = float(kwargs.get('beta_loop_penalty', 1.0))
    K_BASE = float(kwargs.get('k_base_quality_weight', 0.6))
    P_EXP = float(kwargs.get('p_exponent', 1.2))
    MIN_SCORE = float(kwargs.get('min_score_clamp', -20.0))
    MAX_SCORE = float(kwargs.get('max_score_clamp', 20.0))
    PARSE_FAIL_PENALTY = float(kwargs.get('parse_fail_penalty', -10.0))

    def _to_tuple(p):
        return tuple(p) if isinstance(p, (list, tuple)) else p

    point_reward_sum = collections.defaultdict(float)
    point_visit_count = collections.defaultdict(int)

    completion_coord_lists = []
    completion_base_rewards = []
    scores = []

    for idx, completion in enumerate(completions):
        astar_pts = kwargs.get('astar_path', [[]])[idx]
        astar_set = set(_to_tuple(p) for p in astar_pts)
        astar_len = len(astar_pts)

        if isinstance(completion, list) and completion and isinstance(completion[0], dict) and "content" in completion[0]:
            response_text = completion[0]["content"]
        elif isinstance(completion, dict) and "content" in completion:
            response_text = completion["content"]
        else:
            response_text = str(completion)
        response_text = '<trajectory for agent1> =' + response_text
        coord_list = []
        try:
            lines = response_text.strip().split('\n')
            coord_line = ""
            for line in lines:
                if '<trajectory for agent1> =' in line:
                    coord_line = line
                    break

            if coord_line:
                coord_str = coord_line.split('=')[1].strip()
                coord_list = eval(coord_str)
            else:
                completion_coord_lists.append([])
                completion_base_rewards.append(PARSE_FAIL_PENALTY)
                scores.append(float(PARSE_FAIL_PENALTY))
                continue

        except Exception as e:
            completion_coord_lists.append([])
            completion_base_rewards.append(PARSE_FAIL_PENALTY)
            scores.append(float(PARSE_FAIL_PENALTY))
            continue

        if any(
            not isinstance(p, (list, tuple)) or len(p) != 2 or not all(isinstance(val, (int, float)) for val in p)
            for p in coord_list
        ):
            completion_coord_lists.append([])
            completion_base_rewards.append(PARSE_FAIL_PENALTY)
            scores.append(float(PARSE_FAIL_PENALTY))
            continue

        base_quality = 0.0
        if astar_set and astar_len > 2 and len(coord_list) > 2:
            path_length = 0
            for i in range(1, len(coord_list)):
                x0, y0 = coord_list[i-1]
                x1, y1 = coord_list[i]
                path_length += abs(x1 - x0) + abs(y1 - y0)

            astar_path_length = max(0, astar_len - 1)

            diff = path_length - astar_path_length
            base_quality = 20.0 - 1.5 * (diff ** 2)
        elif astar_set and len(coord_list) > 0:
            in_cnt = sum(1 for p in coord_list if _to_tuple(p) in astar_set)
            out_cnt = len(coord_list) - in_cnt
            base_quality = 2.0 * in_cnt - 2.0 * out_cnt
        else:
            base_quality = -2.0

        completion_coord_lists.append(coord_list)
        completion_base_rewards.append(base_quality)
        for point in coord_list:
            pt = _to_tuple(point)
            point_reward_sum[pt] += base_quality
            point_visit_count[pt] += 1

    point_avg_reward = {
        pt: (point_reward_sum[pt] / point_visit_count[pt]) if point_visit_count[pt] > 0 else 0.0
        for pt in point_reward_sum
    }

    final_scores = []
    for coord_list, base_quality in zip(completion_coord_lists, completion_base_rewards):
        if not coord_list:
            final_scores.append(float(PARSE_FAIL_PENALTY))
            continue

        route_sum = 0.0
        for point in coord_list:
            route_sum += point_avg_reward.get(_to_tuple(point), 0.0)

        extra_anchors = max(0, len(coord_list) - IDEAL_ANCHORS)
        anchor_penalty = ALPHA_ANCHOR * extra_anchors

        unique_cnt = len(set(_to_tuple(p) for p in coord_list))
        loop_extra = max(0, len(coord_list) - unique_cnt)
        loop_penalty = BETA_LOOP * loop_extra

        raw = route_sum + K_BASE * base_quality - anchor_penalty - loop_penalty

        amplified = math.copysign(abs(raw) ** P_EXP, raw)

        amplified = max(MIN_SCORE, min(MAX_SCORE, amplified))

        final_scores.append(float(amplified))

    return final_scores