
import os
import re
import copy
import math

from datetime import datetime
from math_verify import parse, verify
from collections import deque

def calculate_recall_only(candidate, reference):
    candidate = normalize_word(candidate)
    reference = normalize_word(reference)

    candidate_words = split_sentence(candidate, 1)
    reference_words = split_sentence(reference, 1)

    tp = 0
    fn = 0

    for word in reference_words:
        if word in candidate_words:
            tp += min(candidate_words[word], reference_words[word])
        else:
            fn += reference_words[word]

    if tp + fn == 0:
        return 0.0
    else:
        return tp / (tp + fn)

def accuracy_reward(completions, solution, answer_type, **kwargs):
    """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
    
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_type = answer_type[0]
    for content, sol in zip(contents, solution):
        reward = 0.0

        # Extract answer from solution if it has think/answer tags
        sol_match = re.search(r'<answer>(.*?)</answer>', sol)
        ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
        
        # Extract answer from content if it has think/answer tags
        content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
        student_answer = content_match.group(1).strip() if content_match else content.strip()

        recall = calculate_recall_only(student_answer,ground_truth)
        recall_all = calculate_recall_only(content.strip(),ground_truth)

        if answer_type == "OPEN":
            reward = recall_all
        elif answer_type == "CLOSED":
            if  ground_truth.lower() in student_answer.lower():
                reward = 1.0
            else:
                reward = 0.0
        else:
            reward = 0.0

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "True":
            log_path = os.getenv("LOG_PATH")
            try:
                with open(log_path, "a") as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} Question type: {answer_type} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Content: {student_answer}\n")
                    f.write(f"Solution: {ground_truth}\n")
            except:
                pass

    return rewards

def func_accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is correct using either symbolic verification or exact string matching."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")

    def extract_items(text):
        pattern = re.compile(r"(\w+)\((\w+),\s*'?(\w+)'?\)")
        matches = pattern.findall(text)
        filtered_matches = list(set(matches))
        return filtered_matches, len(filtered_matches) / len(matches)
    
    for content, sol in zip(contents, solution):
        reward = 0.0
        reward = 0.0
        # Try string matching
        try:
            # Extract (func, object_id, value) pairs
            # Extract answer from content if it has think/answer tags
            content_match = re.search(r'<answer>(.*?)</answer>', content)
            content_match = content_match.group(1).strip() if content_match else content.strip()
            pred_list, repeat_panelty = extract_items(content_match)
            sol_list, _ = extract_items(sol)
            
            item_score = repeat_panelty / max(len(pred_list), len(sol_list))
            
            pred_queue = deque(pred_list)
            sol_queue = deque(sol_list)
            
            # full mapping
            full_mapping_num = 0
            exact_matches = [(p, s) for p in pred_queue for s in sol_queue if p == s]
            for p, s in exact_matches:
                if p in pred_queue and s in sol_queue:
                    full_mapping_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += full_mapping_num * item_score
            
            # (func, object_id) mapping
            partial_matches_1_num = 0
            partial_matches_1 = [(p, s) for p in pred_queue for s in sol_queue if p[:2] == s[:2]]
            for p, s in partial_matches_1:
                if p in pred_queue and s in sol_queue:
                    partial_matches_1_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += partial_matches_1_num * item_score * 0.5
            
            # (func, value) mapping
            partial_matches_2_num = 0
            partial_matches_2 = [(p, s) for p in pred_queue for s in sol_queue if (p[0], p[2]) == (s[0], s[2])]
            for p, s in partial_matches_2:
                if p in pred_queue and s in sol_queue:
                    partial_matches_2_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += partial_matches_2_num * item_score * 0.5
            
            # only-func mapping
            func_matches_num = 0
            func_matches = [(p, s) for p in pred_queue for s in sol_queue if p[0] == s[0]]
            for p, s in func_matches:
                if p in pred_queue and s in sol_queue:
                    func_matches_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += func_matches_num * item_score * 0.25

        except Exception:
            pass

            reward = 0.0
        # Try string matching
        try:
            # Extract (func, object_id, value) pairs
            # Extract answer from content if it has think/answer tags
            content_match = re.search(r'<answer>(.*?)</answer>', content)
            content_match = content_match.group(1).strip() if content_match else content.strip()
            pred_list, repeat_panelty = extract_items(content_match)
            sol_list, _ = extract_items(sol)
            
            item_score = repeat_panelty / max(len(pred_list), len(sol_list))
            
            pred_queue = deque(pred_list)
            sol_queue = deque(sol_list)
            
            # full mapping
            full_mapping_num = 0
            exact_matches = [(p, s) for p in pred_queue for s in sol_queue if p == s]
            for p, s in exact_matches:
                if p in pred_queue and s in sol_queue:
                    full_mapping_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += full_mapping_num * item_score
            
            # (func, object_id) mapping
            partial_matches_1_num = 0
            partial_matches_1 = [(p, s) for p in pred_queue for s in sol_queue if p[:2] == s[:2]]
            for p, s in partial_matches_1:
                if p in pred_queue and s in sol_queue:
                    partial_matches_1_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += partial_matches_1_num * item_score * 0.5
            
            # (func, value) mapping
            partial_matches_2_num = 0
            partial_matches_2 = [(p, s) for p in pred_queue for s in sol_queue if (p[0], p[2]) == (s[0], s[2])]
            for p, s in partial_matches_2:
                if p in pred_queue and s in sol_queue:
                    partial_matches_2_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += partial_matches_2_num * item_score * 0.5
            
            # only-func mapping
            func_matches_num = 0
            func_matches = [(p, s) for p in pred_queue for s in sol_queue if p[0] == s[0]]
            for p, s in func_matches:
                if p in pred_queue and s in sol_queue:
                    func_matches_num += 1
                    pred_queue.remove(p)
                    sol_queue.remove(s)
            reward += func_matches_num * item_score * 0.25

        except Exception:
            pass

        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "True":
            log_path = os.getenv("LOG_PATH")
            try:
                with open(log_path, "a") as f:
                    f.write(f"------------- {current_time} Accuracy reward: {reward} Question type: {answer_type} -------------\n")
                    f.write(f"Content: {content}\n")
                    f.write(f"Content: {student_answer}\n")
                    f.write(f"Solution: {ground_truth}\n")
            except:
                pass

    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

def calibration_reward(completions, solution, confidence, answer_type, **kwargs):
    """
    计算一致性奖励：acc * (conf / 10) - (1 - acc) * (conf / 10)
    """

    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_type = answer_type[0]

    for idx, (content, sol) in enumerate(zip(contents, solution)):
        try:
            sol_match = re.search(r'<answer>(.*?)</answer>', sol)
            ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()

            content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
            student_answer = content_match.group(1).strip() if content_match else content.strip()

            if answer_type == "OPEN":
                acc = calculate_recall_only(content.strip(), ground_truth)
            elif answer_type == "CLOSED":
                acc = 1.0 if ground_truth.lower() in student_answer.lower() else 0.0
            else:
                acc = 0.0

            conf = float(confidence[idx].item())

            alignment_reward = acc * conf - (1 - acc) * conf
            rewards.append(alignment_reward)

        except Exception as e:
            rewards.append(0.0)

    return rewards


# 为Recall所需函数
from collections import defaultdict
import re
import math
def split_sentence(sentence, n):
    words = defaultdict(int)
    tmp_sentence = sentence
    tmp_sentence = tmp_sentence.lower()
    tmp_sentence = tmp_sentence.strip().split()
    length = len(tmp_sentence)
    for i in range(length - n + 1):
        tmp_words = " ".join(tmp_sentence[i: i + n])
        if tmp_words:
            words[tmp_words] += 1
    return words

import re

contractions = {
    "aint": "ain't",
    "arent": "aren't",
    "cant": "can't",
    "couldve": "could've",
    "couldnt": "couldn't",
    "couldn'tve": "couldn't've",
    "couldnt've": "couldn't've",
    "didnt": "didn't",
    "doesnt": "doesn't",
    "dont": "don't",
    "hadnt": "hadn't",
    "hadnt've": "hadn't've",
    "hadn'tve": "hadn't've",
    "hasnt": "hasn't",
    "havent": "haven't",
    "hed": "he'd",
    "hed've": "he'd've",
    "he'dve": "he'd've",
    "hes": "he's",
    "howd": "how'd",
    "howll": "how'll",
    "hows": "how's",
    "Id've": "I'd've",
    "I'dve": "I'd've",
    "Im": "I'm",
    "Ive": "I've",
    "isnt": "isn't",
    "itd": "it'd",
    "itd've": "it'd've",
    "it'dve": "it'd've",
    "itll": "it'll",
    "let's": "let's",
    "maam": "ma'am",
    "mightnt": "mightn't",
    "mightnt've": "mightn't've",
    "mightn'tve": "mightn't've",
    "mightve": "might've",
    "mustnt": "mustn't",
    "mustve": "must've",
    "neednt": "needn't",
    "notve": "not've",
    "oclock": "o'clock",
    "oughtnt": "oughtn't",
    "ow's'at": "'ow's'at",
    "'ows'at": "'ow's'at",
    "'ow'sat": "'ow's'at",
    "shant": "shan't",
    "shed've": "she'd've",
    "she'dve": "she'd've",
    "she's": "she's",
    "shouldve": "should've",
    "shouldnt": "shouldn't",
    "shouldnt've": "shouldn't've",
    "shouldn'tve": "shouldn't've",
    "somebody'd": "somebodyd",
    "somebodyd've": "somebody'd've",
    "somebody'dve": "somebody'd've",
    "somebodyll": "somebody'll",
    "somebodys": "somebody's",
    "someoned": "someone'd",
    "someoned've": "someone'd've",
    "someone'dve": "someone'd've",
    "someonell": "someone'll",
    "someones": "someone's",
    "somethingd": "something'd",
    "somethingd've": "something'd've",
    "something'dve": "something'd've",
    "somethingll": "something'll",
    "thats": "that's",
    "thered": "there'd",
    "thered've": "there'd've",
    "there'dve": "there'd've",
    "therere": "there're",
    "theres": "there's",
    "theyd": "they'd",
    "theyd've": "they'd've",
    "they'dve": "they'd've",
    "theyll": "they'll",
    "theyre": "they're",
    "theyve": "they've",
    "twas": "'twas",
    "wasnt": "wasn't",
    "wed've": "we'd've",
    "we'dve": "we'd've",
    "weve": "we've",
    "werent": "weren't",
    "whatll": "what'll",
    "whatre": "what're",
    "whats": "what's",
    "whatve": "what've",
    "whens": "when's",
    "whered": "where'd",
    "wheres": "where's",
    "whereve": "where've",
    "whod": "who'd",
    "whod've": "who'd've",
    "who'dve": "who'd've",
    "wholl": "who'll",
    "whos": "who's",
    "whove": "who've",
    "whyll": "why'll",
    "whyre": "why're",
    "whys": "why's",
    "wont": "won't",
    "wouldve": "would've",
    "wouldnt": "wouldn't",
    "wouldnt've": "wouldn't've",
    "wouldn'tve": "wouldn't've",
    "yall": "y'all",
    "yall'll": "y'all'll",
    "y'allll": "y'all'll",
    "yall'd've": "y'all'd've",
    "y'alld've": "y'all'd've",
    "y'all'dve": "y'all'd've",
    "youd": "you'd",
    "youd've": "you'd've",
    "you'dve": "you'd've",
    "youll": "you'll",
    "youre": "you're",
    "youve": "you've",
}

manual_map = {
    "none": "0",
    "zero": "0",
    "one": "1",
    "two": "2",
    "three": "3",
    "four": "4",
    "five": "5",
    "six": "6",
    "seven": "7",
    "eight": "8",
    "nine": "9",
    "ten": "10",
}
articles = ["a", "an", "the"]
period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
comma_strip = re.compile("(\d)(\,)(\d)")
punct = [
    ";",
    r"/",
    "[",
    "]",
    '"',
    "{",
    "}",
    "(",
    ")",
    "=",
    "+",
    "\\",
    "_",
    "-",
    ">",
    "<",
    "@",
    "`",
    ",",
    "?",
    "!",
]


def normalize_word(token):
    _token = token
    for p in punct:
        if (p + " " in token or " " + p in token) or (
            re.search(comma_strip, token) != None
        ):
            _token = _token.replace(p, "")
        else:
            _token = _token.replace(p, " ")
    token = period_strip.sub("", _token, re.UNICODE)

    _token = []
    temp = token.lower().split()
    for word in temp:
        word = manual_map.setdefault(word, word)
        if word not in articles:
            _token.append(word)
    for i, word in enumerate(_token):
        if word in contractions:
            _token[i] = contractions[word]
    token = " ".join(_token)
    token = token.replace(",", "")
    return token