import os
import re
from typing import Dict, List, Union

import json
import random

from swift.llm import InferRequest
from swift.llm.infer.infer_engine import PtEngine
from tqdm import tqdm
import torch
import numpy as np
from rouge import Rouge
from mosestokenizer import MosesTokenizer
from hard import _passes_all_cases

def calculate_sari(orig_text: str, sys_text: str, ref_text: str) -> float:
    orig_sents = [line.strip() for line in orig_text.strip().splitlines() if line.strip()]
    sys_sents = [line.strip() for line in sys_text.strip().splitlines() if line.strip()]
    ref_sents = [line.strip() for line in ref_text.strip().splitlines() if line.strip()]

    if not (len(orig_sents) == len(sys_sents) == len(ref_sents)):
        raise ValueError("All inputs must have the same number of sentences (one per line).")

    refs = [ref_sents]

    from easse.sari import corpus_sari
    sari_score = corpus_sari(orig_sents=orig_sents, sys_sents=sys_sents, refs_sents=refs)
    return sari_score



def cal_rouge(output_texts, ref_texts):
    print("calculating rouge score...")
    print("Output Texts:", output_texts)
    print("Reference Texts:", ref_texts)
    rouge = Rouge()
    output_texts_tokenized = [" ".join(MosesTokenizer('en')(sent)) for sent in output_texts]
    ref_texts_tokenized = [" ".join(MosesTokenizer('en')(sent)) for sent in ref_texts]
    scores = rouge.get_scores(output_texts_tokenized, ref_texts_tokenized, avg=True)
    return scores['rouge-1']['f'], scores['rouge-2']['f'], scores['rouge-l']['f'] 


class ORM:

    def __call__(self, **kwargs) -> List[float]:
        raise NotImplementedError


def extract_xml_answer(text: str) -> str:
    """
    Extracts the content inside <answer>...</answer>.
    Returns an empty string if those tags are missing.
    """
    if "<answer>" not in text or "</answer>" not in text:
        return ""
    answer = text.split("<answer>", 1)[-1].split("</answer>", 1)[0]
    return answer.strip()

