# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import re
from typing import Any

import datasets

from verl.tools.base_tool import OpenAIFunctionToolSchema
from verl.tools.sandbox_fusion_tools import SandboxFusionTool
from verl.utils.dataset import RLHFDataset

from verl.utils.rollout_trace import rollout_trace_op
from mathruler.grader import extract_boxed_content, grade_answer
from agentmath.reward_score import tool_text_r1_format
logger = logging.getLogger(__name__)

class ProxyManager:
    def __enter__(self):
        self._original_http_proxy = os.environ.pop("http_proxy", None)
        self._original_https_proxy = os.environ.pop("https_proxy", None)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._original_http_proxy is not None:
            os.environ["http_proxy"] = self._original_http_proxy
        if self._original_https_proxy is not None:
            os.environ["https_proxy"] = self._original_https_proxy

class CustomSandboxFusionTool(SandboxFusionTool):
    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        self.code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
        sandbox_url_ips = os.environ.get("IP_Address_strings", None)
        self.sandbox_fusion_url_lists = []
        if sandbox_url_ips is not None:
            ips_list = sandbox_url_ips.split(",")
            url_template = "http://{ip}:8080/run_code"
            for ip_k in ips_list:
                self.sandbox_fusion_url_lists.append(url_template.format(ip=ip_k))
        print(
            f"\nself.sandbox_fusion_url_lists ==== {self.sandbox_fusion_url_lists}, "
            f"sandbox_fusion_url_lists Nums === {len(self.sandbox_fusion_url_lists)}\n"
        )


    @rollout_trace_op
    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[str, float, dict]:
        if len(self.sandbox_fusion_url_lists) != 0:
            self.sandbox_fusion_url = random.sample(self.sandbox_fusion_url_lists, k=1)[0]
            print(f"sample sandbox_fusion_url ==== {self.sandbox_fusion_url}")

        code = parameters["code"]
        matches = self.code_pattern.findall(code)
        if matches:
            code = matches[0].strip()

        # NOTE: some script may not explicitly print result, we need to add a print statement to the end of the script
        lines = code.split("\n")
        for i, line in reversed(list(enumerate(lines))):
            if line == "":
                continue
            if not lines[i].startswith("print"):
                lines[i] = f"{line}"
            break
        code = "\n".join(lines)

        timeout = parameters.get("timeout", self.default_timeout)
        language = parameters.get("language", self.default_language)
        if not isinstance(code, str):
            code = str(code)

        result = await self.execution_pool.execute.remote(self.execute_code, instance_id, code, timeout, language)
        return result, None, None


class CustomRLHFDataset(RLHFDataset):
    def _read_files_and_tokenize(self):
        dataframes = []
        for parquet_file in self.data_files:
            dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
            data_source = "/".join(parquet_file.split("/")[-2:])
            if data_source in ["AIME_2024", "aime_2025"]:
                dataframe = dataframe.map(
                    self.map_fn, fn_kwargs={"data_source": data_source}, remove_columns=dataframe.column_names
                )
            else:
                dataframe = dataframe.map(self.map_fn2, num_proc=16)
            dataframes.append(dataframe)
        self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)

        print(f"dataset len: {len(self.dataframe)}")

    def map_fn(self, row: dict, *, data_source: str = None):
        if data_source == "AIME_2024":
            problem, answer = row["Problem"], row["Answer"]
        elif data_source == "aime_2025":
            problem, answer = row["problem"], row["answer"]

        prompt = problem
        data = {
            "data_source": data_source.split("/")[1].lower(),  # aime_2024, aime_2025
            "prompt": [{"role": "user", "content": prompt}],
            "ability": "MATH",
            "reward_model": {"ground_truth": str(answer)},
            "agent_name": "ag_math_agent_patial_rollout",
        }
        return data

    def map_fn2(self, row: dict):
        content = row["prompt"][0]["content"]
        row["prompt"][0]["content"] = content
        row["agent_name"] = "ag_math_agent_patial_rollout"
        return row





def tool_default_compute_score(data_source, solution_str, ground_truth, extra_info=None):
    score = tool_text_r1_format.compute_score_format_answer(solution_str, ground_truth)
    acc = tool_text_r1_format.compute_score_answer(solution_str, ground_truth)
    format = tool_text_r1_format.compute_score_format(solution_str)
    pred_text = tool_text_r1_format.pred_ans_extract(solution_str)
    pred_ans = extract_boxed_content(pred_text)
    if pred_ans is None:
        pred_ans = pred_text
    return {
        "score": float(score),
        "acc": float(acc),
        "format": float(format),
        "pred": pred_ans
    }


def compute_score(data_source, solution_str, ground_truth, extra_info):
    result = tool_default_compute_score(data_source, solution_str, ground_truth, extra_info)
    return result