# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import openai
import asyncio
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor
import concurrent.futures
import requests
from tqdm import tqdm
from functools import partial
from typing import Any, List, Dict, Tuple, Optional
from tqdm.asyncio import tqdm_asyncio
from transformers import PreTrainedTokenizer
import numpy as np
import pandas as pd
from tqdm.asyncio import tqdm_asyncio
import ast
import re
import json
import random
import torch
import difflib
import fcntl  # 导入文件锁模块

from verl import DataProto
from verl.utils.reward_score import _default_compute_score
from verl.utils.reward_score.our_code.sandbox_eval.sandbox_utils import RunCodeRequest, SubmitRequest, TestConfig, submit_to_sandbox


# exec reward:
"""
Exec reward:
    no answer: -5
    compilation error: -2
    test case validation: 0-5
Format reward:
    -1, 1
"""
REWARD_CE = -2
REWARD_NON_ANSWER = -2




RUN_TIMEOUT = 10
MAX_REQUESTS = 64
LOC_MAX_REQUESTS = 256

def normalize_code(code_str):
    try:
        # Parse the code into an AST
        tree = ast.parse(code_str)

        # Dictionary to store variable name mappings
        var_counter = 0
        var_map = {}

        # AST transformer to rename variables
        class VariableRenamer(ast.NodeTransformer):
            def visit_Name(self, node):
                nonlocal var_counter
                if isinstance(node.ctx, ast.Store):
                    if node.id not in var_map:
                        var_map[node.id] = f"v_{var_counter}"
                        var_counter += 1
                return ast.Name(id=var_map.get(node.id, node.id), ctx=node.ctx)

        # Apply the transformation
        transformed = VariableRenamer().visit(tree)

        # Convert back to string with normalized formatting
        normalized = ast.unparse(transformed)
        return normalized

    except SyntaxError:
        # If parsing fails, return original code
        return code_str



