# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py

import torch

from typing import List
from math_verify import parse, verify


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1
    
    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]
    
    return retval

def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None


def extract_boxed_answer(solution: str) -> str:
    """Extract the answer from inside a LaTeX \\boxed{} command"""
    solution = last_boxed_only_string(solution)
    solution = remove_boxed(solution)
    return solution


def reward_func(queries, prompts, labels, **kwargs):
    """
    Reward function for calculating rewards of model outputs.

    Args:
        queries (torch.Tensor): Complete text sequences containing prompts and responses
        prompts (torch.Tensor): Input prompt sequences
        labels (torch.Tensor): Ground truth answer sequences
        **kwargs: Additional optional parameters

    Returns:
        dict: A dictionary containing the following key-value pairs:
            - rewards: Reward values used for calculating advantage function
            - scores: Reward values in range [0,1] used for dynamic filtering
            - extra_logs: Additional information to be logged in wandb
    """
    # Print input queries for debugging purposes
    # print(queries)

    # Generate random rewards as an example
    # In real applications, this should be replaced with actual reward calculation logic
    rewards = []
    scores = []
    formats = []

    for i, (query, label) in enumerate(zip(queries, labels)):
        # Limit solution length for efficiency
        solution_str = query[-200:]  # The longest answer in MATH-500 has 159 characters

        # Verify the solution
        solution = extract_boxed_answer(solution_str)
        if solution is None:
            rewards.append(-1.0)
            scores.append(0.0)
            formats.append(0.0)
            continue

        formats.append(1.0)
        gold = parse(label) + parse(f"${label}$")
        answer = parse(solution) + parse(f"${solution}$")
        result = verify(gold, answer)
        if result:
            rewards.append(1.0)
            scores.append(1.0)
        else:
            rewards.append(-0.5)
            scores.append(0.0)

        if i == 0:
            print(f"gold: {gold}, answer: {answer}, result: {result}")

    rewards = torch.tensor(rewards)
    scores = torch.tensor(scores)
    formats = torch.tensor(formats)

    return {
        "rewards": rewards,  # Rewards for advantage calculation
        "scores": scores,  # Scores for dynamic filtering (0-1 reward)
        "extra_logs": {"formats": formats},  # Additional logging info for wandb
    }