# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import openai
import asyncio
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor
import concurrent.futures
import requests
from tqdm import tqdm
from functools import partial
from typing import Any, List, Dict, Tuple, Optional
from tqdm.asyncio import tqdm_asyncio
from transformers import PreTrainedTokenizer
import numpy as np
import pandas as pd
from tqdm.asyncio import tqdm_asyncio
import ast
import re
import random
import torch
import difflib

from verl import DataProto


# exec reward:
"""
Exec reward:
    no answer: -5
    compilation error: -2
    test case validation: 0-5
Format reward:
    -1, 1
"""
REWARD_CE = -2
REWARD_NON_ANSWER = -2




RUN_TIMEOUT = 10
MAX_REQUESTS = 64
LOC_MAX_REQUESTS = 256

def normalize_code(code_str):
    try:
        # Parse the code into an AST
        tree = ast.parse(code_str)

        # Dictionary to store variable name mappings
        var_counter = 0
        var_map = {}

        # AST transformer to rename variables
        class VariableRenamer(ast.NodeTransformer):
            def visit_Name(self, node):
                nonlocal var_counter
                if isinstance(node.ctx, ast.Store):
                    if node.id not in var_map:
                        var_map[node.id] = f"v_{var_counter}"
                        var_counter += 1
                return ast.Name(id=var_map.get(node.id, node.id), ctx=node.ctx)

        # Apply the transformation
        transformed = VariableRenamer().visit(tree)

        # Convert back to string with normalized formatting
        normalized = ast.unparse(transformed)
        return normalized

    except SyntaxError:
        # If parsing fails, return original code
        return code_str



