import asyncio
import json
import math
import re
from functools import partial, update_wrapper
from typing import Callable, Dict, Optional
import logging

# from oai_generator import OpenAIGenerator, OpenAIGeneratorMultiTurn, REASONING_CUSTOM_MODELS
import logging, re, json
import string


# from latex2sympy2_extended import NormalizationConfig
# from math_verify import LatexExtractionConfig, parse, verify

# from prompts import *

def format_reward(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

def tag_count_reward(completions, **kwargs) -> list[float]:
    """Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.

    Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90
    """

    def count_tags(text: str) -> float:
        count = 0.0
        if text.count("<think>\n") == 1:
            count += 0.25
        if text.count("\n</think>\n") == 1:
            count += 0.25
        if text.count("\n<answer>\n") == 1:
            count += 0.25
        if text.count("\n</answer>") == 1:
            count += 0.25
        return count

    contents = [completion[0]["content"] for completion in completions]
    return [count_tags(c) for c in contents]

def reasoning_steps_reward(completions, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [len(re.findall(pattern, content)) for content in completion_contents]

    # Magic number 3 to encourage 3 steps and more, otherwise partial reward
    return [min(1.0, count / 3) for count in matches]

def len_reward(completions: list[Dict[str, str]], solution: list[str], **kwargs) -> float:
    """Compute length-based rewards to discourage overthinking and promote token efficiency.

    Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599

    Args:
        completions: List of model completions
        solution: List of ground truth solutions

    Returns:
        List of rewards where:
        - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
        - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
    """
    # TODO: implement this!!
    # Original source: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py#L135
    
    
# def get_cosine_scaled_reward(
#     min_value_wrong: float = -1.0,
#     max_value_wrong: float = -0.5,
#     min_value_correct: float = 0.5,
#     max_value_correct: float = 1.0,
#     max_len: int = 1000,
# ):
#     def cosine_scaled_reward(completions, solution, **kwargs):
#         """Reward function that scales based on completion length using a cosine schedule.

#         Shorter correct solutions are rewarded more than longer ones.
#         Longer incorrect solutions are penalized less than shorter ones.

#         Args:
#             completions: List of model completions
#             solution: List of ground truth solutions

#         This function is parameterized by the following arguments:
#             min_value_wrong: Minimum reward for wrong answers
#             max_value_wrong: Maximum reward for wrong answers
#             min_value_correct: Minimum reward for correct answers
#             max_value_correct: Maximum reward for correct answers
#             max_len: Maximum length for scaling
#         """
#         contents = [completion[0]["content"] for completion in completions]
#         rewards = []

#         for content, sol in zip(contents, solution):
#             gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
#             if len(gold_parsed) == 0:
#                 rewards.append(1.0)  # Skip unparseable examples
#                 print("Failed to parse gold solution: ", sol)
#                 continue

#             answer_parsed = parse(
#                 content,
#                 extraction_config=[
#                     LatexExtractionConfig(
#                         normalization_config=NormalizationConfig(
#                             nits=False,
#                             malformed_operators=False,
#                             basic_latex=True,
#                             equations=True,
#                             boxed=True,
#                             units=True,
#                         ),
#                         boxed_match_priority=0,
#                         try_extract_without_anchor=False,
#                     )
#                 ],
#                 extraction_mode="first_match",
#             )

#             is_correct = verify(answer_parsed, gold_parsed)
#             gen_len = len(content)

#             # Apply cosine scaling based on length
#             progress = gen_len / max_len
#             cosine = math.cos(progress * math.pi)

#             if is_correct:
#                 min_value = min_value_correct
#                 max_value = max_value_correct
#             else:
#                 # Swap min/max for incorrect answers
#                 min_value = max_value_wrong
#                 max_value = min_value_wrong

#             reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
#             rewards.append(float(reward))

#         return rewards

#     return cosine_scaled_reward

def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
    """
    Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

    Args:
    ngram_size: size of the n-grams
    max_penalty: Maximum (negative) penalty for wrong answers
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")

    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])

    def repetition_penalty_reward(completions, **kwargs) -> float:
        """
        reward function the penalizes repetitions
        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

        Args:
            completions: List of model completions
        """

        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for completion in contents:
            if completion == "":
                rewards.append(0.0)
                continue
            if len(completion.split()) < ngram_size:
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            for ng in zipngram(completion, ngram_size):
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty
            rewards.append(reward)
        return rewards

    return repetition_penalty_reward

# -------------------------------------------------
# GLOBAL HANDLE (initially None)
_JUDGE = None

def set_judge_generator(generator):
    """
    Call this **once per process** (e.g. in grpo.py)
    so every reward_* function can reuse the same client.
    """
    global _JUDGE
    _JUDGE = generator
    logging.info("Judge generator registered for rewards.")

def _require_judge():
    if _JUDGE is None:
        raise RuntimeError("Judge generator has not been initialised – "
                           "call rewards.set_judge_generator() first")
    return _JUDGE
# -------------------------------------------------

# def legal_move_reward_func(completions, **kwargs):
#     """
#     Rewards if the first token (after cleaning) from the final <answer> tag is one of the allowed moves.
#     If no legal moves are available (i.e. allowed_moves is empty), then reward if the answer matches the default:
#       "None" for NL mode or "<move_null>" for special mode.
#     Expects kwargs to contain 'allowed_moves' as a list of lists (one per sample) and 'representation_mode'.
#     """
#     logging.info("Computing legal move rewards.")
#     logging.info("  Allowed moves: %s", kwargs.get("allowed_moves", None))
#     logging.info("Completions: %s", completions)
#     rewards = []
#     allowed_moves_all = kwargs.get("allowed_moves", None)
#     rep_mode = kwargs.get("representation_mode", "nl")
#     default_nl = "None"
#     default_special = "<move_null>"
#     for idx, completion in enumerate(completions):
#         # Use the final answer block.
#         answer_matches = re.findall(r"<answer>(.*?)</answer>", completion, re.DOTALL)
#         if answer_matches:
#             raw_answer = answer_matches[-1].strip()
#             answer_clean = raw_answer.translate(str.maketrans("", "", string.punctuation)).strip()
#             tokens = answer_clean.split()
#             first_token = tokens[0] if tokens else ""
#             if allowed_moves_all and isinstance(allowed_moves_all, list) and len(allowed_moves_all) > idx:
#                 allowed = allowed_moves_all[idx]
#                 # If allowed is empty, use default token based on representation mode.
#                 if not allowed:
#                     allowed = [default_nl] if rep_mode == "nl" else [default_special]
#                 reward = 1.0 if first_token in allowed else 0.0
#             else:
#                 reward = 0.0
#         else:
#             reward = 0.0
#         rewards.append(reward)
#     return rewards

def legal_move_reward_func(completions, **kwargs):
    """
    COMMON IMPLEMENTATION FOR BOTH LEGAL AND BEST MOVE REWARDS!!! Just expects the 'allowed_moves' list to contain either the best or the legal move based on the experiment_mode
    
    Rewards if the first token (after cleaning) from the final <answer> tag is one of the allowed moves.
    If no legal moves are available (i.e. allowed_moves is empty), then reward if the answer matches the default:
    "None" for NL mode or "<move_null>" for special mode.
    Expects kwargs to contain 'allowed_moves' as a list of lists (one per sample) and 'representation_mode'.
    """
    logging.info("Computing legal move rewards.")
    logging.info("  Allowed moves: %s", kwargs.get("allowed_moves", None))

    # Extract raw text first
    completion_texts = []
    for c in completions:
        if isinstance(c, list) and c and isinstance(c[0], dict):
            completion_texts.append(c[0].get("content", ""))
        elif isinstance(c, dict):
            completion_texts.append(c.get("content", ""))
        else:
            completion_texts.append(str(c))
    
    logging.info("Completions processed: %s", completion_texts)

    rewards = []
    allowed_moves_all = kwargs.get("allowed_moves", None) or []
    rep_mode = kwargs.get("representation_mode", "nl")
    default_nl, default_special = "None", "<move_null>"

    for idx, text in enumerate(completion_texts):
        answer_matches = re.findall(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL | re.IGNORECASE)
        logging.info("Answer matches: %s", answer_matches)
        if answer_matches:
            raw_answer = answer_matches[-1].strip()
            # answer_clean = raw_answer.translate(str.maketrans("", "", string.punctuation)).strip()
            first_token = (raw_answer.split() or [""])[0]

            allowed = allowed_moves_all[idx] if idx < len(allowed_moves_all) else []
            if not allowed:
                allowed = [default_nl] if rep_mode == "nl" else [default_special]

            logging.info("First token: %s", first_token)
            logging.info("Allowed moves: %s", allowed)
            reward = 1.0 if first_token in allowed else 0.0
            logging.info("Reward with logic if first token in allowed moves: %s", reward)
        else:
            reward = 0.0
        rewards.append(reward)
    logging.info("Rewards computed: %s", rewards)

    return rewards



def get_reward_funcs(script_args) -> list[Callable]:
    REWARD_FUNCS_REGISTRY = {
        "format": format_reward,
        "reasoning_steps": reasoning_steps_reward,
        # "cosine": get_cosine_scaled_reward(
        #     min_value_wrong=script_args.cosine_min_value_wrong,
        #     max_value_wrong=script_args.cosine_max_value_wrong,
        #     min_value_correct=script_args.cosine_min_value_correct,
        #     max_value_correct=script_args.cosine_max_value_correct,
        #     max_len=script_args.cosine_max_len,
        # ),
        "repetition_penalty": get_repetition_penalty_reward(
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
        "length": len_reward,
        "tag_count": tag_count_reward,
        "legal_move": legal_move_reward_func,
    }
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

    return reward_funcs