import re
import math
from Levenshtein import ratio as levenshtein_ratio


def determine_swipe_direction(coords):
    if len(coords) != 2 or any(len(point) != 2 for point in coords):
        return "123"

    x1 = coords[0][0]
    y1 = coords[0][1]
    x2 = coords[1][0]
    y2 = coords[1][1]


    dx = x2 - x1
    dy = y2 - y1


    distance = math.sqrt(dx ** 2 + dy ** 2)
    angle = math.atan2(dy, dx) * 180 / math.pi  


    if distance < 1e-5:
        return "None"


    if abs(dx) > abs(dy): 
        if dx > 0:
            direction = "Right"
        else:
            direction = "Left"
    elif abs(dy) > abs(dx):
        if dy > 0:
            direction = "Down"
        else:
            direction = "Up"
    else:  
        if dx > 0 and dy > 0:
            direction = "Right"
        elif dx > 0 and dy < 0:
            direction = "Right"
        elif dx < 0 and dy > 0:
            direction = "Left"
        else:  # dx < 0 and dy < 0
            direction = "Left"

    return direction


def calculate_f1_score(predicted_str, ground_truth_str):
    predicted_str = predicted_str.replace("[", "").replace("]", "")
    ground_truth_str = ground_truth_str.replace("[", "").replace("]", "")
    predicted_tokens = set(predicted_str.lower().split())
    ground_truth_tokens = set(ground_truth_str.lower().split())

    if len(predicted_tokens) == 1 and len(ground_truth_tokens) == 1:
        predicted_token = list(predicted_tokens)[0]
        ground_truth_token = list(ground_truth_tokens)[0]
        if predicted_token in ground_truth_token or ground_truth_token in predicted_token:
            return 1

    common_tokens = predicted_tokens.intersection(ground_truth_tokens)
    if len(predicted_tokens) == 0:
        precision = 0
    else:
        precision = len(common_tokens) / len(predicted_tokens)
    if len(ground_truth_tokens) == 0:
        recall = 0
    else:
        recall = len(common_tokens) / len(ground_truth_tokens)

    if precision + recall == 0:
        f1_score = 0
    else:
        f1_score = 2 * (precision * recall) / (precision + recall)
    return f1_score


def format_reward_func(completions, **kwargs):
    pattern = r"^<thinking>.*?</thinking>.*?<answer>.*?</answer>$"
    matches = [re.match(pattern, content[0]['content'], re.DOTALL) for content in completions]
    for content in completions:
        print('prediction==' + content[0]['content'])
    return [0.1 if match else 0.0 for match in matches]


def thinking_reward_func(completions, **kwargs):
    pattern = r'^<thinking>.*?<analysis>.*?</analysis>.*?<reasoning>.*?</reasoning>.*?<instruction>.*?</instruction>.*?</thinking>.*?<answer>.*?</answer>$'
    matches = [re.match(pattern, content[0]['content'], re.DOTALL) for content in completions]
    return [0.1 if match else 0.0 for match in matches]


def accuracy_reward_type(completions, solution, **kwargs):
    res = []
    for completion, solution in zip(completions, solution):
        completion = completion[0]['content']
        if '</thinking>' in completion:
            t = completion.split('</thinking>')[-1]
        else:
            t = completion
        if '<answer>' in t and '</answer>' in t:
            start_tag = '<answer>'
            end_tag = '</answer>'
            start_index = t.find(start_tag)
            end_index = t.rfind(end_tag)
            extracted_answer = t[start_index + len(start_tag): end_index].strip()

            if '</thinking>' in solution:
                sol_answer = solution.split('</thinking>')[-1]
            else:
                sol_answer = solution
            start_index_sol = sol_answer.find(start_tag)
            end_index_sol = sol_answer.rfind(end_tag)
            sol_dict = sol_answer[start_index_sol + len(start_tag): end_index_sol].strip()
            sol = eval(sol_dict)


            if extracted_answer.startswith('{{') or extracted_answer.endswith('}}') or not extracted_answer.startswith(
                    '{') or not extracted_answer.endswith('}'):
                res.append(0.0)
            else:
                try:
                    cleaned = extracted_answer.strip()
                    extracted_answer = eval(cleaned)
                    if start_index != -1 and end_index != -1 and start_index < end_index:
                        if isinstance(extracted_answer, dict):
                            res.append(levenshtein_ratio(extracted_answer['action_type'].lower(), sol['action_type'].lower()) * 0.3)
                        else:
                            res.append(0.0)
                    else:
                        res.append(0.0)
                except:
                    res.append(0.0)
        else:
            res.append(0.0)

    return res