class ErrorLocRewardManager:
    """
    """

    def __init__(
        self,
        config,
        file,  # train_file or test_file
        tokenizer: PreTrainedTokenizer,
        num_examine=0,
        run_all_cases=True,
    ):
        self.config = config
        self.file = file
        self.tokenizer = tokenizer
        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console
        self.run_all_cases = run_all_cases
        self.client = openai.Client(base_url=self.config.reward_model.error_local_model_url, api_key="None")
        
        
        
        
        if isinstance(self.file, str):
            self.dataset = pd.read_parquet(self.file)
        else:
            self.dataset = pd.concat([pd.read_parquet(f) for f in self.file])

        self.dataset["proxy_id"] = self.dataset["task_id"]
        self.id_to_infos = self.dataset.set_index("proxy_id").to_dict(orient="index")

        # implement code hash so that we can reduce sandbox usage
        self.code_to_reward = {}
        self.executor = ThreadPoolExecutor(max_workers=32)  # 定义最大并发线程数
    
    def extract_question(self, prompt_str: str) -> Tuple[Optional[str], str]:
       
        # Split prompt_str to isolate question string
        if "<|im_start|>user" in prompt_str:
            system_prompt = prompt_str.split("<|im_start|>user", 1)[0]
            question_str = prompt_str.split("<|im_start|>user", 1)[1]
        else:
            question_str = prompt_str
        return question_str
    
    def desanitize(self, text: str) -> str:
        # Pattern to match code blocks starting with ```, with optional language identifier
        # and capturing everything after until the end or until another ```
        pattern = r"```(?:python)?\s*([\s\S]*?)(?:\s*```|$)"
        # Find all matches in the text
        matches = re.findall(pattern, text, re.IGNORECASE)

        # Return the first one
        return f"```python\n{matches[0]}\n```" if matches and len(matches[0]) > 0 else text
    
    def sanitize(self, text: str) -> str:
        # Remove the starting and ending ```
        pattern = r"```(?:python)?\s*([\s\S]*?)\s*```"
        match = re.search(pattern, text, re.IGNORECASE)

        if match:
            return match.group(1).strip()
        return text
    
        
    def check_ce(self, code_str: str) -> bool:
        
        if not isinstance(code_str, str):
            return True
        try:
            ast.parse(code_str)
            return False
        except:
            return True
        
    
    def parse_response(self, data: DataProto) -> DataProto:
        task_ids = []
        questions = []
        responses = []
        samples = []
        valid_response_lengths = []
        valid_response_idss = []

        for i in range(len(data)):
            data_item = data[i]
            
            prompt_ids = data_item.batch["prompts"]
            prompt_length = prompt_ids.shape[-1]
            valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
            valid_prompt_ids = prompt_ids[-valid_prompt_length:]
            
            response_ids = data_item.batch["responses"]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = response_ids[:valid_response_length]

            # decode
            prompt_str = self.tokenizer.decode(valid_prompt_ids)
            response_str = self.tokenizer.decode(valid_response_ids)
            
            # # remove <eos> token
            # response_str = response_str.replace(self.tokenizer.eos_token, "")
            
            question = self.extract_question(prompt_str)
            task_id = data_item.non_tensor_batch["task_id"]

            
            task_ids.append(task_id)
            questions.append(question)
            responses.append(response_str)
            samples.append(data_item.non_tensor_batch)
            valid_response_lengths.append(valid_response_length)
            valid_response_idss.append(valid_response_ids)
            
            
        return task_ids, questions, responses, samples, valid_response_lengths, valid_response_idss
        
        
    def build_message(self, question, response):
        
        # generate revision
        assert isinstance(question, str)
        assert isinstance(response, str)
        
        response_code_block = self.sanitize(response)
        
        
        prompt = f"""I have an algorithm competition problem and the corresponding python solution, but I have verified that the solution is wrong, and I would like you to point out on which line the error occurs.

### Problem
{question}

### Python Solution
{response_code_block}

Please note that your formatting must adhere to the following rules.
1. Select the only line of code that you think is most likely to cause an error in your code.
2. You need to give the contents of the wrong line of code, enclosed in a python code block.
3. You don't need to give anything to understand, analyze, or correct the error.

### Example Output
```python
Error Code Line
```
""" 
        
        message = [{
            "role": "user",
            "content": prompt
        }]
        
        return message
    
    async def exec_reward_sanboxfusion(
        self, response: str, sample: dict, semaphore: asyncio.Semaphore
    ) -> float:
        
        answer_pattern = r'<answer>(.*?)</answer>'
        matches = list(re.finditer(answer_pattern, response, re.DOTALL))
        if not matches:
            return REWARD_NON_ANSWER
        answer = matches[-1].group(1).strip()

        code_block = self.sanitize(answer).strip()
        if '\x00' in code_block:
            code_block = code_block.replace('\x00', '')
        
        if self.check_ce(code_block):
            return REWARD_CE
        
        success = await self.get_sandbox_response(code_block, sample, semaphore)

        
        exec_reward = 5 * success
        return exec_reward
    
    def exec_reward_oj(
        self, response: str, sample: dict
    ) -> float:
        
        answer_pattern = r'<answer>(.*?)</answer>'
        matches = list(re.finditer(answer_pattern, response, re.DOTALL))
        if not matches:
            return REWARD_NON_ANSWER
        answer = matches[-1].group(1).strip()

        code_block = self.sanitize(answer).strip()
        if '\x00' in code_block:
            code_block = code_block.replace('\x00', '')
        
        if self.check_ce(code_block):
            return REWARD_CE

        success = self.get_oj_response(code_block, sample)
        
        exec_reward = 5 * success
        return exec_reward
    
    
    
    def calculate_format_reward(self, processed_str: str, format_reward: int = 1):
        """Performs comprehensive validation of response structure.

        Args:
            processed_str: Processed response string from the model

        Returns:
            Boolean indicating whether all formatting requirements are met
        """
        debug_str = []
        debug_str.append("\n[Structure Validation]")
        validation_passed = True

        # 检查唯一标签的位置
        unique_tags = {
            'think': '<think>',
            'think_end': '</think>',
            'answer': '<answer>',
            'answer_end': '</answer>'
        }

        positions = {}
        for tag_name, tag_str in unique_tags.items():
            count = processed_str.count(tag_str)
            positions[tag_name] = pos = processed_str.find(tag_str)
            
            debug_str.append(f"  {tag_str}: count={count}, position={pos}")
            
            if count != 1:
                debug_str.append(f"  [Error] {tag_str} appears {count} times (expected 1)")
                validation_passed = False

        # 验证基本顺序（think在开头，answer在结尾）
        if processed_str.strip()[0:len("<think>")] != "<think>":
            debug_str.append("  [Error] Incorrect start: Expected <think> at beginning")
            validation_passed = False
        elif positions['think'] > positions['think_end'] or positions['think_end'] > positions['answer'] or positions['answer'] > positions['answer_end']:
            debug_str.append("  [Error] Incorrect tag order: Expected <think>...</think><answer>...</answer>")
            validation_passed = False
        elif not (processed_str.strip().endswith("</answer><|endoftext|>") or 
                processed_str.strip().endswith("</answer><|im_end|>")):
            debug_str.append("  [Error] Incorrect ending: Expected </answer><|endoftext|> or </answer><|im_end|>")
            validation_passed = False

        # 验证step和code的数量和顺序
        step_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<step>', i)]
        step_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</step>', i)]
        code_positions = [i for i in range(len(processed_str)) if processed_str.startswith('<code>', i)]
        code_end_positions = [i for i in range(len(processed_str)) if processed_str.startswith('</code>', i)]

        # 记录多次出现的标签信息
        debug_str.append(f"  <step> count: {len(step_positions)}")
        debug_str.append(f"  </step> count: {len(step_end_positions)}")
        debug_str.append(f"  <code> count: {len(code_positions)}")
        debug_str.append(f"  </code> count: {len(code_end_positions)}")

        if len(step_positions) != len(code_positions):
            debug_str.append(f"  [Error] Number of <step> ({len(step_positions)}) does not match <code> ({len(code_positions)})")
            validation_passed = False
        else:
            for i in range(len(step_positions)):
                if step_positions[i] > code_positions[i]:
                    debug_str.append(f"  [Error] <step> at position {step_positions[i]} comes after <code> at position {code_positions[i]}")
                    validation_passed = False

        # 验证每对标签的配对
        if len(step_end_positions) != len(step_positions):
            debug_str.append(f"  [Error] Number of </step> ({len(step_end_positions)}) does not match <step> ({len(step_positions)})")
            validation_passed = False

        if len(code_end_positions) != len(code_positions):
            debug_str.append(f"  [Error] Number of </code> ({len(code_end_positions)}) does not match <code> ({len(code_positions)})")
            validation_passed = False

        format_score = format_reward if validation_passed else -abs(format_reward)

        return format_score, debug_str
    
    async def get_reward_all_sandboxfusion(
        self, responses: List[str], samples: List[Any]
    ) -> List[float]:
        """
        samples.keys() = ['task_id', 'problem', 'public_uts', 'private_uts', 'generated_uts', 'all_uts', 'prompter_type', 'selected_uts', 'problem_id', 'id', 'index', 'uid']
        """
        
        # exec_reward calculation
        semaphore = asyncio.Semaphore(MAX_REQUESTS)
        exec_rewards = await tqdm_asyncio.gather(
            *[
                self.exec_reward_sanboxfusion(response, sample, semaphore)
                for response, sample in zip(responses, samples)
            ],
            desc="Generating rewards",
            mininterval=10.0 # min print interval 10s
        )
        
        # format_reward calculation
        loop = asyncio.get_running_loop()
        format_rewards, debug_infos = await loop.run_in_executor(
            self.executor,
            lambda: zip(*[self.calculate_format_reward(response) for response in responses])
        )
        format_rewards = list(format_rewards)
        debug_infos = list(debug_infos)
        
        # total_reward calculation
        total_rewards = [
            exec_reward + format_reward
            for exec_reward, format_reward in zip(exec_rewards, format_rewards)
        ]
            
        for i in range(len(samples)):
            if i < self.num_examine:
                print("*" * 40)
                print("[TASK_ID: ]", samples[i]["task_id"])
                print("[PROMPT: ]\n", samples[i]["problem"])
                print("[RESPONSE: ]\n", responses[i])
                print("-" * 20)
                print("\n".join(debug_infos[i]))
                print(f"  Format: {format_rewards[i]}")
                print(f"  Answer: {exec_rewards[i]}")
                print(f"  Total: {total_rewards[i]}")
                print("-" * 20)
                print("*" * 40)

        return total_rewards, exec_rewards, format_rewards
    
    def get_reward_all_oj(self, responses: List[str], samples: List[Any],
                        max_parallel_threads: int = MAX_REQUESTS) -> List[float]:
        """
        samples.keys() = ['task_id', 'problem', 'public_uts', 'private_uts', 'generated_uts', 'all_uts', 'prompter_type', 'selected_uts', 'problem_id', 'id', 'index', 'uid']
        """
        
        exec_rewards = [REWARD_NON_ANSWER] * len(responses)
        with ThreadPoolExecutor(max_workers=max_parallel_threads) as executor:
            future_to_index = {executor.submit(self.exec_reward_oj, response, sample) : index
                            for index, (response, sample) in enumerate(zip(responses, samples))}

            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    exec_reward = future.result()
                    exec_rewards[index] = exec_reward
                except Exception as exc:
                    print(f'Generated an exception: {exc}')

        # format_reward calculation using ThreadPoolExecutor
        with ThreadPoolExecutor(max_workers=max_parallel_threads) as executor:
            format_rewards, debug_infos = zip(*list(executor.map(self.calculate_format_reward, responses)))
        format_rewards = list(format_rewards)
        debug_infos = list(debug_infos)
        
        # total_reward calculation
        total_rewards = [
            exec_reward + format_reward
            for exec_reward, format_reward in zip(exec_rewards, format_rewards)
        ]
            
        for i in range(len(samples)):
            if i < self.num_examine:
                print("*" * 40)
                print("[TASK_ID: ]", samples[i]["task_id"])
                print("[PROMPT: ]\n", samples[i]["problem"])
                print("[RESPONSE: ]\n", responses[i])
                print("-" * 20)
                print("\n".join(debug_infos[i]))
                print(f"  Format: {format_rewards[i]}")
                print(f"  Answer: {exec_rewards[i]}")
                print(f"  Total: {total_rewards[i]}")
                print("-" * 20)
                print("*" * 40)

        return total_rewards, exec_rewards, format_rewards




    def get_oj_response(self, code_str: str, sample: dict):
        
        submissions = []
        def build_submission(code_str, output_str, input_str=None) -> dict:
            submission = {
                "type": "python",
                "solution": code_str,
                "expected_output": output_str,
            }
            if input_str is not None:
                submission["input"] = input_str
            return submission

        def SubmitBatch(submissions) -> list:
            data = {
                "type": "batch",
                "submissions": submissions
            }

            try:
                response = requests.post(self.config.reward_model.url, json=data)
                response.raise_for_status()

                results = response.json()['results']
                success_list = [res['success'] for res in results]
                assert len(success_list) == len(submissions)
                return success_list
            except requests.exceptions.RequestException as e:
                print(f"Request failed: {e}")
                return [False] * len(submissions)
            except (ValueError, KeyError, AssertionError) as e:
                print(f"Failed to process response: {e}")
                return [False] * len(submissions)

        uts = sample.get("selected_uts", [])
        if not isinstance(uts, list): 
            try:
                uts = ast.literal_eval(uts)
            except (ValueError, SyntaxError):
                uts = []
        if not uts:  
            return 0
        
        sample_size = 10
        if len(uts) > sample_size:
            uts = random.sample(uts, sample_size)
        
        if  sample["prompter_type"] == "mbppplus":
            if self.run_all_cases:
                for ut in uts:
                    submissions.append(build_submission(code_str + ut, ""))
            else:
                submissions.append(build_submission(code_str, uts[0], ""))
                
        elif sample["prompter_type"] == "code_contests":
            if self.run_all_cases:
                for ut in uts:
                    submissions.append(build_submission(code_str, ut['output']['stdout'], ut['input']['stdin']))
            else:
                submissions.append(build_submission(code_str, uts[0]['output']['stdout'], uts[0]['input']['stdin']))
        else:
            raise RuntimeError("Invalid prompter_type!")

        res = SubmitBatch(submissions)
        reward = np.mean(res)
        return reward


    def get_error_localization(self, questions: List[str], responses: List[str], valid_response_idss: List[Any],
                        max_parallel_threads: int = LOC_MAX_REQUESTS) -> List[Any]:
        """
        """
        def preprocess(response: str) -> bool:
            """
            check if need error localization
            True: valid code
            False: invalid code
            """
            answer_pattern = r'<answer>(.*?)</answer>'
            matches = list(re.finditer(answer_pattern, response, re.DOTALL))
            if not matches:
                return None
            answer = matches[-1].group(1).strip()

            code_block = self.sanitize(answer).strip()
            if '\x00' in code_block:
                code_block = code_block.replace('\x00', '')
            
            if self.check_ce(code_block):
                return None
            # return f"```python\n{code_block}\n```"
            return code_block
            
            
                    
        def postprocess(api_response: str) -> str:
            """
            sanitize the api response
            """
            code_line = self.sanitize(api_response)
            
            return code_line
            
            
        
        def request_api(question, response, max_retries=3):
            
            message = self.build_message(question, response)
            
            attempts = 0
            while attempts < max_retries:
                try:
                    api_response = self.client.chat.completions.create(
                        model=self.config.reward_model.error_local_model,
                        messages=message,
                        temperature=0.5,
                        max_tokens=128,
                    )
                    if api_response.choices and len(api_response.choices) > 0:
                        raw_answer = api_response.choices[0].message.content
                        error_code = postprocess(raw_answer)
                        if error_code in response:
                            return error_code
                finally:
                    attempts += 1
            return None
            
            # ## just for debug
            # lines = response.split('\n')
            # non_empty_lines = [line for line in lines if line != ''] 
            # if not non_empty_lines:  
            #     return None
            # error_code = random.choice(non_empty_lines)
            # return error_code

        
        def localization(question, response, valid_response_ids):
            
            response = preprocess(response)
            # Check if response is None, empty string, or only contains whitespaces
            if response is None or response.strip() == "":
                return None, None
            
            error_code = request_api(question, response)
            if error_code is None:
                return None, None
            
            error_start = response.find(error_code)  
            error_end = error_start + len(error_code)
            
            tokens_before_error = self.tokenizer.encode(response[:error_start], add_special_tokens=False)
            tokens_in_error = self.tokenizer.encode(error_code, add_special_tokens=False)
            
            error_start_idx = len(tokens_before_error)
            error_end_idx = error_start_idx + len(tokens_in_error)

            return error_start_idx, error_end_idx
        
        error_indices = []
        with ThreadPoolExecutor(max_workers=max_parallel_threads) as executor:
            futures = [
                (i, executor.submit(localization, question, response, valid_response_ids))
                for i, (question, response, valid_response_ids) in enumerate(zip(questions, responses, valid_response_idss))
            ]
            
            # Collect the results
            results = []
            for idx, future in futures:
                error_start_idx, error_end_idx = future.result()
                results.append((idx, error_start_idx, error_end_idx))
            
            # Sort results based on the original index
            results.sort(key=lambda x: x[0])
            
            # Extract the error indices in original order
            error_indices = [(error_start_idx, error_end_idx) for _, error_start_idx, error_end_idx in results]

        return error_indices