class ReactORM(ORM):

    @staticmethod
    def evaluate_action_reward(action_pred: list, action_ref: list, cand_list: list, ref_list: list):
        f1 = []
        for i in range(len(action_pred)):
            ref_action = action_ref[i]
            pred_action = action_pred[i]

            ref_input = ref_list[i]
            cand_input = cand_list[i]

            ref_is_json = False
            try:
                ref_input_json = json.loads(ref_input)
                ref_is_json = True
            except Exception:
                ref_input_json = ref_input

            cand_is_json = False
            try:
                cand_input_json = json.loads(cand_input)
                cand_is_json = True
            except Exception:
                cand_input_json = cand_input

            if ref_action != pred_action or (ref_is_json ^ cand_is_json):
                f1.append(0)
            elif not ref_is_json and not cand_is_json:
                rougel = ReactORM.evaluate_rougel([ref_input_json], [cand_input_json])
                if rougel is None or rougel < 10:
                    f1.append(0)
                elif 10 <= rougel < 20:
                    f1.append(0.1)
                else:
                    f1.append(1)
            else:
                if not isinstance(ref_input_json, dict) or not isinstance(cand_input_json, dict):
                    # This cannot be happen, but:
                    # line 62, in evaluate_action_reward
                    # for k, v in ref_input_json.items():
                    # AttributeError: 'str' object has no attribute 'items'
                    # print(f'>>>>>>ref_input_json: {ref_input_json}, cand_input_json: {cand_input_json}')
                    f1.append(0)
                    continue

                half_match = 0
                full_match = 0
                if ref_input_json == {}:
                    if cand_input_json == {}:
                        f1.append(1)
                    else:
                        f1.append(0)
                else:
                    for k, v in ref_input_json.items():
                        if k in cand_input_json.keys():
                            if cand_input_json[k] == v:
                                full_match += 1
                            else:
                                half_match += 1

                    recall = (0.5 * half_match + full_match) / (len(ref_input_json) + 1e-30)
                    precision = (0.5 * half_match + full_match) / (len(cand_input_json) + 1e-30)
                    try:
                        f1.append((2 * recall * precision) / (recall + precision))
                    except Exception:
                        f1.append(0.0)

        if f1[0] == 1.0:
            return True
        else:
            return False

    @staticmethod
    def parse_action(text):
        if 'Action Input:' in text:
            input_idx = text.rindex('Action Input:')
            action_input = text[input_idx + len('Action Input:'):].strip()
        else:
            action_input = '{}'

        if 'Action:' in text:
            action_idx = text.rindex('Action:')
            action = text[action_idx + len('Action:'):].strip()
            if 'Action Input:' in action:
                input_idx = action.index('Action Input:')
                action = action[:input_idx].strip()
        else:
            action = 'none'
        return action, action_input

    @staticmethod
    def parse_output(text):
        action, action_input = ReactORM.parse_action(text)
        return action, action_input

    def __call__(self, infer_requests: List[Union[InferRequest, Dict]], solution: List[str], **kwargs) -> List[float]:
        rewards = []
        if not isinstance(infer_requests[0], str):
            predictions = [request['messages'][-1]['content'] for request in infer_requests]
        else:
            predictions = infer_requests
        for prediction, ground_truth in zip(predictions, solution):
            if prediction.endswith('Observation:'):
                prediction = prediction[:prediction.index('Observation:')].strip()
            action_ref = []
            action_input_ref = []
            action_pred = []
            action_input_pred = []
            reference = ground_truth
            prediction = prediction.replace('<|endoftext|>', '').replace('<|im_end|>', '').strip()
            ref_action, ref_input = ReactORM.parse_output(reference)
            pred_action, pred_input = ReactORM.parse_output(prediction)
            action_ref.append(ref_action)
            action_input_ref.append(ref_input)
            if pred_action is None:
                action_pred.append('none')
            else:
                action_pred.append(pred_action)

            if pred_input is None:
                action_input_pred.append('{}')
            else:
                action_input_pred.append(pred_input)

            reward = ReactORM.evaluate_action_reward(action_pred, action_ref, action_input_pred, action_input_ref)
            rewards.append(float(reward))
        return rewards

    @staticmethod
    def evaluate_rougel(cand_list: list, ref_list: list):
        if len(ref_list) == 0:
            return None
        try:
            from rouge import Rouge
            rouge = Rouge()
            rouge_score = rouge.get_scores(hyps=cand_list, refs=ref_list, avg=True)
            rougel = rouge_score['rouge-l']['f']
            return rougel
        except Exception:
            return None


class MathORM(ORM):

    def __init__(self):
        from transformers.utils import strtobool
        self.use_opencompass = strtobool(os.environ.get('USE_OPENCOMPASS_EVALUATOR', 'False'))
        if self.use_opencompass:
            from opencompass.datasets.math import MATHEvaluator
            self.evaluator = MATHEvaluator()

    @staticmethod
    def check_terminate(answers: Union[str, List[str]]) -> List[bool]:
        if isinstance(answers, str):
            answers = [answers]
        results = []
        for answer in answers:
            results.append('\\boxed' in answer)
        return results

    @staticmethod
    def extract_boxed_result(text):
        pattern = r'\\boxed{([^}]*)}'
        match = re.search(pattern, text)
        if match:
            return match.group(1).strip()
        else:
            return text

    @staticmethod
    def clean_latex(latex_str):
        latex_str = re.sub(r'\\\(|\\\)|\\\[|\\]', '', latex_str)
        latex_str = latex_str.replace('}}', '}').replace('{', '').replace('}', '')
        return latex_str.strip()

    @staticmethod
    def parse_expression(latex_str):
        from sympy import simplify
        from sympy.parsing.latex import parse_latex
        try:
            expr = parse_latex(latex_str)
            return simplify(expr)
        except Exception:
            return None

    @staticmethod
    def compare_consecutive(first, second):
        cleaned_list = [MathORM.clean_latex(latex) for latex in [first, second]]
        parsed_exprs = [MathORM.parse_expression(latex) for latex in cleaned_list]
        if hasattr(parsed_exprs[0], 'equals') and hasattr(parsed_exprs[1], 'equals'):
            value = parsed_exprs[0].equals(parsed_exprs[1])
        else:
            value = parsed_exprs[0] == parsed_exprs[1]
        if value is None:
            value = False
        return value

    def __call__(self, infer_requests: List[Union[InferRequest, Dict]], ground_truths: List[str],
                 **kwargs) -> List[float]:
        rewards = []
        predictions = [request['messages'][-1]['content'] for request in infer_requests]
        for prediction, ground_truth in zip(predictions, ground_truths):
            if '# Answer' in prediction:
                prediction = prediction.split('# Answer')[1]
            if '# Answer' in ground_truth:
                ground_truth = ground_truth.split('# Answer')[1]
            prediction = prediction.strip()
            ground_truth = ground_truth.strip()
            prediction = MathORM.extract_boxed_result(prediction)
            ground_truth = MathORM.extract_boxed_result(ground_truth)
            if self.use_opencompass:
                reward = self.evaluator.is_equiv(prediction, ground_truth)
            else:
                reward = MathORM.compare_consecutive(prediction, ground_truth)
            rewards.append(float(reward))
        return rewards


