# Copyright 2024 PRIME team and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial
from typing import Callable, Optional
from tqdm import tqdm
import psutil
import torch
from transformers import PreTrainedTokenizer

from verl import DataProto
from verl.utils.reward_score import default_compute_score
from verl.workers.reward_manager import register
import multiprocessing as mp
from typing import Any, Callable, List, Sequence, Optional
from tqdm import tqdm


async def single_compute_score(evaluation_func, completion, reference, task, task_extra_info, executor, timeout=300.0):
    loop = asyncio.get_running_loop()
    try:
        # Ensure process_completion is called properly
        future = loop.run_in_executor(executor, partial(evaluation_func, task, completion, reference, task_extra_info))
        return await asyncio.wait_for(future, timeout=timeout)
    except asyncio.TimeoutError:
        print(f"[Timeout] Task timeout: {completion}")
        return None  # Default value for timed-out rows
    except Exception as e:
        print(f"[Error] Task failed: {e}, completion: {completion[:80]}")
        return None  # Default value for failed rows


async def parallel_compute_score_async(
    evaluation_func, completions, references, tasks, extra_info=None, num_processes=64
):
    if extra_info is None:
        extra_info = [None] * len(tasks)
    scores = []
    with ThreadPoolExecutor(max_workers=num_processes) as executor:
        # to prevent very occasional starvation caused by some anomalous programs ( like infinite loop ), the
        # exceptions in async programs will instantly halt the evaluation, and all summoned processes will be killed.
        try:
            # Create tasks for all rows
            tasks_async = [
                single_compute_score(evaluation_func, c, r, t, ei, executor, timeout=300.0)
                for c, r, t, ei in zip(completions, references, tasks, extra_info, strict=True)
            ]
            results = await asyncio.gather(*tasks_async, return_exceptions=False)
        except Exception as e:
            print(f"[Exception] async gather failed: {e}")
            raise
        finally:
            # terminated_count = 0
            # for pid, proc in executor._processes.items():
            #     try:
            #         p = psutil.Process(pid)
            #         p.terminate()
            #         try:
            #             p.wait(timeout=5)
            #         except psutil.TimeoutExpired:
            #             p.kill()
            #         terminated_count += 1
            #     except Exception:
            #         pass
            # print(f"[Shutdown] {terminated_count} subprocess(es) terminated.")
            print("[Shutdown] ThreadPoolExecutor finished.")

    # Process results
    for result, completion, reference, task in zip(results, completions, references, tasks, strict=True):
        if isinstance(result, Exception) or result is None:
            # Handle failed or timed-out tasks
            scores.append(0.0)
        elif isinstance(result, int | float | bool):
            scores.append(float(result))
        else:
            scores.append(float(result[0]))
    return scores


def run_reward_scoring(evaluation_func, completions, references, tasks, extra_info=None, num_processes=64):
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        print("Before run_until_complete")
        scores = loop.run_until_complete(
            parallel_compute_score_async(evaluation_func, completions, references, tasks, extra_info, num_processes=num_processes)
        )
        print("After run_until_complete")
        return scores
    finally:
        loop.close()
    # scores = asyncio.run(
    #     parallel_compute_score_async(evaluation_func, completions, references, tasks, num_processes=num_processes)
    # )
    # return scores


def run_reward_scoring_serial(evaluation_func, completions, references, tasks, extra_info=None, num_processes=None):
    if extra_info is None:
        extra_info = [None] * len(tasks)
    scores = []
    for c, r, t, ei in tqdm(zip(completions, references, tasks, extra_info, strict=True)):
        try:
            score = evaluation_func(data_source=t, solution_str=c, ground_truth=r, extra_info=ei)
            scores.append(score)
        except Exception as e:
            print(f"[Error] Failed to evaluate task: {e}")
            scores.append(None)
    return scores