def accuracy_reward_action(completions, solution, **kwargs):
    res = []
    for completion, solution in zip(completions, solution):
        completion = completion[0]['content']
        if '</thinking>' in completion:
            t = completion.split('</thinking>')[-1]
        else:
            t = completion
        if '</thinking>' in solution:
            sol_answer = solution.split('</thinking>')[-1]
        else:
            sol_answer = solution
        if '<answer>' in t and '</answer>' in t:
            start_tag = '<answer>'
            end_tag = '</answer>'
            start_index = t.find(start_tag)
            end_index = t.rfind(end_tag)
            extracted_answer = t[start_index + len(start_tag): end_index].strip()
            start_index_sol = sol_answer.find(start_tag)
            end_index_sol = sol_answer.rfind(end_tag)
            sol_dict = sol_answer[start_index_sol + len(start_tag): end_index_sol].strip()
            sol = eval(sol_dict)
            if extracted_answer.startswith('{{') or extracted_answer.endswith('}}') or not extracted_answer.startswith(
                    '{') or not extracted_answer.endswith('}'):
                res.append(0.0)
            else:
                try:
                    extracted_answer = eval(extracted_answer)
                    if start_index != -1 and end_index != -1 and start_index < end_index:
                        if isinstance(extracted_answer, dict):
                            if extracted_answer['action_type'] == sol['action_type']:
                                if extracted_answer['action_type'] == 'SCROLL':
                                    if all(isinstance(x, (int, float)) for x in extracted_answer['action_info']):
                                        res.append(0.0)
                                    else:
                                        direction_sol = determine_swipe_direction(sol['action_info'])
                                        if all(isinstance(x, list) for x in extracted_answer['action_info']):
                                            direction_answer = determine_swipe_direction(extracted_answer['action_info'])
                                        else:
                                            direction_answer = extracted_answer['action_info']
                                        if direction_sol == direction_answer:
                                            res.append(0.5)
                                        else:
                                            res.append(0.0)
                                elif extracted_answer['action_type'] == 'CLICK' and all(isinstance(x, list) for x in sol['action_info']):
                                    if all(isinstance(x, (int, float)) for x in extracted_answer['action_info']):
                                        sol_x1 = sol['action_info'][0][0]
                                        sol_y1 = sol['action_info'][0][1]
                                        sol_x2 = sol['action_info'][1][0]
                                        sol_y2 = sol['action_info'][1][1]
                                        answer_x = float(extracted_answer['action_info'][0])
                                        answer_y = float(extracted_answer['action_info'][1])
                                        if sol_x1 < answer_x < sol_x2 and sol_y1 < answer_y < sol_y2:
                                            res.append(0.5)
                                        else:
                                            res.append(0.0)
                                    else:
                                        res.append(0.0)
                                else:
                                    critic_0 = calculate_f1_score(extracted_answer['action_info'].lower(), sol['action_info'].lower())
                                    # critic_1 = levenshtein_ratio(extracted_answer['action_info'].lower(), sol['action_info'].lower())
                                    if critic_0 > 0.5:
                                        res.append(0.5)
                                    else:
                                        res.append(0.0)
                            else:
                                res.append(0.0)
                        else:
                            res.append(0.0)
                    else:
                        res.append(0.0)
                except:
                    res.append(0.0)
        else:
            res.append(0.0)

    return res