class ErrorLocRewardManager:
    """
    """

    def __init__(
        self,
        config,
        file,  # train_file or test_file
        tokenizer: PreTrainedTokenizer,
        num_examine=0,
        run_all_cases=True,
    ):
        self.config = config
        self.file = file
        self.tokenizer = tokenizer
        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console
        self.run_all_cases = run_all_cases
        self.client = openai.Client(base_url=self.config.reward_model.error_local_model_url, api_key="None")
        
        
        
        
        if isinstance(self.file, str):
            self.dataset = pd.read_parquet(self.file)
        else:
            self.dataset = pd.concat([pd.read_parquet(f) for f in self.file])

        self.dataset["proxy_id"] = self.dataset["task_id"]
        self.id_to_infos = self.dataset.set_index("proxy_id").to_dict(orient="index")

        # implement code hash so that we can reduce sandbox usage
        self.code_to_reward = {}
        self.executor = ThreadPoolExecutor(max_workers=32)  # 定义最大并发线程数
    
    def extract_question(self, prompt_str: str) -> Tuple[Optional[str], str]:
       
        # Split prompt_str to isolate question string
        if "<|im_start|>user" in prompt_str:
            system_prompt = prompt_str.split("<|im_start|>user", 1)[0]
            question_str = prompt_str.split("<|im_start|>user", 1)[1]
        elif "<|start_header_id|>user<|end_header_id|>" in prompt_str:
            system_prompt = prompt_str.split("<|start_header_id|>user<|end_header_id|>", 1)[0]
            question_str = prompt_str.split("<|start_header_id|>user<|end_header_id|>", 1)[1]
        else:
            question_str = prompt_str
        return question_str
    
    def desanitize(self, text: str) -> str:
        # Pattern to match code blocks starting with ```, with optional language identifier
        # and capturing everything after until the end or until another ```
        pattern = r"```(?:python)?\s*([\s\S]*?)(?:\s*```|$)"
        # Find all matches in the text
        matches = re.findall(pattern, text, re.IGNORECASE)

        # Return the first one
        return f"```python\n{matches[0]}\n```" if matches and len(matches[0]) > 0 else text
    
    def sanitize(self, text: str) -> str:
        # Remove the starting and ending ```
        pattern = r"```(?:python)?\s*([\s\S]*?)\s*```"
        match = re.search(pattern, text, re.IGNORECASE)

        if match:
            return match.group(1).strip()
        return text
    
    def get_samples(self, ids: List[str]) -> List[pd.Series]:
        
        samples = []
        for task_id in ids:
            sample = self.id_to_infos[task_id]
            samples.append(sample)
        
        print("*" * 40)
        print(f"Len of samples:  {len(samples)}")
        print("*" * 40)
        
        for i in range(len(samples)):
            sample = samples[i]
            if i < self.num_examine:
                print("*" * 40)
                print("[TESTS: ]", sample["selected_uts"])
                print("*" * 40)
        
        return samples
        
    def check_ce(self, code_str: str) -> bool:
        
        if not isinstance(code_str, str):
            return True
        try:
            ast.parse(code_str)
            return False
        except:
            return True
        
    
    def parse_response(self, data: DataProto) -> DataProto:
        task_ids = []
        questions = []
        responses = []
        samples = []
        valid_response_lengths = []
        valid_response_idss = []

        for i in range(len(data)):
            data_item = data[i]
            
            prompt_ids = data_item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]
            valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]
            
            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids)
            response_str = self.tokenizer.decode(valid_response_ids)
            
            # # remove <eos> token
            # response_str = response_str.replace(self.tokenizer.eos_token, "")
            
            question = self.extract_question(prompt_str)
            task_id = data_item.non_tensor_batch["task_id"]

            
            task_ids.append(task_id)
            questions.append(question)
            responses.append(response_str)
            samples.append(data_item.non_tensor_batch)
            valid_response_lengths.append(valid_response_length)
            valid_response_idss.append(valid_response_ids)
            
            
        return task_ids, questions, responses, samples, valid_response_lengths, valid_response_idss
        
        
    def build_message(self, question, response):
        
        # generate revision
        assert isinstance(question, str)
        assert isinstance(response, str)
        
        response_code_block = self.sanitize(response)
        
        
        prompt = f"""I have an algorithm competition problem and the corresponding python solution, but I have verified that the solution is wrong, and I would like you to point out on which line the error occurs.

### Problem
{question}

### Python Solution
{response_code_block}

Please note that your formatting must adhere to the following rules.
1. Select the only line of code that you think is most likely to cause an error in your code.
2. You need to give the line number of the error line, enclosed in $$. 
3. You need to give the contents of the wrong line of code, enclosed in a python code block.
4. You don't need to give anything to understand, analyze, or correct the error.

### Example Output
```python
Error Code Line
```
""" 
        
        message = [{
            "role": "user",
            "content": prompt
        }]
        
        return message
    

    
    
    
    # def calculate_format_reward(self, processed_str: str, format_reward: int = 1):
    #     """Performs comprehensive validation of response structure.

    #     Args:
    #         processed_str: Processed response string from the model

    #     Returns:
    #         Boolean indicating whether all formatting requirements are met
    #     """
    #     debug_str = []
    #     debug_str.append("\n[Structure Validation]")
    #     validation_passed = True

    #     # 检查唯一标签的位置
    #     unique_tags = {
    #         'answer': '<answer>',
    #         'answer_end': '</answer>'
    #     }

    #     positions = {}
    #     for tag_name, tag_str in unique_tags.items():
    #         count = processed_str.count(tag_str)
    #         positions[tag_name] = pos = processed_str.find(tag_str)
            
    #         debug_str.append(f"  {tag_str}: count={count}, position={pos}")
            
    #         if count != 1:
    #             debug_str.append(f"  [Error] {tag_str} appears {count} times (expected 1)")
    #             validation_passed = False

    #     # 验证基本顺序（think在开头，answer在结尾）
    #     if processed_str.strip()[0:len("<think>")] != "<think>":
    #         debug_str.append("  [Error] Incorrect start: Expected <think> at beginning")
    #         validation_passed = False
    #     elif positions['answer'] > positions['answer_end']:
    #         debug_str.append("  [Error] Incorrect tag order: Expected <answer>...</answer>")
    #         validation_passed = False
    #     elif not (processed_str.strip().endswith("</answer><|endoftext|>") or 
    #             processed_str.strip().endswith("</answer><|im_end|>") or
    #             processed_str.strip().endswith("</answer><|eot_id|>") or 
    #             processed_str.strip().endswith("</answer><|end_of_text|>")
    #             ):
    #         debug_str.append("  [Error] Incorrect ending: Expected </answer><|endoftext|> or </answer><|im_end|> or </answer><|eot_id|> or </answer><|end_of_text|>")
    #         validation_passed = False


    #     # 验证step和code的数量和顺序
    #     think_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<think>', i)]
    #     think_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</think>', i)]
    #     code_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<code>', i)]
    #     code_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</code>', i)]

    #     # 记录多次出现的标签信息
    #     debug_str.append(f"  <think> count: {len(think_positions)}")
    #     debug_str.append(f"  </think> count: {len(think_end_positions)}")
    #     debug_str.append(f"  <code> count: {len(code_positions)}")
    #     debug_str.append(f"  </code> count: {len(code_end_positions)}")

    #     # TODO: 暂不要求<think> 和 <code> 数量一致
    #     # if len(think_positions) != len(code_positions):
    #     #     debug_str.append(f"  [Error] Number of <step> ({len(think_positions)}) does not match <code> ({len(code_positions)})")
    #     #     validation_passed = False
    #     # else:
    #     #     for i in range(len(think_positions)):
    #     #         if think_positions[i] > code_positions[i]:
    #     #             debug_str.append(f"  [Error] <step> at position {think_positions[i]} comes after <code> at position {code_positions[i]}")
    #     #             validation_passed = False

    #     # 验证每对标签的配对
    #     if len(think_end_positions) != len(think_positions):
    #         debug_str.append(f"  [Error] Number of </step> ({len(think_end_positions)}) does not match <step> ({len(think_positions)})")
    #         validation_passed = False

    #     if len(code_end_positions) != len(code_positions):
    #         debug_str.append(f"  [Error] Number of </code> ({len(code_end_positions)}) does not match <code> ({len(code_positions)})")
    #         validation_passed = False

    #     format_score = format_reward if validation_passed else -abs(format_reward)

    #     return format_score, debug_str
    

    def calculate_format_reward(self, processed_str: str, format_reward: int = 1):
        """Performs comprehensive validation of response structure.

        Args:
            processed_str: Processed response string from the model

        Returns:
            Boolean indicating whether all formatting requirements are met
        """
        debug_str = []
        debug_str.append("\n[Structure Validation]")
        validation_passed = True

        # 检查唯一标签的位置
        unique_tags = {
            'answer': '<answer>',
            'answer_end': '</answer>'
        }

        positions = {}
        for tag_name, tag_str in unique_tags.items():
            count = processed_str.count(tag_str)
            positions[tag_name] = pos = processed_str.find(tag_str)
            
            debug_str.append(f"  {tag_str}: count={count}, position={pos}")
            
            if count != 1:
                debug_str.append(f"  [Error] {tag_str} appears {count} times (expected 1)")
                validation_passed = False

        # 验证基本顺序（think在开头，answer在结尾）
        if processed_str.strip()[0:len("<think>")] != "<think>":
            debug_str.append("  [Error] Incorrect start: Expected <think> at beginning")
            validation_passed = False
        elif positions['answer'] > positions['answer_end']:
            debug_str.append("  [Error] Incorrect tag order: Expected <answer>...</answer>")
            validation_passed = False
        elif not (processed_str.strip().endswith("</answer><|endoftext|>") or 
                processed_str.strip().endswith("</answer><|im_end|>") or
                processed_str.strip().endswith("</answer><|eot_id|>") or 
                processed_str.strip().endswith("</answer><|end_of_text|>")
                ):
            debug_str.append("  [Error] Incorrect ending: Expected </answer><|endoftext|> or </answer><|im_end|> or </answer><|eot_id|> or </answer><|end_of_text|>")
            validation_passed = False

        # 验证step和code的数量和顺序
        think_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<think>', i)]
        think_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</think>', i)]
        code_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<code>', i)]
        code_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</code>', i)]

        # 记录多次出现的标签信息
        debug_str.append(f"  <think> count: {len(think_positions)}")
        debug_str.append(f"  </think> count: {len(think_end_positions)}")
        debug_str.append(f"  <code> count: {len(code_positions)}")
        debug_str.append(f"  </code> count: {len(code_end_positions)}")

        # 验证每对标签的配对
        if len(think_end_positions) != len(think_positions):
            debug_str.append(f"  [Error] Number of </think> ({len(think_end_positions)}) does not match <think> ({len(think_positions)})")
            validation_passed = False

        if len(code_end_positions) != len(code_positions):
            debug_str.append(f"  [Error] Number of </code> ({len(code_end_positions)}) does not match <code> ({len(code_positions)})")
            validation_passed = False

        # 新增逻辑：检查 <answer> 和 </answer> 之间的内容是否包含 Python 代码块
        answer_start = processed_str.find('<answer>')
        answer_end = processed_str.find('</answer>')

        if answer_start != -1 and answer_end != -1:
            answer_content = processed_str[answer_start + len('<answer>'):answer_end]
            if "```python" not in answer_content and "```" not in answer_content:
                debug_str.append("  [Error] Python code block ' ```python ``` ' not found in <answer>...</answer>")
                validation_passed = False
            else:
                debug_str.append("  [Success] Python code block ' ```python ``` ' found in <answer>...</answer>")
        else:
            debug_str.append("  [Error] <answer> or </answer> not found, skipping Python code block check")
            validation_passed = False

        format_score = format_reward if validation_passed else -abs(format_reward)

        return format_score, debug_str
    


    ############################  oj reward   ############################
    def get_reward_all_oj(self, responses: list, samples: list, global_step: int = -1, batch_size: int = 16, max_parallel_threads: int = 16) -> list:
        exec_rewards = [REWARD_NON_ANSWER] * len(responses)
        total_batches = (len(responses) + batch_size - 1) // batch_size
        
        with ThreadPoolExecutor(max_workers=max_parallel_threads) as executor:
            future_to_index = {}

            for batch_num in range(total_batches):
                start_index = batch_num * batch_size
                end_index = min(start_index + batch_size, len(responses))
                responses_batch = responses[start_index:end_index]
                samples_batch = samples[start_index:end_index]
                future = executor.submit(self.process_batch, responses_batch, samples_batch)
                future_to_index[future] = (start_index, end_index)

            for future in as_completed(future_to_index):
                start_index, end_index = future_to_index[future]
                try:
                    batch_rewards = future.result()
                    exec_rewards[start_index:end_index] = batch_rewards
                except Exception as exc:
                    print(f"Generated an exception: {exc}")
                    
        
        # format_reward calculation using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=512) as executor:
            format_rewards, debug_infos = zip(*list(executor.map(self.calculate_format_reward, responses)))
            format_rewards = list(format_rewards)
            debug_infos = list(debug_infos)
            
        # total_reward calculation
        total_rewards = [
            exec_reward + format_reward
            for exec_reward, format_reward in zip(exec_rewards, format_rewards)
        ]
                
        for i in range(len(samples)):
            if i < self.num_examine:
                print("*" * 40)
                # print("[SUBMISSIONS: ]", all_submissions)
                print("[GLOBAL_STEP: ]", global_step)
                print("[TASK_ID: ]", samples[i]["task_id"])
                print("[PROMPT: ]\n", samples[i]["problem"])
                print("[RESPONSE: ]\n", responses[i])
                print("-" * 20)
                print("\n".join(debug_infos[i]))
                print(f"  Format: {format_rewards[i]}")
                print(f"  Answer: {exec_rewards[i]}")
                print(f"  Total: {total_rewards[i]}")
                print("-" * 20)
                print("*" * 40)

        return total_rewards, exec_rewards, format_rewards            
                        
    def process_batch(self, responses_batch: list, samples_batch: list) -> list:
        batch_rewards = []
        all_submissions = []
        index_map = {}

        for index, (response, sample) in enumerate(zip(responses_batch, samples_batch)):
            answer_pattern = r'<answer>(.*?)</answer>'
            matches = list(re.finditer(answer_pattern, response, re.DOTALL))
            if not matches:
                batch_rewards.append(REWARD_NON_ANSWER)
                # print("No Matcher REWARD_NON_ANSWER!!!")
                continue

            answer = matches[-1].group(1).strip()
            code_block = self.sanitize(answer).strip()
            if '\x00' in code_block:
                code_block = code_block.replace('\x00', '')

            if self.check_ce(code_block):
                batch_rewards.append(REWARD_CE)
                # print("REWARD_CE!!!")
                continue

            submissions = self.build_oj_submissions(code_block, sample)

            if not submissions:
                batch_rewards.append(REWARD_NON_ANSWER)
                # print("No submissions REWARD_NON_ANSWER!!!")
                continue

            all_submissions.extend(submissions)
            index_map[index] = len(submissions)

        response_results = self.submit_batch(all_submissions)

        pos = 0
        for index, num_submissions in index_map.items():
            result_slice = response_results[pos:pos + num_submissions]
            pos += num_submissions
            batch_rewards.append(5 * np.mean(result_slice))

        return batch_rewards

    def build_oj_submissions(self, code_str: str, sample: dict) -> list:
        sample_size = 10
        submissions = []

        def build_submission(code_str, output_str=None, input_str=None) -> dict:
            submission = {
                "type": "python",
                "solution": code_str,
            }
            if output_str is not None:
                submission["expected_output"] = output_str
            if input_str is not None:
                submission["input"] = input_str
            return submission

        if sample["prompter_type"] == "mbppplus":
            uts = sample["selected_uts"].tolist()
            if len(uts) > sample_size:
                uts = random.sample(uts, sample_size)
            
            if self.run_all_cases:
                for ut in uts:
                    code_str = code_str + "\n" + ut
                    submissions.append(build_submission(code_str))
            else:
                code_str = code_str + "\n" + uts[0]
                submissions.append(build_submission(code_str))
                
        elif sample["prompter_type"] == "apps_fn":
            # 确保 sample["selected_uts"] 是字符串类型
            # if isinstance(sample["selected_uts"], np.ndarray):
            #     # 将 ndarray 转换为列表并序列化为 JSON 字符串
            #     uts = sample["selected_uts"].tolist()
            # else:
            uts = json.loads(sample["selected_uts"])
            
            if len(uts) > sample_size:
                uts = random.sample(uts, sample_size)
        
            if self.run_all_cases:
                for ut in uts:
                    code_str = code_str + "\n" + ut
                    submissions.append(build_submission(code_str))
            else:
                code_str = code_str + "\n" + uts[0]
                submissions.append(build_submission(code_str))
                
        elif sample["prompter_type"] == "code_contests" or sample["prompter_type"] == "codeforces" or sample["prompter_type"] == "apps":
            # 确保 sample["selected_uts"] 是字符串类型
            if isinstance(sample["selected_uts"], np.ndarray):
                # 将 ndarray 转换为列表并序列化为 JSON 字符串
                uts = sample["selected_uts"].tolist()
            else:
                uts = json.loads(sample["selected_uts"])
            
            if len(uts) > sample_size:
                uts = random.sample(uts, sample_size)
        
            if self.run_all_cases:
                for ut in uts:
                    submissions.append(build_submission(code_str, ut['output']['stdout'], ut['input']['stdin']))
            else:
                submissions.append(build_submission(code_str, uts[0]['output']['stdout'], uts[0]['input']['stdin']))
        
        elif sample["prompter_type"] == "livecodebench":
            """
            sample = {
                "selected_uts": {
                    "input_output": {
                        "inputs": List,
                        "outputs: List,
                        "fn_name": str or None,
                    }
                }
            }
            """
            selected_uts = json.loads(sample["selected_uts"])
            input_output = json.loads(selected_uts["input_output"])
            
            assert len(input_output["inputs"]) == len(input_output["outputs"])
            uts = list(zip(input_output["inputs"], input_output["outputs"]))
            
            if len(uts) > sample_size:
                uts = random.sample(uts, sample_size)
            
            fn_name = input_output["fn_name"]
            if fn_name is not None:
                def create_function_call_str(func_name, args_list):
                    args_str = ", ".join(repr(arg) for arg in args_list)
                    return f"{func_name}({args_str})"
                
                if self.run_all_cases:
                    for stdin, stdout in uts:
                        suffix = f"solution = Solution()\nassert {create_function_call_str(fn_name, stdin)} == {repr(stdout)}"
                        submissions.append(build_submission(code_str + "\n" + suffix))
                else:
                    stdin, stdout = uts[0]
                    suffix = f"solution = Solution()\nassert {create_function_call_str(fn_name, stdin)} == {repr(stdout)}"
                    submissions.append(build_submission(code_str + "\n" + suffix))
            else:
                if self.run_all_cases:
                    for stdin, stdout in uts:
                        submissions.append(build_submission(code_str, stdout, stdin))
                else:
                    stdin, stdout = uts[0]
                    submissions.append(build_submission(code_str, stdout, stdin))

        else:
                raise RuntimeError("Invalid prompter_type!")

        return submissions            
                    
        
        
        # uts = sample.get("selected_uts", [])
        
        # #TODO: ndarray to list
        # if sample["prompter_type"] == "mbppplus":
        #     uts = uts.tolist()

        # if not isinstance(uts, list):
        #     try:
        #         uts = ast.literal_eval(uts)
        #     except (ValueError, SyntaxError):
        #         uts = []
        # if not uts:
        #     return []
        # sample_size = 10
        # if len(uts) > sample_size:
        #     uts = random.sample(uts, sample_size)

        # if sample["prompter_type"] == "mbppplus":
        #     if self.run_all_cases:
        #         for ut in uts:
        #             code_str = code_str + "\n" + ut
        #             submissions.append(build_submission(code_str))
        #     else:
        #         code_str = code_str + "\n" + uts[0]
        #         submissions.append(build_submission(code_str))
        # elif sample["prompter_type"] == "code_contests":
        #     if self.run_all_cases:
        #         for ut in uts:
        #             submissions.append(build_submission(code_str, ut['output']['stdout'], ut['input']['stdin']))
        #     else:
        #         submissions.append(build_submission(code_str, uts[0]['output']['stdout'], uts[0]['input']['stdin']))
        # else:
        #     raise RuntimeError("Invalid prompter_type!")

        # return submissions


    def submit_batch(self, submissions: list) -> list:
        data = {
            "type": "batch",
            "submissions": submissions
        }

        def write_data_to_json(file_path, data):
            try:
                with open(file_path, 'a') as f:
                    # 获取文件锁
                    fcntl.flock(f.fileno(), fcntl.LOCK_EX)
                    try:
                        json.dump(data, f, indent=4)
                        f.write('\n')  # 添加换行符以分隔不同的记录
                    finally:
                        # 释放文件锁
                        fcntl.flock(f.fileno(), fcntl.LOCK_UN)
            except IOError as e:
                print(f"Failed to write to file: {e}")

        try:
            response = requests.post(self.config.reward_model.url, json=data)
            response.raise_for_status()

            results = response.json()['results']
            success_list = [res['success'] for res in results]
            assert len(success_list) == len(submissions)
            return success_list
        except requests.exceptions.RequestException as e:
            file_path = "/home/superbench/xinzhang3/haoling/epicoder2/submissions.json"
            write_data_to_json(file_path, data)
            print(f"Request failed: {e}")
            return [False] * len(submissions)
        except (ValueError, KeyError, AssertionError) as e:
            print(f"Failed to process response: {e}")
            return [False] * len(submissions)


    
    
    
    
    
    
    

    ############################  error localization  ############################
    def get_error_localization(self, questions: List[str], responses: List[str], valid_response_idss: List[Any],
                        max_parallel_threads: int = LOC_MAX_REQUESTS) -> List[Any]:
        """
        """
        def preprocess(response: str) -> bool:
            """
            check if need error localization
            True: valid code
            False: invalid code
            """
            answer_pattern = r'<answer>(.*?)</answer>'
            matches = list(re.finditer(answer_pattern, response, re.DOTALL))
            if not matches:
                return None
            answer = matches[-1].group(1).strip()

            code_block = self.sanitize(answer).strip()
            if '\x00' in code_block:
                code_block = code_block.replace('\x00', '')
            
            if self.check_ce(code_block):
                return None
            # return f"```python\n{code_block}\n```"
            return code_block
            
            
                    
        def postprocess(api_response: str) -> str:
            """
            sanitize the api response
            """
            code_line = self.sanitize(api_response)
            
            return code_line
            
            
        
        def request_api(question, response, max_retries=3):
            
            message = self.build_message(question, response)
            
            attempts = 0
            while attempts < max_retries:
                try:
                    api_response = self.client.chat.completions.create(
                        model=self.config.reward_model.error_local_model,
                        messages=message,
                        temperature=0.5,
                        max_tokens=128,
                    )
                    if api_response.choices and len(api_response.choices) > 0:
                        raw_answer = api_response.choices[0].message.content
                        error_code = postprocess(raw_answer)
                        if error_code in response:
                            return error_code
                finally:
                    attempts += 1
            return None
            
            # ## just for debug
            # lines = response.split('\n')
            # non_empty_lines = [line for line in lines if line != ''] 
            # if not non_empty_lines:  
            #     return None
            # error_code = random.choice(non_empty_lines)
            # return error_code

        
        def localization(question, response, valid_response_ids):
            
            response = preprocess(response)
            # Check if response is None, empty string, or only contains whitespaces
            if response is None or response.strip() == "":
                return None, None
            
            error_code = request_api(question, response)
            if error_code is None:
                return None, None
            
            error_start = response.find(error_code)  
            error_end = error_start + len(error_code)
            
            tokens_before_error = self.tokenizer.encode(response[:error_start], add_special_tokens=False)
            tokens_in_error = self.tokenizer.encode(error_code, add_special_tokens=False)
            
            error_start_idx = len(tokens_before_error)
            error_end_idx = error_start_idx + len(tokens_in_error)

            return error_start_idx, error_end_idx
        
        error_indices = []
        with ThreadPoolExecutor(max_workers=max_parallel_threads) as executor:
            futures = [
                (i, executor.submit(localization, question, response, valid_response_ids))
                for i, (question, response, valid_response_ids) in enumerate(zip(questions, responses, valid_response_idss))
            ]
            
            # Collect the results
            results = []
            for idx, future in futures:
                error_start_idx, error_end_idx = future.result()
                results.append((idx, error_start_idx, error_end_idx))
            
            # Sort results based on the original index
            results.sort(key=lambda x: x[0])
            
            # Extract the error indices in original order
            error_indices = [(error_start_idx, error_end_idx) for _, error_start_idx, error_end_idx in results]

        return error_indices