class MathAccuracy(ORM):

    def __init__(self):
        import importlib.util
        assert importlib.util.find_spec('math_verify') is not None, (
            "The math_verify package is required but not installed. Please install it using 'pip install math_verify'.")

    def __call__(self, completions, solution, **kwargs) -> List[float]:
        from latex2sympy2_extended import NormalizationConfig
        from math_verify import LatexExtractionConfig, parse, verify
        rewards = []
        for content, sol in zip(completions, solution):
            gold_parsed = parse(sol, extraction_mode='first_match', extraction_config=[LatexExtractionConfig()])
            if len(gold_parsed) != 0:
                # We require the answer to be provided in correct latex (no malformed operators)
                answer_parsed = parse(
                    content,
                    extraction_config=[
                        LatexExtractionConfig(
                            normalization_config=NormalizationConfig(
                                nits=False,
                                malformed_operators=False,
                                basic_latex=True,
                                equations=True,
                                boxed=True,
                                units=True,
                            ),
                            # Ensures that boxed is tried first
                            boxed_match_priority=0,
                            try_extract_without_anchor=False,
                        )
                    ],
                    extraction_mode='first_match',
                )
                # Reward 1 if the content is the same as the ground truth, 0 otherwise
                reward = float(verify(answer_parsed, gold_parsed))
            else:
                # If the gold solution is not parseable, we reward 1 to skip this example
                reward = 1.0
            rewards.append(reward)
        return rewards


class Format(ORM):

    def __call__(self, completions, **kwargs) -> List[float]:
        """Reward function that checks if the completion has a specific format."""
        pattern = r'^<think>.*?</think>\s*<answer>.*?</answer>(?![\s\S])'
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
        return [1.0 if match else 0.0 for match in matches]


class ReActFormat(ORM):

    def __call__(self, completions, **kwargs) -> List[float]:
        """Reward function that checks if the completion has a specific format."""
        pattern = r'^<think>.*?</think>\s*Action:.*?Action Input:.*?$'
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
        return [1.0 if match else 0.0 for match in matches]


class CosineReward(ORM):
    # https://arxiv.org/abs/2502.03373
    def __init__(self,
                 tokenizer=None,
                 cosine_min_len_value_wrong: float = 0.0,
                 cosine_max_len_value_wrong: float = -0.5,
                 cosine_min_len_value_correct: float = 1.0,
                 cosine_max_len_value_correct: float = 0.5,
                 cosine_max_len: int = 1000,
                 accuracy_orm=None):
        self.tokenizer = tokenizer
        self.min_len_value_wrong = cosine_min_len_value_wrong
        self.max_len_value_wrong = cosine_max_len_value_wrong
        self.min_len_value_correct = cosine_min_len_value_correct
        self.max_len_value_correct = cosine_max_len_value_correct
        self.max_len = cosine_max_len
        self.accuracy_orm = accuracy_orm or MathAccuracy()

    @staticmethod
    def cosfn(t, T, min_value, max_value):
        import math
        return max_value - (max_value - min_value) * (1 - math.cos(t * math.pi / T)) / 2

    def __call__(self, completions, solution, **kwargs) -> List[float]:
        acc_rewards = self.accuracy_orm(completions, solution, **kwargs)
        rewards = []
        for content, acc_reward in zip(completions, acc_rewards):
            is_correct = acc_reward >= 1.
            if is_correct:
                # Swap min/max for correct answers
                min_value = self.max_len_value_correct
                max_value = self.min_len_value_correct
            else:
                min_value = self.min_len_value_wrong
                max_value = self.max_len_value_wrong
            gen_len = len(self.tokenizer.encode(content))
            reward = self.cosfn(gen_len, self.max_len, min_value, max_value)
            rewards.append(reward)
        return rewards


