# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from . import gsm8k, math, prime_math, prime_code
import math
import re
from collections import Counter



def calculate_range(number, alpha=0.2):
    offset_50_percent = number * alpha
    offset_20 = 20

    plus_50_percent = number + offset_50_percent
    minus_50_percent = number - offset_50_percent
    plus_20 = number + offset_20
    minus_20 = number - offset_20

    max_value = max(plus_50_percent, plus_20)
    min_value = max(0,min(minus_50_percent, minus_20))

    # return min_value, max_value
    return minus_50_percent, plus_50_percent

def cos_fn(t, T, eta_min, eta_max):
    return eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(t * math.pi / T))

def compute_token_entropy(text):
    try:
        tokens = text.strip().split()
        if not tokens:
            return 0.0
        total_tokens = len(tokens)
        token_counts = Counter(tokens)
        freqs = [count / total_tokens for count in token_counts.values()]
        entropy = -sum(f * math.log2(f) for f in freqs)
        max_entropy = math.log2(total_tokens)
        normalized_entropy = entropy / max_entropy if max_entropy > 0 else 0.0
        return normalized_entropy
    except Exception as e:
        print(f"token entropy error: {str(e)}")
        return 0.0
    
def clip(value, min_val, max_val):
    return max(min(value, max_val), min_val)


def _default_compute_score(data_source, solution_str, ground_truth, extra_info=None):
    if data_source == 'openai/gsm8k':
        from . import gsm8k
        res = gsm8k.compute_score(solution_str, ground_truth)
    elif data_source in ['lighteval/MATH', 'DigitalLearningGmbH/MATH-lighteval']:
        from . import math
        res = math.compute_score(solution_str, ground_truth)
    elif data_source in [
            'numina_aops_forum', 'numina_synthetic_math', 'numina_amc_aime', 'numina_synthetic_amc', 'numina_cn_k12',
            'numina_olympiads'
    ]:
        from . import prime_math
        res = prime_math.compute_score(solution_str, ground_truth)
    elif data_source in ['codecontests', 'apps', 'codeforces', 'taco']:
        from . import prime_code
        res = prime_code.compute_score(solution_str, ground_truth, continuous=True)
    elif data_source in ['hiyouga/geometry3k']:
        from . import geo3k
        res = geo3k.compute_score(solution_str, ground_truth)
    elif data_source == 'openai/gsm8k_cosine':
        from . import gsm8k
        correctness = gsm8k.compute_score(solution_str, ground_truth)
        # Check if the format is correct
        regex = r"^<budget>([^<]*(?:<(?!/?budget>)[^<]*)*)<\/budget><solution>([\s\S]*?)<\/solution>$"
        # regex = r"^<budget>(\d+)<\/budget>\n<solution>([\s\S]*?)<\/solution>$"

        match = re.search(regex, solution_str, re.DOTALL)
        # if the format is not correct, reward is 0
        if match is None or len(match.groups()) != 2:
            # format_reward = 0.0
            return -1.0
        else:
            budgetcnt = match.group(1).strip()
            pred_solution = match.group(2).strip()
            # format_reward = 1.0

        actual_budget = extra_info
        try:
            target_budget = int(budgetcnt)
        except:
            try:
                target_budget = int(budgetcnt.split(' ')[0])
            except:
                target_budget = 0
        L_min, L_max = calculate_range(target_budget)
        eta_c_min, eta_c_max = 1.0, 0.1
        eta_w_min, eta_w_max = -0.1, -0.9
        r_exceed = -0.95
        # res = 1.0 if correctness else 0.0
        if correctness:
            if actual_budget > L_max:
                return r_exceed
            else:
                return cos_fn(actual_budget, L_max, eta_c_min, eta_c_max)
        else:
            if actual_budget > L_max:
                return r_exceed
            else:
                return cos_fn(actual_budget, L_max, eta_w_min, eta_w_max)
    elif data_source in ['openai/gsm8k_l1', 'STILL-3-l1']:

        # from . import gsm8k
        # correctness = gsm8k.compute_score(pred_solution, ground_truth)
        actual_len, para_ques, ori_budget, reward_type = extra_info
        from .util import extract_math_answer,is_equiv
        pred = extract_math_answer(para_ques,solution_str,"")
        correctness = is_equiv(ground_truth,pred)
        if correctness:
            cor_score = 1.0
        else:
            cor_score = 0.0
        if reward_type == "l1_max":
            len_score = clip(0.0003*(ori_budget-actual_len)+0.5,0.0,1.0)
            score = cor_score * len_score
            return score
        elif reward_type == "l1_exact":
            len_score = 0.0003*abs(ori_budget-actual_len)
            score = cor_score - len_score
            return score
        else:
            raise NotImplementedError
    elif data_source in ['openai/gsm8k_corMatch','DigitalLearningGmbH/MATH-lighteval_corMatch','STILL-3-Preview-RL-Data-HF_corMatch']:
        # Check if the format is correct
        regex = r"^<budget>([^<]*(?:<(?!/?budget>)[^<]*)*)<\/budget><solution>([\s\S]*?)<\/solution>$"
        # regex = r"^<budget>(\d+)<\/budget>\n<solution>([\s\S]*?)<\/solution>$"

        match = re.search(regex, solution_str, re.DOTALL)
        # if the format is not correct, reward is 0
        if match is None or len(match.groups()) != 2:
            # format_reward = 0.0
            solution_str = "<budget>"+solution_str
            match = re.search(regex, solution_str, re.DOTALL)
            if match is None or len(match.groups()) != 2:
                return -1.0
            else:
                budgetcnt = match.group(1).strip()
                pred_solution = match.group(2).strip()
        else:
            budgetcnt = match.group(1).strip()
            pred_solution = match.group(2).strip()
            # format_reward = 1.0

        # from . import gsm8k
        # correctness = gsm8k.compute_score(pred_solution, ground_truth)
        actual_budget,para_ques,ori_budget,alpha = extra_info
        from .util import extract_math_answer,is_equiv
        pred = extract_math_answer(para_ques,pred_solution,"")
        correctness = is_equiv(ground_truth,pred)
        try:
            target_budget = int(budgetcnt)
        except:
            try:
                target_budget = int(budgetcnt.split(' ')[0])
            except:
                target_budget = 0
        L_min, L_max = calculate_range(target_budget,alpha)
        # res = 1.0 if correctness else 0.0
        exceed_penalty = -0.5
        if ori_budget > 0 and ori_budget < target_budget:
            budget_penalty = -0.4
        else:
            budget_penalty = 0.0
        if correctness:
            if actual_budget > L_max or actual_budget < L_min:
                return 1.0 + exceed_penalty + budget_penalty
            else:
                return cos_fn(actual_budget - L_min, L_max - L_min, 1.0 + exceed_penalty, 1.0) + budget_penalty
        else:
            if actual_budget > L_max or actual_budget < L_min:
                return exceed_penalty + budget_penalty
            else:
                return cos_fn(L_max - actual_budget, L_max - L_min, exceed_penalty, 0.0) + budget_penalty
    else:
        raise NotImplementedError

    if isinstance(res, (int, float, bool)):
        return float(res)
    else:
        return float(res[0])
