# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to reinforcement learning and verifiers.

"""
import os
import sys
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

import re
import math 
import numpy as np
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from safetensors.torch import save_file

from dataclasses import dataclass
from packaging import version
from typing import List, Optional, Union, Tuple, Any, Dict
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import ModelOutput
if version.parse(transformers.__version__) > version.parse("4.33.0"):
    from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
else:
    from transformers.deepspeed import is_deepspeed_zero3_enabled

# pip install trl
from trl import AutoModelForCausalLMWithValueHead, create_reference_model
# pip install math_verify
import math
from math_verify import parse as mv_parse, verify as mv_verify
import matplotlib.pyplot as plt



@dataclass
class CausalLMWithValueHeadOutputWithPast(ModelOutput):
    """
    Base class for causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    values: torch.FloatTensor = None
    logits: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class MyAutoModelForCausalLMWithValueHead(AutoModelForCausalLMWithValueHead):
    r"""
    An autoregressive model with a value head in addition to the language model head.
    This class inherits from `~trl.AutoModelForCausalLMWithValueHead` which wraps a
    `transformers.PreTrainedModel` class. The wrapper class supports classic functions
    such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped
    model, simply manipulate the `pretrained_model` attribute of this class.

    In addition to AutoModelForCausalLMWithValueHead, this class supports
    - training of outcome reward models (ORM), process reward models (PRM), outcome value function via SFT, given value_labels.
    - contrastive learning

    This class is similar to AutoModelForSequenceClassification, except that it supports both CausalLM and Scoring.

    Class attributes:
        - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This
            should be set to `transformers.AutoModelForCausalLM` for this class.
        - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the
            wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models
            in the future
        - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported
            by the `ValueHead` class. Currently, the supported args are:
            - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the
                `ValueHead` class.
            - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the
                `ValueHead` if a specific initialization strategy is selected.
            - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the
                `ValueHead`. Currently, the supported strategies are:
                - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default
                    strategy.
                - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution.

    """
    def __init__(self, pretrained_model, config, **kwargs):
        super().__init__(pretrained_model, **kwargs)

        self.value_loss_fct_type = getattr(config, "value_loss_fct_type", "bce")
        assert self.value_loss_fct_type in {
            "bce", "ce", "mse", "stepwise_cl", "samplewise_cl"
        }, NotImplementedError("Loss type {} is not implemented.".format(self.value_loss_fct_type))
        self.value_loss_weight = getattr(config, "value_loss_weight", 1.0)

    @staticmethod
    def _pad_value(value, target_length):
        cur_length = value.numel()
        if cur_length < target_length:
            return torch.cat([value,] + [value[-1:]] * (target_length - cur_length), dim=0)
        else:
            return value

    def _stepwise_cl(self, values, labels, masks=None):
        if masks is not None:
            values_per_sample = [value[mask] for value, mask in zip(values, masks)]
            labels_per_sample = [label[mask] for label, mask in zip(labels, masks)]

        # pad each sample to the same number of step
        max_num_steps = max([value.numel() for value in values_per_sample])
        values_per_sample = torch.cat(  # (N, S)
            [self._pad_value(value, max_num_steps).unsqueeze(0) for value in values_per_sample], dim=0).to(torch.float32)
        labels_per_sample = torch.cat(  # (N, S)
            [self._pad_value(label, max_num_steps).unsqueeze(0) for label in labels_per_sample], dim=0).to(torch.float32)

        # we adopt a listwise loss from paper:
        # Learning to Rank: From Pairwise Approach to Listwise Approach
        values_per_step = values_per_sample.T  # (S, N)
        labels_per_step = labels_per_sample.T  # (S, N)
        labels_per_step = labels_per_step / (labels_per_step.sum(dim=1, keepdim=True) + 1e-6)
        value_loss = torch.nn.CrossEntropyLoss()(values_per_step, labels_per_step)

        return value_loss

    def _samplewise_cl(self, values, labels, masks=None):
        if masks is not None:
            values_per_sample = [value[mask] for value, mask in zip(values, masks)]
            labels_per_sample = [label[mask] for label, mask in zip(labels, masks)]

        # we adopt a listwise loss from paper:
        # Learning to Rank: From Pairwise Approach to Listwise Approach
        # pool all steps
        all_values = torch.cat(values_per_sample, dim=0).to(torch.float32)  # (S * N,)
        all_labels = torch.cat(labels_per_sample, dim=0).to(torch.float32)  # (S * N,)
        all_labels = all_labels / (all_labels.sum() + 1e-6)
        value_loss = torch.nn.CrossEntropyLoss()(all_values, all_labels)
        # value_loss = torch.nn.BCEWithLogitsLoss()(all_values, all_labels)

        return value_loss

    def get_value_loss(self, values, value_labels, group_index=None):
        shift_values = values[..., :-1].contiguous()
        shift_value_labels = value_labels[..., 1:].contiguous()
        value_mask = shift_value_labels != -100

        if self.value_loss_fct_type == "bce":
            value_loss = BCEWithLogitsLoss()(
                shift_values[value_mask].float(), shift_value_labels[value_mask].float())
        elif self.value_loss_fct_type == "ce":
            value_loss = CrossEntropyLoss()(
                shift_values[value_mask].float(), shift_value_labels[value_mask].float())
        elif self.value_loss_fct_type == "mse":
            value_loss = MSELoss()(
                shift_values[value_mask].float(), shift_value_labels[value_mask].float())
        elif self.value_loss_fct_type in {"stepwise_cl", "samplewise_cl"}:
            if group_index is None:
                group_index = torch.randint(65536, (shift_values.shape[0],), dtype=torch.int32, device=shift_values.device)

            sample_masks = [group_index == index for index in group_index.unique()]
            if self.value_loss_fct_type == "stepwise_cl":
                value_losses = [self._stepwise_cl(
                    shift_values[mask], shift_value_labels[mask], value_mask[mask]).unsqueeze(0) for mask in sample_masks]
            else:
                value_losses = [self._samplewise_cl(
                    shift_values[mask], shift_value_labels[mask], value_mask[mask]).unsqueeze(0) for mask in sample_masks]

            value_loss = torch.cat(value_losses).mean()
        else:
            raise NotImplementedError("Loss type {} is not implemented.".format(self.value_loss_fct_type))

        return value_loss

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        value_labels: Optional[torch.LongTensor] = None,
        group_index: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMWithValueHeadOutputWithPast]:
        r"""
        Applies a forward pass to the wrapped model and returns the logits of the value head.

        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
            value_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the value prediction loss.
            value_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask for calculating the value loss. `0` means ignoring the current token.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            kwargs.get("output_hidden_states", None) if kwargs.get("output_hidden_states", None) is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        kwargs["output_hidden_states"] = True  # This must be True for self.v_head to work

        if self.is_peft_model:
            peft_type = self.pretrained_model.active_peft_config[0].peft_type \
                if isinstance(self.pretrained_model.active_peft_config, list) \
                    else self.pretrained_model.active_peft_config.peft_type
            if peft_type == "PREFIX_TUNING":
                kwargs.pop("past_key_values")
        
        base_model_outputs = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            return_dict=return_dict,
            **kwargs,
        )

        # get last-layer hidden_states and token_loss
        token_loss = None
        if return_dict:
            last_hidden_state = base_model_outputs.hidden_states[-1]
            lm_logits = base_model_outputs.logits
            token_loss = base_model_outputs.loss
        else:
            if labels is not None:
                token_loss, lm_logits, hidden_state = base_model_outputs[:3]  # TODO check hidden_state
            else:
                lm_logits, hidden_state = base_model_outputs[:2]  # TODO check hidden_state
            last_hidden_state = hidden_state[-1]
        # force upcast in fp32 if logits are in half-precision
        if lm_logits.dtype != torch.float32:
            lm_logits = lm_logits.float()

        # get value prediction and value loss
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
        values = self.v_head(last_hidden_state).squeeze(-1)  # (N, L)
        value_loss = None
        if value_labels is not None:
            value_loss = self.get_value_loss(values, value_labels, group_index)

        # get the final loss as a weighted sum of token_loss and value_loss
        if token_loss is None:
            loss = value_loss * self.value_loss_weight if value_loss is not None else None
        elif value_loss is None:
            loss = token_loss
        else:
            loss = token_loss + value_loss * self.value_loss_weight

        if not return_dict:
            base_model_outputs = base_model_outputs[1:] if token_loss is None else base_model_outputs[2:]
            outputs = (values, lm_logits) + base_model_outputs
            return (loss,) + outputs if loss is not None else outputs

        return CausalLMWithValueHeadOutputWithPast(
            loss=loss,
            values=values,
            logits=lm_logits,
            past_key_values=base_model_outputs.past_key_values,
            hidden_states=base_model_outputs.hidden_states,
            attentions=base_model_outputs.attentions,
        )

    def save_pretrained(self, *args, **kwargs):
        r"""
        Save the pretrained model to a directory. This method is a wrapper around
        `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation
        of `transformers.PreTrainedModel.save_pretrained` for more information.

        Args:
            *args (`list`, *optional*):
                Positional arguments passed along to the underlying model's
                `save_pretrained` method.
            **kwargs (`dict`, *optional*):
                Keyword arguments passed along to the underlying model's
                `save_pretrained` method.
        """
        state_dict = kwargs.get("state_dict")
        if state_dict is None:
            state_dict = self.state_dict()
            kwargs["state_dict"] = state_dict
        safe_serialization = kwargs.get("safe_serialization", False)

        # if it is a peft model only save the `v_head` state_dict
        if self.is_peft_model:
            save_path = args[0]
            state_dict = {key: value for key, value in state_dict.items() if "v_head" in key}
            if safe_serialization:
                save_path = os.path.join(save_path, "pytorch_model.safetensors")
                save_file(state_dict, save_path)
            else:
                save_path = os.path.join(save_path, "pytorch_model.bin")
                torch.save(state_dict, save_path)
            # remove v_head keys and values;
            # remove "pretrained_model" prefix in the rest keys, so that the state_dict 
            # can be loaded by a MoPeftModel correctly.
            kwargs["state_dict"] = {key.replace("pretrained_model.", ""): value \
                for key, value in kwargs["state_dict"].items() if "v_head" not in key}
            if len(kwargs["state_dict"]) == 0:
                kwargs["state_dict"] = None

        return self.pretrained_model.save_pretrained(*args, **kwargs)

    def value_head(self, last_hidden_state):
        if isinstance(last_hidden_state, (list, tuple)):
            last_hidden_state = torch.cat(last_hidden_state, dim=1)  # L x (N, 1, D) --> (N, L, D)
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
        values = self.v_head(last_hidden_state).squeeze(-1)  # (N, L)

        return values

    def generate_and_value(self, *args, **kwargs):
        r"""
        We call `generate` on the wrapped model and calculate the value using the last_hidden_state.
        """
        original_return_dict_in_generate = kwargs.get("return_dict_in_generate", None) or self.generation_config.return_dict_in_generate
        kwargs["return_dict_in_generate"] = True
        
        if self.generation_config.value_by_transition_scores:
            original_output_scores = kwargs.get("output_scores", None) or self.generation_config.output_scores
            kwargs["output_scores"] = True
            outputs = self.pretrained_model.generate(*args, **kwargs)

            # get values from scores
            logits = outputs.scores  # L-length tuple of (N, V)
            input_ids = args[0] if len(args) > 0 else kwargs.get("input_ids", None)
            assert isinstance(input_ids, torch.Tensor), RuntimeError("Cannot detect input_ids.")
            values = self.pretrained_model.compute_transition_scores(
                sequences=outputs.sequences[:, input_ids.shape[1]:], scores=logits, normalize_logits=True)
            
            if not original_output_scores:
                outputs.scores = None
        else:
            original_output_hidden_states = kwargs.get("output_hidden_states", None) or self.generation_config.output_hidden_states
            kwargs["output_hidden_states"] = [-1] if not original_output_hidden_states else True  # get the last hidden_states only
            outputs = self.pretrained_model.generate(*args, **kwargs)

            # get values from last_hidden_state
            if original_output_hidden_states:
                last_hidden_state = [hidden_states_t[-1] for hidden_states_t in outputs.hidden_states][1:]
            else:
                last_hidden_state = outputs.hidden_states[1:]
                outputs.hidden_states = None
            values = self.value_head(last_hidden_state)

        if original_return_dict_in_generate:
            setattr(outputs, "values", values)
            return outputs
        else:
            return outputs, values

    def generate(self, *args, **kwargs):
        r"""
        We call `generate` on the wrapped model.
        """
        get_values = kwargs.pop("get_values", False)
        if get_values:
            return self.generate_and_value(*args, **kwargs)
        else:
            return self.pretrained_model.generate(*args, **kwargs)

    def value(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        **kwargs,
    ) -> Union[Tuple, CausalLMWithValueHeadOutputWithPast]:
        # ignore certain keys if provided.
        ignored_keys = ["labels", "return_dict", "output_attentions", "output_hidden_states", "use_cache"]
        for key in ignored_keys:
            kwargs.pop(key, None)

        if self.is_peft_model:
            peft_type = self.pretrained_model.active_peft_config[0].peft_type \
                if isinstance(self.pretrained_model.active_peft_config, list) \
                    else self.pretrained_model.active_peft_config.peft_type
            if peft_type == "PREFIX_TUNING":
                kwargs.pop("past_key_values")
        
        base_model_outputs = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=None,
            use_cache=None,
            output_attentions=False,
            return_dict=True,
            output_hidden_states=True,
            **kwargs,
        )
        if self.generation_config.value_by_transition_scores:
            logits = torch.unbind(base_model_outputs.logits, dim=1)  # (N, L, V) to L-length tuple of (N, V)
            values = self.pretrained_model.compute_transition_scores(
                sequences=input_ids, scores=logits, normalize_logits=True)
        else:
            last_hidden_state = base_model_outputs.hidden_states[-1]
            values = self.value_head(last_hidden_state)

        return values


def load_safetensors(file_path, device="cpu"):
    from safetensors import safe_open

    tensors = None
    if os.path.isfile(file_path) and file_path.endswith(".safetensors"):
        with safe_open(file_path, framework="pt", device=device) as f:
            tensors = {k: f.get_tensor(k) for k in f.keys()}
    else:
        raise FileNotFoundError("File path is not a valid safetensor path: {}".format(file_path))
    
    return tensors


def load_rl_weights_into_model(model, ckpt_folder):
    ckpt_path1 = os.path.join(ckpt_folder, "pytorch_model.safetensors")
    ckpt_path2 = os.path.join(ckpt_folder, "pytorch_model.bin")

    weights = {}
    valid_ckpt_path = None
    if os.path.isfile(ckpt_path1):
        weights = load_safetensors(ckpt_path1, device="cpu")
        assert weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path1))
        valid_ckpt_path = ckpt_path1
    elif os.path.isfile(ckpt_path2):
        weights = torch.load(ckpt_path2)
        assert weights, RuntimeError("CKPT File is empty: {}".format(ckpt_path2))
        valid_ckpt_path = ckpt_path2

    model.post_init(weights)

    return model, valid_ckpt_path


def prepare_deepspeed_ref_model(model):
    if is_deepspeed_zero3_enabled:
        # Adopted from: https://github.com/huggingface/trl/blob/02f5c1d8cee73045c837d01d7f1577a57779b035/trl/trainer/ppo_trainer.py#L1399
        import deepspeed

        # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
        deepspeed_plugin = accelerator.state.deepspeed_plugin
        config_kwargs = deepspeed_plugin.deepspeed_config
        if hasattr(model, "config"):
            hidden_size = (
                max(model.config.hidden_sizes)
                if getattr(model.config, "hidden_sizes", None)
                else getattr(model.config, "hidden_size", None)
            )
            if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
                # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
                # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
                config_kwargs.update(
                    {
                        "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
                        "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
                        "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
                    }
                )
        
        # TODO check this
        # If ZeRO-3 is used, we shard both the active and reference model.
        # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
        if config_kwargs["zero_optimization"]["stage"] != 3:
            config_kwargs["zero_optimization"]["stage"] = 0
        ref_model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
        ref_model.eval()
    else:
        ref_model = create_reference_model(model)
    
    return ref_model


def _maybe_unwrap_chat(chat):
    """_summary_

    Args:
        chat (str, List[str], List[dict], List[List[dict]]): conversational format, e.g.,
            [[{'role': ..., 'content': ...}], ...]
            
    Returns:
        List[str]: _description_
    """
    if isinstance(chat, str):
        return chat
    elif isinstance(chat, list) and len(chat) > 0:
        chat = [_maybe_unwrap_chat(_chat) for _chat in chat]
        if all(isinstance(_chat, list) and len(_chat) == 1 and isinstance(_chat[0], str) for _chat in chat):
            chat = [_chat[0] for _chat in chat]
        return chat
    elif isinstance(chat, dict) and 'content' in chat:
        return chat['content']
    else:
        raise RuntimeError(f"Unexpect input: {chat}")


def extract_boxed_answer(pred_str):
    preds = 'ERROR'
    if 'boxed' in pred_str:
        ans = pred_str.split('boxed')[-1]  # find the last boxed
        if len(ans) == 0:
            preds = 'ERROR'
        elif (ans[0] == '{'):
            stack = 1
            a = ''
            for c in ans[1:]:
                if (c == '{'):
                    stack += 1
                elif (c == '}'):
                    stack -= 1
                    if (stack == 0): 
                        break
                a += c
            preds = a
        else:
            preds = 'ERROR'
    else:
        preds = 'ERROR'

    return preds
    
    
def math_parse_solution(solution):
    """Given a math solution with CoT, discover the final answer.
    
    We assume the final answer is wrapped by \\boxed{...}, or <answer>...</answer>, or <output>...</output>, or ####.

    Args:
        solution (str): _description_
    """
    if isinstance(solution, str):
        ans = ""
        if "\\boxed" in solution:
            ans = extract_boxed_answer(solution).strip()
            if ans == "ERROR":
                ans = ""
        elif "<answer>" in solution:
            matches = re.findall(r"<answer>(.*?)<\/answer>", solution)
            if matches:
                ans = matches[-1].strip()  # find the last match
        elif "<output>" in solution:
            matches = re.findall(r"<output>(.*?)<\/output>", solution)
            if matches:
                ans = matches[-1].strip()  # find the last match
        elif "#### " in solution and re.match(r"[\s\S]*\#\#\#\# [0-9]+$", solution):  # solution ends like #### 23
            matches = re.search(r"\#\#\#\# ([0-9]+)$", solution)
            if matches:
                ans = matches.group(1).strip()
        
        if ans and (not re.match(r"^-?[0-9\.]+$", ans)):  # if not simple numbers, add latex sign
            if not (ans.startswith("$") and ans.endswith("$")):
                ans = f"${ans}$"
        return ans
        
    elif isinstance(solution, list):
        return [math_parse_solution(sol) for sol in solution]
    else:
        raise RuntimeError(f"Cannot parse solution of type {type(solution)}.")



"""

Length-Control Methods Recipes

- AdapThink: math_accuracy_reward + adaptive_reasoning_control_reward

- GRPO: math_accuracy_reward

- LCPO: LCPO_max_reward

- TLB: TLB_reward

- CosFn: length_scaled_math_accuracy_reward

"""
def math_accuracy_reward(completion, label, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth, using math-verify.
    Return Reward 1 if the content is the same as the ground truth, 0 otherwise.
    
    Args:
        completion (List[str], List[List[dict]]): a string answer with the final result given in \\boxed{...} 
            or <answer>...</answer>, or <output>...</output>, or #### .
        label (List[str], List[List[dict]]): similar to completion, but with the ground truth.
    """
    # remove dict structure
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    # parse solution
    completion = math_parse_solution(completion)
    label = math_parse_solution(label)

    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
    # print("completion", completion)
    # print("label", label)
    rewards = []
    for _completion, _label in zip(completion, label):
        try:
            _completion = mv_parse(_completion)
            _label = mv_parse(_label)
            reward = float(mv_verify(_completion, _label))
        except Exception:  # if it fails for any reason, return 0.0
            reward = 0.0
        rewards.append(reward)
    
    return rewards

def length_scaled_math_accuracy_reward(
    completion, 
    label, 
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_seq_length: int = 4096,
    tokenizer: PreTrainedTokenizerBase = None,
    **kwargs):
    """Reward function that checks if the completion is the same as the ground truth, using math-verify.
    The apply cosine-scaled length penalty to the reward. If tokenizer is provided, we count number of tokens 
    in the completion, otherwise, we count the string length.
    
    This metric is from paper:
    Demystifying Long Chain-of-Thought Reasoning in LLMs, https://arxiv.org/abs/2502.03373.
    
    Args:
        completion (List[str], List[List[dict]]): a string answer with the final result given in \\boxed{...} 
            or <answer>...</answer>, or <output>...</output>, or #### .
        label (List[str], List[List[dict]]): similar to completion, but with the ground truth.
    """
    # remove dict structure
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
    
    is_correct = math_accuracy_reward(completion, label, **kwargs)
    
    if tokenizer:
        completion_length = [len(tokenizer.encode(_completion, add_special_tokens=False)) for _completion in completion]
    else:
        completion_length = [len(_completion) for _completion in completion]
    
    rewards = []
    for i, (length, is_c) in enumerate(zip(completion_length, is_correct)):
        if is_c:
            min_value = min_value_correct
            max_value = max_value_correct
        else:
            # Swap min/max for incorrect answers
            min_value = max_value_wrong
            max_value = min_value_wrong
            
        cosine = math.cos(length / max_seq_length * math.pi)
            
        reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
        rewards.append(float(reward))
        
    return rewards

def think_ans_format_reward(completion, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    completion = _maybe_unwrap_chat(completion)
    if isinstance(completion, str):
        completion = [completion]
        
    pattern = r"^<think>[\s\S]*?<\/think>\n*<answer>[\s\S]*?<\/answer>$"
    rewards = [1.0 if re.match(pattern, _completion) else 0.0 for _completion in completion]
    
    return rewards

def get_weight(p, p_low=0.3, p_high=0.7):
    """
    p <= p_low: 1
    p_low < p < p_high: cosine transition
    p >= p_high: -1
    """
    if p <= p_low:
        return 1
    elif p >= p_high:
        return -1
    else:
        x = (p - p_low) / (p_high - p_low) * np.pi
        return np.cos(x)

def adaptive_reasoning_control_reward(
    completion,
    label,
    tokenizer: PreTrainedTokenizerBase = None,
    p_threshold: float = 1, 
    p_low: float = 0.25, 
    **kwargs):
    """
    Adaptive reasoning control reward combining length, process switch and depth control.
    
    Control strategy:
    - High confidence (accuracy > p_low):
      * Encourage shorter length
      * Discourage reflection-related words frequent
    
    - Low confidence (accuracy <= p_low):
      * No explicit efficient control
    """

    # Get accuracy rewards first
    accuracy_rewards = math_accuracy_reward(completion, label, **kwargs)
    group_mean_accuracy = sum(accuracy_rewards) / len(accuracy_rewards)
    
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Get weight based on group accuracy
    # when group accuracy is low, the weight is +, else -
    weight = get_weight(group_mean_accuracy, p_low, p_threshold)
    
    # Calculate length rewards 
    length_rewards = calculate_relative_rewards(
        completion, accuracy_rewards, 
        measure_func=lambda x: len(tokenizer.encode(x, add_special_tokens=False)) if tokenizer else len(x)
    )
    # Calculate reflection words rewards
    word_rewards = []
    word_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "alternatively", 'another', 'instead', 'however', "wait", "check", 'hold on', 'verify', 'hmm', 'rethink'
    ])) + r')\b', re.IGNORECASE)
    
    word_rewards = calculate_relative_rewards(
        completion, accuracy_rewards,
        measure_func=lambda x: len(word_pattern.findall(x))
    )

    # ablations on different words
    switch_rewards = []
    switch_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "alternatively", 'another', 'instead', 'however'
    ])) + r')\b', re.IGNORECASE)
    
    switch_rewards = calculate_relative_rewards(
        completion, accuracy_rewards,
        measure_func=lambda x: len(switch_pattern.findall(x))
    )
    
    sequential_rewards = []
    sequential_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "first", 'then', 'next', 'finally', 'therefore', 'so,', 'thus'
    ])) + r')\b', re.IGNORECASE)
    
    sequential_rewards = calculate_relative_rewards(
        completion, accuracy_rewards,
        measure_func=lambda x: len(sequential_pattern.findall(x))
    )
    
    depth_rewards = []
    depth_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "wait", "check", 'hold on', 'verify'
    ])) + r')\b', re.IGNORECASE)
    
    depth_rewards = calculate_relative_rewards(
        completion, accuracy_rewards,
        measure_func=lambda x: len(depth_pattern.findall(x))
    )
    
    # Combine rewards
    final_rewards = []
    for i in range(len(completion)):
        reward = weight * (switch_rewards[i] + length_rewards[i] ) # ablation
        # Clip final reward
        reward = max(min(reward, 1), -1) 
        final_rewards.append(float(reward))
    
    return final_rewards

def calculate_relative_rewards(completions, accuracy_rewards, measure_func):
    """Helper function to calculate relative rewards based on a measurement function"""
    # Get measurements
    measurements = [measure_func(comp) for comp in completions]
    
    # Split into correct and wrong answers
    correct_indices = [i for i, r in enumerate(accuracy_rewards) if r > 0]
    wrong_indices = [i for i, r in enumerate(accuracy_rewards) if r <= 0]
    
    correct_measurements = [measurements[i] for i in correct_indices]
    wrong_measurements = [measurements[i] for i in wrong_indices]
    
    # Calculate means
    mean_correct = sum(correct_measurements) / len(correct_measurements) if correct_measurements else 0
    mean_wrong = sum(wrong_measurements) / len(wrong_measurements) if wrong_measurements else 0
    
    # Calculate relative rewards
    rewards = []
    for i, measurement in enumerate(measurements):
        if accuracy_rewards[i] > 0:
            base = mean_correct
        else:
            base = mean_wrong
            
        if base == 0:
            rewards.append(0)
        else:
            ratio = float(measurement / base - 1)
            rewards.append(ratio)
            
    return rewards

# Baselines

def LCPO_max_reward(completion, label, target_length, tokenizer, alpha=0.0003, delta=0.5, **kwargs):
    """Reward function that implements the L1-Max variant from Length Controlled Policy Optimization.
    
    Args:
        completion (List[str]): Model generated completions
        label (List[str]): Ground truth labels
        target_length (int or List[int]): Target token length(s) for the completions
        tokenizer: Tokenizer to count tokens
        alpha (float): Controls penalty for length violations (default: 0.1)
        delta (float): Offset term to prefer correct answers with minor violations (default: 0.5)
    
    Returns:
        List[float]: Rewards for each completion
    """
    # Ensure inputs are lists
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
    if isinstance(target_length, int) or isinstance(target_length, float):
        target_length = [target_length] * len(completion)
        
    # Get correctness rewards
    correctness = math_accuracy_reward(completion, label, **kwargs)
    
    # Get completion lengths
    completion_lengths = [len(tokenizer.encode(_completion, add_special_tokens=False)) 
                        for _completion in completion]
    
    rewards = []
    for is_correct, length, target in zip(correctness, completion_lengths, target_length):
        # Only apply reward if answer is correct
        if is_correct:
            # Clip the length penalty term between 0 and 1
            length_penalty = max(0, min(1, alpha * (target - length) + delta))
            reward = float(is_correct) * length_penalty
        else:
            reward = 0.0
        rewards.append(reward)
        
    return rewards

def TLB_reward(completion, label, tokenizer, L_max, n_generations=8, **kwargs):
    """Reward function that implements Token Length Budget (TLB) based calibration.
    
    Args:
        completion (List[str]): Model generated completions
        label (List[str]): Ground truth labels
        tokenizer: Tokenizer to count tokens
        L_max (int): Maximum generation length
        n_generations (int): Number of generations per question (default: 16)
    
    Returns:
        List[float]: Calibrated rewards for each completion
    """
    # Ensure inputs are lists
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Get correctness for each completion
    correctness = math_accuracy_reward(completion, label)
    
    # Calculate sampling accuracy and average length of correct responses
    batch_size = len(completion) // n_generations
    rewards = []
    
    for batch_idx in range(batch_size):
        start_idx = batch_idx * n_generations
        end_idx = start_idx + n_generations
        
        # Get batch samples
        batch_completions = completion[start_idx:end_idx]
        batch_correctness = correctness[start_idx:end_idx]
        
        # Calculate p (sampling accuracy)
        c = sum(batch_correctness)  # number of correct responses
        p = c / n_generations       # sampling accuracy
        
        # Calculate L_π (average length of correct responses)
        correct_lengths = [
            len(tokenizer.encode(comp, add_special_tokens=False))
            for comp, is_correct in zip(batch_completions, batch_correctness)
            if is_correct
        ]
        L_pi = sum(correct_lengths) / len(correct_lengths) if correct_lengths else L_max
        
        # Calculate L_budget
        L_budget = p * L_pi + (1 - p) * L_max
        
        # Calculate rewards for each sample in batch
        for comp, is_correct in zip(batch_completions, batch_correctness):
            # Get current completion length
            L_i = len(tokenizer.encode(comp, add_special_tokens=False))
            
            # Calculate λ
            lambda_i = (L_i - L_budget) / L_budget
            
            # Calculate reward based on correctness and λ
            if is_correct:
                reward = max(-0.5 * lambda_i + 0.5, 0.1)
            else:
                reward = min(0.9 * lambda_i - 0.1, -0.1)
                
            rewards.append(float(reward))
            
    return rewards

def ngram_repetition_penalty(completion, n_gram_size=20, penalty_value=-0.1, **kwargs):
    """N-gram repetition penalty reward function.
    
    Args:
        completion (List[str]): Model generated completions
        n_gram_size (int): Size of n-grams to check for repetition
        penalty_value (float): Penalty value for repeated n-grams
    
    Returns:
        List[float]: Rewards/penalties for each completion
    """
    completion = _maybe_unwrap_chat(completion)
    if isinstance(completion, str):
        completion = [completion]
        
    rewards = []
    for seq in completion:
        # Tokenize sequence (simple space-based tokenization for demonstration)
        tokens = seq.split()
        l = len(tokens)
        
        # Initialize rewards vector
        r = [0.0] * l
        ngrams = set()
        
        # Check for n-gram repetitions
        for j in range(l - n_gram_size + 1):
            # Get current n-gram
            current_ngram = tuple(tokens[j:j + n_gram_size])
            
            # If n-gram already seen, apply penalty
            if current_ngram in ngrams:
                for t in range(j, j + n_gram_size):
                    r[t] = penalty_value
                    
            # Add current n-gram to seen set
            ngrams.add(current_ngram)
            
        # Average penalties across sequence
        avg_reward = sum(r) / len(r) if r else 0.0
        rewards.append(float(avg_reward))
        
    return rewards

def cosine_reward(completion, label, tokenizer, L_max, 
                  r0_correct=1.0, r0_wrong=0.0,
                  rL_correct=0.5, rL_wrong=-0.5,
                  re=-1.0, **kwargs):
    """Cosine reward function based on completion correctness and length.
    
    Args:
        completion (List[str]): Model generated completions
        label (List[str]): Ground truth labels
        tokenizer: Tokenizer to count tokens
        L_max (int): Maximum allowed length
        r0_correct (float): Reward for correct answer at L=0
        r0_wrong (float): Reward for wrong answer at L=0
        rL_correct (float): Reward for correct answer at L=L_max
        rL_wrong (float): Reward for wrong answer at L=L_max
        re (float): Penalty for exceeding L_max
    """
    def CosFn(t, T, eta_min, eta_max):
        return eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(t * math.pi / T))
    
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
    
    # Get correctness for each completion
    correctness = math_accuracy_reward(completion, label)
    
    rewards = []
    for comp, is_correct in zip(completion, correctness):
        # Get generation length
        L_gen = len(tokenizer.encode(comp, add_special_tokens=False))
        
        # Handle special case when L_gen = L_max
        if L_gen >= L_max:
            rewards.append(float(re))
            continue
            
        # Select appropriate reward parameters based on correctness
        if is_correct:
            reward = CosFn(L_gen, L_max, r0_correct, rL_correct)
        else:
            reward = CosFn(L_gen, L_max, r0_wrong, rL_wrong)
            
        rewards.append(float(reward))
        
    return rewards

def group_average_switch_reward(
    completion,
    label,
    **kwargs):
    """
    Calculate group average switch counts as reward.
    
    Args:
        completion: List of completion strings or single string
        label: List of label strings or single string
        
    Returns:
        group-level statistics
    """
    # Unwrap chat format if needed
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    
    # Convert to list if string
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Define patterns
    switch_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "alternatively", 'another', 'however', 'instead'
    ])) + r')\b', re.IGNORECASE)
    
    # Calculate average switches for the group
    group_switches = []
    
    for text in completion:
        switch_count = len(switch_pattern.findall(text))
        group_switches.append(switch_count)
    
    avg_switches = float(sum(group_switches) / len(group_switches))
    
    return [avg_switches for _completion in completion]

def group_average_sequential_reward(
    completion,
    label,
    **kwargs):
    """
    Calculate group average switch counts as reward.
    
    Args:
        completion: List of completion strings or single string
        label: List of label strings or single string
        
    Returns:
        group-level statistics
    """
    # Unwrap chat format if needed
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    
    # Convert to list if string
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Define patterns
    sequential_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "first", 'then', 'next', 'finally', 'therefore', 'so,', 'thus'
    ])) + r')\b', re.IGNORECASE)
    
    # Calculate average switches for the group
    group_sequential = []
    
    for text in completion:
        sequential_count = len(sequential_pattern.findall(text))
        group_sequential.append(sequential_count)
    
    avg_sequential = float(sum(group_sequential) / len(group_sequential))
    
    return [avg_sequential for _completion in completion]

def group_average_depth_reward(
    completion,
    label,
    **kwargs):
    """
    Calculate group average depth counts as reward.
    
    Args:
        completion: List of completion strings or single string
        label: List of label strings or single string
        
    Returns:
        group-level statistics
    """
    # Unwrap chat format if needed
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    
    # Convert to list if string
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Define patterns
    
    depth_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "wait", "check", 'hold on', 'verify'
    ])) + r')\b', re.IGNORECASE)
    
    # Calculate average depths for the group
    group_depths = []
    
    for text in completion:
        depth_count = len(depth_pattern.findall(text))
        group_depths.append(depth_count)
    
    avg_depths = float(sum(group_depths) / len(group_depths))
    
    return [avg_depths for _completion in completion]

def group_average_output_reward(
    completion,
    label,
    n_generations: int = 1,
    **kwargs):
    """
    Calculate group average switch counts as reward.
    
    Args:
        completion: List of completion strings or single string
        label: List of label strings or single string
        n_generations: Number of generations to group together
        
    Returns:
        group-level statistics
    """
    # Unwrap chat format if needed
    completion = _maybe_unwrap_chat(completion)
    label = _maybe_unwrap_chat(label)
    
    # Convert to list if string
    if isinstance(completion, str):
        completion = [completion]
    if isinstance(label, str):
        label = [label]
        
    # Define patterns
    output_pattern = re.compile(r'\b(' + '|'.join(map(re.escape, [
        "**final answer**",
    ])) + r')\b', re.IGNORECASE)
    
    results = []
    # Process in chunks of n_generations
    for i in range(0, len(completion), n_generations):
        chunk = completion[i:i+n_generations]
        
        # Calculate outputs for current chunk
        chunk_outputs = []
        for text in chunk:
            output_count = len(output_pattern.findall(text))
            chunk_outputs.append(output_count)
        
        # Calculate average for current chunk
        chunk_avg = float(sum(chunk_outputs) / len(chunk_outputs))
        
        # Extend results with the chunk average
        results.extend([chunk_avg] * len(chunk))

    return results

# Group Diversity Calculation

def length_group_diversity(completion: Union[str, List[str], List[List[dict]]],
                           label: Union[str, List[str], List[List[dict]]],
                           tokenizer: PreTrainedTokenizerBase = None,
                           length_groups: Optional[List[Tuple[int, int]]] = None,
                           n_generations: int = 1,
                           **kwargs) -> List[float]:
    
    completions = _maybe_unwrap_chat(completion)
    if isinstance(completions, str):
        completions = [completions]

    if tokenizer is not None:
        lengths = [len(tokenizer.encode(text, add_special_tokens=False)) for text in completions]
    else:
        lengths = [len(text) for text in completions]

    if length_groups is None:
        length_groups = [(0, 512), (513, 1024), (1025, 1536), (2048, float('inf'))] # cause we set 2K token limit for train

    def map_length_to_group(length):
        for idx, (start, end) in enumerate(length_groups):
            if start <= length < end:
                return f"{start}-{end - 1}"
        return f"{length_groups[-1][0]}+"

    grouped_lengths = [map_length_to_group(length) for length in lengths]

    results = []
    for i in range(0, len(grouped_lengths), n_generations):
        chunk = grouped_lengths[i:i+n_generations]
        group_counts = {}
        for group in chunk:
            group_counts[group] = group_counts.get(group, 0) + 1
        total = sum(group_counts.values())
        entropy = 0.0
        for count in group_counts.values():
            p = count / total
            entropy -= p * math.log(p, 2) if p > 0 else 0.0
        normalized_entropy = entropy / math.log(total, 2) if total > 1 else 0.0
        results.extend([normalized_entropy] * n_generations)

    return results

def process_group_words_diversity(completion: Union[str, List[str], List[List[dict]]],
                                label: Union[str, List[str], List[List[dict]]],
                                transition_words: List[str] = None,
                                n_generations: int = 1,
                                **kwargs) -> List[float]:
    
    if transition_words is None:
        transition_words = [
            "wait", "check", 'hold on', 'verify', "alternatively", "however", 'another', 'instead', 'hmm', 'but'
        ]
    
    completions = _maybe_unwrap_chat(completion)
    if isinstance(completions, str):
        completions = [completions]

    count_groups = [
        (0, 10), (11, 20), (21, 30), (31, float('inf'))
    ]

    def map_count_to_group(count):
        for start, end in count_groups:
            if start <= count < end:
                return f"{start}-{end if end != float('inf') else 'inf'}"
        return "31-inf"  

    pattern = re.compile(r'\b(' + '|'.join(map(re.escape, transition_words)) + r')\b', re.IGNORECASE)
    word_counts = [len(pattern.findall(text)) for text in completions]
    
    grouped_counts = [map_count_to_group(count) for count in word_counts]

    results = []
    for i in range(0, len(grouped_counts), n_generations):
        chunk = grouped_counts[i:i+n_generations]
        group_counts = {}
        for group in chunk:
            group_counts[group] = group_counts.get(group, 0) + 1
            
        total = sum(group_counts.values())
        entropy = 0.0
        for count in group_counts.values():
            p = count / total
            entropy -= p * math.log(p, 2) if p > 0 else 0.0
            
        normalized_entropy = entropy / math.log(total, 2) if total > 1 else 0.0
        results.extend([normalized_entropy] * n_generations)

    return results


def process_group_depth_diversity(completion: Union[str, List[str], List[List[dict]]],
                                label: Union[str, List[str], List[List[dict]]],
                                transition_words: List[str] = None,
                                n_generations: int = 1,
                                **kwargs) -> List[float]:
    
    if transition_words is None:
        transition_words = [
            "wait", "check", 'hold on', 'verify'
        ]
    
    completions = _maybe_unwrap_chat(completion)
    if isinstance(completions, str):
        completions = [completions]

    count_groups = [
        (0, 3), (4, 6), (7, 9),(10, 13), (14, float('inf'))
    ]

    def map_count_to_group(count):
        for start, end in count_groups:
            if start <= count < end:
                return f"{start}-{end if end != float('inf') else 'inf'}"
        return "14-inf"  

    pattern = re.compile(r'\b(' + '|'.join(map(re.escape, transition_words)) + r')\b', re.IGNORECASE)
    word_counts = [len(pattern.findall(text)) for text in completions]
    
    grouped_counts = [map_count_to_group(count) for count in word_counts]

    results = []
    for i in range(0, len(grouped_counts), n_generations):
        chunk = grouped_counts[i:i+n_generations]
        group_counts = {}
        for group in chunk:
            group_counts[group] = group_counts.get(group, 0) + 1
            
        total = sum(group_counts.values())
        entropy = 0.0
        for count in group_counts.values():
            p = count / total
            entropy -= p * math.log(p, 2) if p > 0 else 0.0
            
        normalized_entropy = entropy / math.log(total, 2) if total > 1 else 0.0
        results.extend([normalized_entropy] * n_generations)

    return results

def process_group_switch_diversity(completion: Union[str, List[str], List[List[dict]]], 
                                 label: Union[str, List[str], List[List[dict]]],
                                 transition_words: List[str] = None,
                                 n_generations: int = 1,
                                 **kwargs) -> List[float]:

    if transition_words is None:
        transition_words = [
            "alternatively", "however", 'another', 'instead'
        ]

    completions = _maybe_unwrap_chat(completion)
    if isinstance(completions, str):
        completions = [completions]

    count_groups = [
        (0, 3), (4, 6), (7, 9), (10, 13), (14, float('inf'))
    ]

    def map_count_to_group(count):
        for start, end in count_groups:
            if start <= count < end:
                return f"{start}-{end if end != float('inf') else 'inf'}"
        return "14-inf"  

    pattern = re.compile(r'\b(' + '|'.join(map(re.escape, transition_words)) + r')\b', re.IGNORECASE)
    word_counts = [len(pattern.findall(text)) for text in completions]
    
    grouped_counts = [map_count_to_group(count) for count in word_counts]

    results = []
    for i in range(0, len(grouped_counts), n_generations):
        chunk = grouped_counts[i:i+n_generations]
        group_counts = {}
        for group in chunk:
            group_counts[group] = group_counts.get(group, 0) + 1
            
        total = sum(group_counts.values())
        entropy = 0.0
        for count in group_counts.values():
            p = count / total
            entropy -= p * math.log(p, 2) if p > 0 else 0.0
            
        normalized_entropy = entropy / math.log(total, 2) if total > 1 else 0.0
        results.extend([normalized_entropy] * n_generations)

    return results


# Conditional Greedy Diversity Sampling 
def greedy_select_diverse(responses, labels, candidate_indices, k, tokenizer, p_threshold=0.8, current_accuracy=0.0, initial_indices=None):
    """Greedily select k most diverse samples.
    
    Args:
        responses: List[str], all responses in current batch
        candidate_indices: indices of candidate samples to select from
        k: number of samples to select
        tokenizer: tokenizer for computing length diversity
        p_threshold: threshold for accuracy to determine diversity preference
        current_accuracy: current mean accuracy of the batch
        initial_indices: indices of already selected samples (if any)
    Returns:
        List[int]: indices of selected samples
    """
    selected = [] if initial_indices is None else initial_indices.copy()
    candidate_indices = candidate_indices.copy()
    
    # Dynamically set weights based on current_accuracy and p_threshold
    for _ in range(k):
        best_score = float('-inf')
        best_idx = None
        
        # Try each candidate
        for idx in candidate_indices:
            if idx in selected:
                continue
                
            # Temporarily add current candidate
            temp_selected = selected + [idx]
            selected_responses = [responses[i] for i in temp_selected]
            selected_labels = [labels[i] for i in temp_selected]
            
            # Compute diversity scores for current subset
            depth_div = process_group_depth_diversity(selected_responses, selected_labels, n_generations=len(temp_selected))[0]
            switch_div = process_group_switch_diversity(selected_responses, selected_labels, n_generations=len(temp_selected))[0]
            length_div = length_group_diversity(selected_responses, selected_labels, tokenizer=tokenizer, n_generations=len(temp_selected))[0]
            
            # Compute weighted diversity score (its a ablation, using length + all for formal)
            score = depth_div + switch_div + length_div

            # Update best if current score is higher
            if score > best_score:
                best_score = score
                best_idx = idx
        
        if best_idx is not None:
            selected.append(best_idx)
            candidate_indices.remove(best_idx)
            
    return selected

def select_diverse_subset(completions, labels, n_generations=16, target_n_generations=8, min_correct=1, min_incorrect=1, tokenizer=None):
    
    batch_size = len(completions) // n_generations
    selected_completions = []
    selected_indices = []
    
    total_initial_switch_div = 0
    total_initial_depth_div = 0
    total_initial_length_div = 0
    total_initial_accuracy = 0
    
    total_subset_switch_div = 0
    total_subset_depth_div = 0
    total_subset_length_div = 0
    total_subset_accuracy = 0
    
    for batch_idx in range(batch_size):
        start_idx = batch_idx * n_generations
        end_idx = start_idx + n_generations
        batch_responses = completions[start_idx:end_idx]
        batch_labels = labels[start_idx:end_idx]
        
        # Calculate the diversity metrics of the initial set
        initial_switch_div = process_group_switch_diversity(batch_responses, batch_labels, n_generations=n_generations)[0]
        initial_depth_div = process_group_depth_diversity(batch_responses, batch_labels, n_generations=n_generations)[0]
        initial_length_div = length_group_diversity(batch_responses, batch_labels, tokenizer=tokenizer, n_generations=n_generations)[0]
        
        # Calculate the accuracy metric of the initial set
        acc_rewards = math_accuracy_reward(batch_responses, batch_labels)
        initial_accuracy = sum(acc_rewards) / len(acc_rewards)
        
        total_initial_switch_div += initial_switch_div
        total_initial_depth_div += initial_depth_div
        total_initial_length_div += initial_length_div
        total_initial_accuracy += initial_accuracy
        
        correct_indices = [i for i, r in enumerate(acc_rewards) if r == 1.0]
        wrong_indices = [i for i, r in enumerate(acc_rewards) if r == 0.0]
        
        # Simplified selection logic: avoid all-correct or all-wrong
        batch_selected = []
        
        if len(correct_indices) == 0:
            # All wrong, just select diverse wrong samples
            batch_selected = greedy_select_diverse(
                batch_responses, batch_labels, wrong_indices, 
                target_n_generations, tokenizer
            )
        elif len(wrong_indices) == 0:
            # All correct, just select diverse correct samples  
            batch_selected = greedy_select_diverse(
                batch_responses, batch_labels, correct_indices,
                target_n_generations, tokenizer
            )
        else:
            # Mix of correct and wrong - ensure we get at least 1 of each type
            # Select 1 correct and 1 wrong first
            selected_correct = greedy_select_diverse(
                batch_responses, batch_labels, correct_indices, min_correct, tokenizer
            )
            selected_wrong = greedy_select_diverse(
                batch_responses, batch_labels, wrong_indices, min_incorrect, tokenizer
            )
            
            batch_selected.extend(selected_correct)
            batch_selected.extend(selected_wrong)
            
            # Fill remaining slots from all candidates except already selected
            remaining = target_n_generations - len(batch_selected)
            if remaining > 0:
                remaining_candidates = [i for i in range(n_generations) if i not in batch_selected]
                additional = greedy_select_diverse(
                    batch_responses, batch_labels, remaining_candidates,
                    remaining, tokenizer
                )
                batch_selected.extend(additional)
        
        # Calculate the diversity metrics for selected subsets
        selected_responses = [batch_responses[i] for i in batch_selected]
        selected_labels = [batch_labels[i] for i in batch_selected]
        
        subset_switch_div = process_group_switch_diversity(selected_responses, selected_labels, n_generations=len(batch_selected))[0]
        subset_depth_div = process_group_depth_diversity(selected_responses, selected_labels, n_generations=len(batch_selected))[0]
        subset_length_div = length_group_diversity(selected_responses, selected_labels, tokenizer=tokenizer, n_generations=len(batch_selected))[0]
        
        # Calculate the accuracy metric of the subset
        subset_acc_rewards = math_accuracy_reward(selected_responses, selected_labels)
        subset_accuracy = sum(subset_acc_rewards) / len(subset_acc_rewards)
        
        total_subset_switch_div += subset_switch_div
        total_subset_depth_div += subset_depth_div
        total_subset_length_div += subset_length_div
        total_subset_accuracy += subset_accuracy
        
        # Add selected samples and their global indexes
        for local_idx in batch_selected:
            global_idx = start_idx + local_idx
            selected_indices.append(global_idx)
            selected_completions.append(batch_responses[local_idx])
        
    initial_metrics = {
        'avg_switch_div': total_initial_switch_div / batch_size,
        'avg_process_div': total_initial_depth_div / batch_size,
        'avg_length_div': total_initial_length_div / batch_size,
        'avg_accuracy': total_initial_accuracy / batch_size
    }
    
    subset_metrics = {
        'avg_switch_div': total_subset_switch_div / batch_size,
        'avg_process_div': total_subset_depth_div / batch_size,
        'avg_length_div': total_subset_length_div / batch_size,
        'avg_accuracy': total_subset_accuracy / batch_size
    }

    return selected_completions, selected_indices, initial_metrics, subset_metric

PRESET_REWARD_FUNCS = {
    "math_accuracy": math_accuracy_reward,
    "length_scaled_math_accuracy": length_scaled_math_accuracy_reward,
    "adaptive_reasoning_control": adaptive_reasoning_control_reward, 
    "think_ans_format": think_ans_format_reward,
    "length_group_diversity": length_group_diversity,
    "process_group_switch_diversity": process_group_switch_diversity,
    "process_group_depth_diversity": process_group_depth_diversity,
    "process_depth": group_average_depth_reward,
    "process_switch": group_average_switch_reward,
    "process_output": group_average_output_reward,
    "process_sequential": group_average_sequential_reward,
    "LCPO_max_reward": LCPO_max_reward,
    "TLB_reward": TLB_reward,
    "ngram_repetition_penalty": ngram_repetition_penalty,
}