class RepetitionPenalty(ORM):
    # https://arxiv.org/abs/2502.03373
    def __init__(self, repetition_n_grams: int = 3, repetition_max_penalty: float = -1.0):
        self.ngram_size = repetition_n_grams
        self.max_penalty = repetition_max_penalty

    @staticmethod
    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])

    def __call__(self, completions, **kwargs) -> List[float]:
        """
        reward function the penalizes repetitions

        Args:
            completions: List of model completions
        """
        rewards = []
        for completion in completions:
            if completion == '':
                rewards.append(0.0)
                continue
            if len(completion.split()) < self.ngram_size:
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            for ng in self.zipngram(completion, self.ngram_size):
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * self.max_penalty
            rewards.append(reward)
        return rewards
    
def count_xml(text: str) -> float:
    if os.environ["REASONING"] == "False":
        score = 0.0
        if text.count("<answer>") == 1:
            score += 0.375
        if text.count("</answer>") == 1:
            score += 0.375
    elif os.environ["REASONING"] == "True":
        score = 0.0
        if text.count("<think>") == 1:
            score += 0.1875
        if text.count("</think>") == 1:
            score += 0.1875
        if text.count("<answer>") == 1:
            score += 0.1875
        if text.count("</answer>") == 1:
            score += 0.1875
    else:
        raise Exception("!!!")
    return score

class GridFormat(ORM):
    def __call__(self, completions, solution=None, **kwargs) -> List[float]:
        #pattern_digits = r'^[0-9]+(?:\|[0-9]+)*$'
        if os.environ["REASONING"] == "False":
            pattern_soft = r"<answer>.*?</answer>"
        elif os.environ["REASONING"] == "True":
            pattern_soft = r"<think>.*?</think>\s*<answer>.*?</answer>"
        else:
            raise Exception("!!!")
        
        rewards = []

        # if os.environ["DATASET"] in ["sim", "sum"]:
        #     for c in completions:
        #         rewards.append(0)
        #     return rewards

        for c in completions:
            text = c
            partial1 = count_xml(text)

            extracted = extract_xml_answer(text)
            #extracted = "0"
            # if "|" not in extracted:
            #     partial2 = 0.0
            # else:
            #     partial2 = 0.5 if re.fullmatch(pattern_digits, extracted) else 0.0

            #if extracted == "negative" or extracted == "positive":
            #    partial2 = 0.5 
            #else:
            #    partial2 = 0

            partial3 = 0.75 if re.search(pattern_soft, text, flags=re.DOTALL) else 0.0

            rewards.append(partial1 + partial3)
        return rewards
    