@register("prime_backup")
class PrimeRewardManager:

    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        num_examine: int,
        compute_score: Optional[Callable] = None,
        reward_fn_key: str = "data_source",
    ) -> None:
        self.tokenizer = tokenizer
        self.num_examine = num_examine  # the number of batches of decoded responses to print to the console
        self.compute_score = compute_score or default_compute_score
        self.reward_fn_key = reward_fn_key

    def verify(self, data):
        """
        verify the batch and save as ``acc`` tensor
        """
        # batched scoring
        prompt_ids = data.batch["prompts"]

        response_ids = data.batch["responses"]
        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
        ground_truth = [data_item.non_tensor_batch["reward_model"]["ground_truth"] for data_item in data]
        data_sources = data.non_tensor_batch[self.reward_fn_key]
        extra_info = data.non_tensor_batch.get("extra_info", None)

        assert len(sequences_str) == len(ground_truth) == len(data_sources)
        try:
            scores = run_reward_scoring(
                self.compute_score,
                completions=sequences_str,
                references=ground_truth,
                tasks=data_sources,
                extra_info=extra_info,
                num_processes=64,
            )
        except asyncio.TimeoutError:
            print("[Timeout] Global reward scoring timed out. Setting all as 0.")
            scores = [0.0 for _ in range(len(sequences_str))]
        except Exception as e:
            print(f"[Error] Unexpected error during scoring. Setting all as 0. {e}")
            scores = [0.0 for _ in range(len(sequences_str))]
        data.batch["acc"] = torch.tensor(scores, dtype=torch.float32, device=prompt_ids.device)
        return scores

    def __call__(self, data: DataProto, return_dict: bool = False):
        """We will expand this function gradually based on the available datasets"""

        # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
        if "rm_scores" in data.batch.keys():
            return data.batch["rm_scores"]

        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)

        already_print_data_sources = {}

        # batched scoring
        prompt_ids = data.batch["prompts"]
        prompt_length = prompt_ids.shape[-1]

        response_ids = data.batch["responses"]
        valid_response_length = data.batch["attention_mask"][:, prompt_length:].sum(dim=-1)
        sequences_str = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)
        data_sources = data.non_tensor_batch["data_source"]

        scores = self.verify(data)

        for i in range(len(data)):
            data_source = data_sources[i]
            reward_tensor[i, valid_response_length[i].item() - 1] = scores[i]

            if data_source not in already_print_data_sources:
                already_print_data_sources[data_source] = 0

            if already_print_data_sources[data_source] < self.num_examine:
                already_print_data_sources[data_source] += 1
                print(sequences_str)

        if return_dict:
            return {"reward_tensor": reward_tensor}
        else:
            return reward_tensor
        
################################################################################
# Internal helpers
################################################################################

def _evaluate_wrapper(args: tuple[Any, ...]) -> Any:
    """Wrap *evaluation_func* so it can be pickled and exceptions are trapped.

    The wrapper ensures that an exception inside ``evaluation_func`` does **not**
    crash the pool worker but is reported back to the main process as ``None``.
    """
    evaluation_func, c, r, t, ei = args
    try:
        return evaluation_func(
            data_source=t, solution_str=c, ground_truth=r, extra_info=ei
        )
    except Exception as exc:  # noqa: BLE001 (broad – we want to catch *everything*)
        print(f"[Error] Failed to evaluate task: {exc}")
        return None


################################################################################
# Public API
################################################################################

def run_reward_scoring_my(
    evaluation_func: Callable[..., Any],
    completions: Sequence[str],
    references: Sequence[str],
    tasks: Sequence[Any],
    extra_info: Optional[Sequence[Any]] = None,
    *,
    num_processes: Optional[int] = 64,
    timeout: Optional[float] = 30,
) -> List[Any]:
    """Evaluate *evaluation_func* over a batch of tasks in parallel.

    Parameters
    ----------
    evaluation_func
        User‑supplied scoring function. Must be **picklable** so it can be sent
        to worker processes.
    completions, references, tasks, extra_info
        Iterables of equal length providing the inputs for each evaluation.
    num_processes
        Number of worker processes.  ``None`` → use :pyfunc:`os.cpu_count`.
        ``1`` effectively disables parallelism and runs serially (still with
        timeout handling).
    timeout
        Maximum seconds allowed for a single call to *evaluation_func*.
        ``None`` → wait indefinitely (previous behaviour).

    Returns
    -------
    list
        Scores aligned with the input order.

    Notes
    -----
    * If a call **times‑out** it returns ``0``.
    * If a call **raises** an exception it returns ``None``.
    * The function shows a progress bar powered by *tqdm*.
    """

    if extra_info is None:
        extra_info = [None] * len(tasks)

    if not (
        len(completions)
        == len(references)
        == len(tasks)
        == len(extra_info)
    ):
        raise ValueError("All input sequences must have the same length.")

    # Prepare work items for the pool – keep original indices so we can reorder
    work_items: list[tuple[int, tuple[Any, ...]]] = [
        (
            idx,
            (
                evaluation_func,
                c,
                r,
                t,
                ei,
            ),
        )
        for idx, (c, r, t, ei) in enumerate(
            zip(completions, references, tasks, extra_info, strict=True)
        )
    ]

    # Output buffer with deterministic ordering
    scores: list[Any] = [None] * len(tasks)

    # Serial fallback when num_processes == 1 (or "parallel off" for debugging)
    if num_processes == 1:
        for idx, payload in tqdm(work_items):
            try:
                scores[idx] = _evaluate_wrapper(payload)
            except Exception:
                # _evaluate_wrapper already catches and logs – keep behaviour
                pass
        return scores

    # ------------------------------------------------------------------
    # Parallel execution using multiprocessing.Pool
    # ------------------------------------------------------------------

    with mp.Pool(processes=num_processes) as pool:
        async_results: list[tuple[int, mp.pool.ApplyResult]] = [
            (idx, pool.apply_async(_evaluate_wrapper, (payload,)))
            for idx, payload in work_items
        ]

        for idx, async_res in tqdm(async_results):
            try:
                scores[idx] = async_res.get(timeout=timeout)
            except mp.TimeoutError:
                print(
                    f"[Timeout] Task {idx} exceeded {timeout} seconds. "
                    "Assigning score=0."
                )
                scores[idx] = 0
            # Any other Exception has already been caught in the wrapper.

    return scores