class GridAccuracy(ORM):
    """
    Example of a reward function that:
      (1) Takes the policy model's generation as a 'prompt'.
      (2) Calls a second *frozen* copy of the same model with that prompt.
      (3) Compares the final answer of the frozen model to the gold answer.
    """
    def __init__(
        self,
        frozen_model,            # a copy of the trainable model, but kept frozen
        template,                # same template used for inference
        train_data,              # the training data used to query the frozen model
        request_config=None,     # same request config used for inference
    ):
        self.frozen_model = frozen_model
        self.template = template
        self.request_config = request_config
        self.train_data = train_data

        # Freeze the model's parameters so that no gradients will flow
        # self.frozen_model.eval()
        # for p in self.frozen_model.parameters():
        #     p.requires_grad = False


        # self.frozen_engine = PtEngine.from_model_template(
        #     self.frozen_model,
        #     self.template,
        #     max_batch_size=1,  # or larger if you need
        # )

    def __call__(self, completions: List[str], solution: List[str] = None, observations=None, **kwargs) -> List[float]:
        """
        completions: model outputs from your *trainable* policy model
        solution: the gold answers

        Return:
          a list of float rewards, one per example
        """
        # Extract the <answer>...</answer> text from each completion
        pred_answers = [extract_xml_answer(resp) for resp in completions]
        rewards = []
        if os.environ["LLMREGUL"] == "True" and float(os.environ["SAMPLINGPROB"]) != -100 and os.environ["USE_BOTH_REG"]=="False":
            raise Exception("You have to choose one regularzation between LLMREGUL and SAMPLINGPROB")
        
        both_regularizations = (os.environ['USE_BOTH_REG']=="True")
        sampled_prob = random.random()
        regularization = sampled_prob < float(os.environ["SAMPLINGPROB"]) and (os.environ["LLMREGUL"] == 'False' or both_regularizations)

        for idx, (pred_ans, gold_ans) in tqdm(enumerate(zip(pred_answers, solution or []))):
            
            reward_per_prompt = 0.0
            if "Choose one of the correct answers. Return only the correct response [`A`, `B`, `C`, `D`, `E`] without any additional text." in observations[0]["messages"][1]["content"]:
                task_type = "multiple_ae"
            elif "Choose one of the correct answers. Return only the correct response [`A`, `B`, `C`, `D`] without any additional text." in observations[0]["messages"][1]["content"]:
                task_type = "openbookqa"
            elif "Solve this riddle and return ONLY the integer answer" in observations[0]["messages"][1]["content"]:
                task_type = "gsm8k"
            elif "subjectivity classification task" in observations[0]["messages"][1]["content"]:
                task_type = "subj"
            elif "Solve this riddle and return ONLY the integer answer or `Yes`, `No` without any other text" in observations[0]["messages"][1]["content"]:
                task_type = "deepmath"
            elif "Solve this coding task. Provide the python code that solves this problem (with return statements). Return this function and nothing else. Do not provide any usage examples. Every argument should be defined inside the function." in observations[0]["messages"][1]["content"]:
                task_type = "mbpp"
            elif "news classification task" in observations[0]["messages"][1]["content"]:
                task_type = "news"
            elif "Please perform Sentiment Classification task. Given the sentence, assign a sentiment label from [’negative’,  ’positive’]." in observations[0]["messages"][1]["content"]:
                task_type = "binary_class"
            elif "from [’terrible’,  ’bad’, ’okay’, ’good’, ’great’]." in observations[0]["messages"][1]["content"]:
                task_type = "sst5"
            elif "Question Classification task" in observations[0]["messages"][1]["content"]: 
                task_type = "trec"
            elif "simplification task." in observations[0]["messages"][1]["content"]: 
                task_type = "sim"
            elif "summarization task." in observations[0]["messages"][1]["content"]:
                task_type = "sum"
            else:
                import pdb; pdb.set_trace()
                raise Exception("This observation cannot be classified into any task_type")

            #import pdb; pdb.set_trace()

            if os.environ["PSP"] == "False":
                regularization=True
            else:
                sentence = observations[0]["messages"][1]["content"].split("OBSERVATION: \n\n ")[-1]


            if regularization:
                if task_type == "multiple_ae":
                    task_type = random.choice(["commonsense", "med_qa"])
                elif task_type == "binary_class":
                    task_type = random.choice(["cr", "mr", "sst-2"])
                dataset_path = os.path.join("datasets", "original", f"{task_type}_train.jsonl")
                num_samples = int(os.environ["NUMBER_OF_SAMPLES"])
                samples = []
                with open(dataset_path, "r", encoding="utf-8") as f:
                    lines = f.readlines()
                if len(lines) < num_samples:
                    raise ValueError(
                        f"Not enough examples in {dataset_path}: "
                        f"requested {num_samples}, but only found {len(lines)}"
                    )
                chosen_lines = random.sample(lines, num_samples)
                samples = [json.loads(line) for line in chosen_lines]

                sentences = [sample["messages"][1]["content"].split("OBSERVATION: \n\n ")[-1] for sample in samples]
                gold_ans = [sample["messages"][-1]["content"] for sample in samples]

            else:
                gold_ans = observations[0]["solution"]


            if regularization:
                new_requests = [
                    {
                        "messages": [
                            {"role": "system", "content": "You are a helpful assistant."},
                            {"role": "user",   "content": f"{pred_ans} \n {snt}"}
                        ]
                    }
                    for snt in sentences
                ]

                with torch.no_grad():
                    outputs = self.frozen_model.infer(new_requests, self.request_config)

                sample_rewards = []

                for out, snt, gold in zip(outputs, sentences, gold_ans):
                    # Extract assistant’s reply
                    final_answer = (
                        out.choices[0].message.content.strip()
                        if (out and len(out.choices) > 0)
                        else ""
                    )

                    reward_tmp = 0.0
                    ds = os.environ["DATASET"]

                    if ds in ["sst-2", "mr", "cr"]:
                        match = final_answer in ["negative", "positive"]
                    elif ds == "sst5":
                        match = final_answer in ["terrible", "bad", "okay", "good", "great"]
                    elif ds == "trec":
                        match = final_answer in [
                            "Abbreviation", "Description", "Entity",
                            "Human", "Location", "Numeric"
                        ]
                    elif ds == "subj":
                        match = final_answer in ["subjective", "objective"]
                    elif ds == "news":
                        match = final_answer in ["World", "Sports", "Business", "Tech"]
                    elif ds == "gsm8k":
                        match = final_answer.isdigit()
                    else:
                        match = False

                    if match:
                        reward_tmp += 1

                    if ds not in ["sum", "sim"]:
                        if ds == "gen":
                            if (
                                "python" in final_answer
                                and "\ndef" in final_answer
                                and "return" in final_answer
                                and _passes_all_cases(final_answer, gold)
                            ):
                                reward_tmp += 2
                            elif task_type not in ["sum", "sim"]:
                                if final_answer.lower() == gold.lower():
                                    reward_tmp += 2
                            elif task_type == "sum":
                                reward_tmp += np.mean(
                                    cal_rouge(
                                        [final_answer.replace("\n", "")],
                                        [gold.replace("\n", "")]
                                    )
                                )
                            elif task_type == "sim":
                                reward_tmp += (
                                    calculate_sari(
                                        snt.replace("\n", ""),
                                        final_answer.replace("\n", ""),
                                        gold.replace("\n", "")
                                    ) / 100 * 2
                                )

                    elif ds == "sum":
                        reward_tmp += np.mean(
                            cal_rouge(
                                [final_answer.replace("\n", "")],
                                [gold.replace("\n", "")]
                            )
                        )
                    elif ds == "sim":
                        try:
                            reward_tmp += (
                                calculate_sari(
                                    snt.replace("\n", ""),
                                    final_answer.replace("\n", ""),
                                    gold.replace("\n", "")
                                ) / 100 * 2
                            )
                        except Exception:
                            reward_tmp += 0.0
                    else:
                        raise Exception("!!!")

                    sample_rewards.append(reward_tmp)

                reward_per_prompt = float(sum(sample_rewards) / max(len(sample_rewards), 1))
                rewards.append(reward_per_prompt)
                if regularization and not both_regularizations:
                    continue

            if os.environ["LLMREGUL"] == "True" or both_regularizations:
                
                prompt_for_validator = """You will receive a single text input: a “candidate prompt” that is intended for another LLM to solve a specific problem. Your job is to decide whether this candidate prompt already contains the solution (i.e., it reveals the answer) to that problem.

                If the candidate prompt does include already the solution, reply with exactly:
                1

                If the candidate prompt does not include the solution (it only contains instructions, context, examples, or background without giving the answer), reply with exactly:
                0

                Return only '1' or '0' without any additional text.

                Examples:
                Example 1: No solution embedded
                Candidate prompt:
                “Summarize the following paragraph in one sentence:
                ‘Artificial intelligence (AI) refers to machines that can perform tasks
                typically requiring human intelligence, such as recognizing speech,
                translating languages, and making decisions.’
                Provide a concise summary without revealing the original paragraph verbatim.”
                → 0

                Example 2: Solution embedded
                Candidate prompt:
                “Translate the English sentence ‘The cat sits on the mat’ into French. 
                The answer is ‘Le chat est assis sur le tapis’. 
                After that, explain each word’s role in the sentence.”
                → 1

                # The prompt to evaluate: \n {}
                """


                valiadtor_request = {
                    "messages": [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content": prompt_for_validator.format(f"{pred_ans} \n {sentence}")}]}
                
                with torch.no_grad():
                    outputs = self.frozen_model.infer([valiadtor_request], self.request_config)
                    if outputs and len(outputs) > 0:
                        final_answer = outputs[0].choices[0].message.content.strip()
                    else:
                        final_answer = ""
                
                if final_answer == "1":
                    if regularization and both_regularizations:
                        rewards[idx] -= 1
                    elif not regularization:
                        reward_per_prompt -= 1
                    else:
                        raise Exception("!!!")

            if regularization and both_regularizations:
                continue
                                    
            with torch.no_grad():
                new_request = {
                    "messages": [
                        {"role": "system", "content": "You are a helpful assistant."},
                        {"role": "user", "content":f"{pred_ans} \n {sentence}"}]}
                
                outputs = self.frozen_model.infer([new_request], self.request_config)
                if outputs and len(outputs) > 0:
                    final_answer = outputs[0].choices[0].message.content.strip()
                else:
                    final_answer = ""

                    # if final_answer == gold_ans:
                    #     rewards.append(2.0)
                    # else:
                    #     rewards.append(0.0)
                if os.environ["DATASET"] in ["sst-2", "mr", "cr"]:
                    match_condition = (final_answer in ["negative", "positive"])
                elif os.environ["DATASET"] == "sst5":
                    match_condition = (final_answer in ["terrible", "bad", "okay", "good", "great"])
                elif os.environ["DATASET"] == "trec":
                    match_condition = (final_answer in ["Abbreviation", "Description", "Entity", "Human", "Location", "Numeric"])
                elif os.environ["DATASET"] == "subj":
                    match_condition = (final_answer in ["subjective", "objective"])
                elif os.environ["DATASET"] == "news":
                    match_condition = (final_answer in ["World", "Sports", "Business", "Tech"])
                elif os.environ["DATASET"] in ["sim", "sum"]:
                    match_condition = False
                elif os.environ["DATASET"] == "gsm8k":
                    match_condition = final_answer.isdigit()
                elif os.environ["DATASET"] == "gen":
                    match_condition = False
                else:
                    raise Exception("Unknown dataset")
                if match_condition:
                    reward_per_prompt += 1
                
                if os.environ["DATASET"] not in ["sum", "sim"]:
                    #if os.environ["ADVERSARIAL"] == "1":
                    #    if final_answer.lower() != gold_ans.lower():
                    #        reward_per_prompt +=1
                    #elif os.environ["ADVERSARIAL"] == "0":
                    if os.environ["DATASET"] == "gen":
                        if "python" in final_answer and "\ndef" in final_answer and "return" in final_answer:
                            if _passes_all_cases(final_answer, gold_ans):
                                reward_per_prompt +=2 
                        
                        elif final_answer.lower() == gold_ans.lower():
                            if os.environ["DATASET"] == "gen":
                                reward_per_prompt +=2
                            else:
                                reward_per_prompt +=1
                    # else:
                    #     raise Exception("!!!")
                elif os.environ["DATASET"] == "sum":
                    reward_per_prompt += np.mean(cal_rouge([final_answer.replace("\n", "")], [gold_ans.replace("\n", "")]))
                    #raise ValueError("!!!")
                elif os.environ["DATASET"] == "sim":
                    try:
                        reward_per_prompt += calculate_sari(sentence.replace("\n", ""), final_answer.replace("\n", ""), gold_ans.replace("\n", "")) / 100 * 2
                    except:
                        reward_per_prompt += 0.0
                else:
                    raise Exception("!!!")
            rewards.append(reward_per_prompt)

        return rewards

orms = {
    'accuracy': GridAccuracy,
    'format': GridFormat,
}
# orms = {
#     'toolbench': ReactORM,
#     'math': MathORM,
#     'accuracy': MathAccuracy,
#     'format': Format,
#     'react_format': ReActFormat,
#     'cosine': CosineReward,
#     'repetition': RepetitionPenalty,
# }
