# Adapted form https://github.com/PeterGriffinJin/Search-R1/blob/ceee7b89655ed52f205b9beb98e1190c3eedcfb0/search_r1/llm_agent/generation.py
from typing import List

import asyncio
import re
import copy
import json
import os
import os.path as osp
import numpy as np

from tqdm import tqdm
from transformers import AutoTokenizer

from slime.utils.async_utils import run
from slime.utils.data import Dataset
from slime.utils.http_utils import get, post
from slime.utils.misc import SingletonMeta, load_function
from slime.utils.types import Sample

from slime.rollout.rm_hub import async_rm, batched_async_rm

from lean_plugins.format_check import check_output_format

__all__ = ["lean_generate_rollout"]

CONFIGS = {
    "max_retries": 3,
    "search_concurrency": 256,
}

def convert_samples_to_data(samples: list[Sample]):
    return [sample.to_dict() for sample in samples]

def save_info_and_data(args, rollout_id: int, info: dict, data: List[List[dict]]):
    save_path = osp.join(args.save, "rollout_data")
    os.makedirs(save_path, exist_ok=True)

    with open(osp.join(save_path, f"rollout_info.jsonl"), "a") as f:
        f.write(json.dumps(info, ensure_ascii=False) + "\n")

    with open(osp.join(save_path, f"rollout_{rollout_id}.json"), "a") as f:
        json.dump(data, f, ensure_ascii=False)

PROMPT = """Think step by step to translate the mathematical problem in natural language to Lean 4, and verify the consistency.\n{informal_statement}\n"""

class GenerateState(metaclass=SingletonMeta):
    """
    The global state for the generation process.
    """

    def __init__(self, args):
        # persistant state for the generation process
        self.args = args
        self.tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
        self.semaphore = asyncio.Semaphore(
            args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
        )
        print(f"rollout concurrency: {args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine}", flush=True)
        # self.semaphore = asyncio.Semaphore(
        #     1
        # )
        self.sampling_params = dict(
            temperature=args.rollout_temperature,
            top_p=args.rollout_top_p,
            top_k=args.rollout_top_k,
            max_new_tokens=args.rollout_max_response_len,
            stop=args.rollout_stop,
            stop_token_ids=args.rollout_stop_token_ids,
            skip_special_tokens=args.rollout_skip_special_tokens,
            no_stop_trim=True,
            spaces_between_special_tokens=False,
            presence_penalty=1.1,
        )

        self.reset()

    def reset(self):
        self.remaining_batch_size = 0
        self.pendings = set()
        self.aborted = False

    def submit_generate_tasks(self, samples: list[list[Sample]]):
        for group in samples:
            self.pendings.add(
                asyncio.create_task(
                    # submit a group of samples as a single task.
                    generate_and_rm_group(
                        self.args,
                        group,
                        sampling_params=self.sampling_params.copy(),
                        evaluation=False,
                    )
                )
            )
        self.remaining_batch_size += len(samples)


async def generate(args, sample: Sample, sampling_params, evaluation=False) -> Sample:
    assert not args.partial_rollout, f"Partial rollout is not supported for this function at the moment."
    state = GenerateState(args)

    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"

    assert (
        sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED
    ), f"Sample status is {sample.status}"

    if len(sample.response) > 0:
        response_token_ids = state.tokenizer(sample.response, add_special_tokens=False)["input_ids"]
        sampling_params["max_new_tokens"] -= len(response_token_ids)

    assert (
        sampling_params["max_new_tokens"] >= 0
    ), f"max_new_tokens: {sampling_params['max_new_tokens']} should not be less than 0"
    if sampling_params["max_new_tokens"] == 0:
        sample.status = Sample.Status.TRUNCATED
        return sample

    # Handle partial rollout samples: continue generation from existing response
    messages = [
        {"role": "user", "content": PROMPT.format(informal_statement=sample.prompt)}
    ]
    prompt = state.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    input_token_ids = state.tokenizer(prompt, add_special_tokens=False)["input_ids"]
    sampling_params['max_new_tokens'] = sampling_params['max_new_tokens'] - len(input_token_ids) - 64
    
    # Prepare payload - shared structure
    payload = {
        "input_ids": input_token_ids,
        "sampling_params": sampling_params,
        "return_logprob": True,
    }

    output = await post(url, payload, use_http2=args.use_http2)

    # Extract new response tokens
    if "output_token_logprobs" in output["meta_info"]:
        new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
        new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
    else:
        # abort
        new_response_tokens = []
        new_response_log_probs = []

    # Update sample with tokens directly - avoiding re-tokenization
    sample.tokens = input_token_ids + new_response_tokens
    sample.response_length = len(new_response_tokens)
    sample.response = output["text"]
    if sample.rollout_log_probs is None:
        sample.rollout_log_probs = []
    sample.rollout_log_probs += new_response_log_probs

    match output["meta_info"]["finish_reason"]["type"]:
        case "length":
            sample.status = Sample.Status.TRUNCATED
        case "abort":
            sample.status = Sample.Status.ABORTED
        case "stop":
            sample.status = Sample.Status.COMPLETED

    return sample


async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluation=False) -> Sample:
    # For samples with existing response, check if they're complete
    if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED:
        assert sample.response is not None
        if not args.group_rm:
            assert sample.reward is not None
        return sample

    state = GenerateState(args)

    # generate
    async with state.semaphore:
        if state.aborted:
            sample.status = Sample.Status.ABORTED
            return sample

        if args.custom_generate_function_path is not None:
            custom_generate_func = load_function(args.custom_generate_function_path)
            sample = await custom_generate_func(args, sample, sampling_params)
        else:
            sample = await generate(args, sample, sampling_params.copy(), evaluation=evaluation)

    if sample.status == Sample.Status.ABORTED:
        return sample

    # for the rm that need the whole group, we will not do the rm here
    if args.group_rm:
        return sample

    sample.reward = await async_rm(args, sample, evaluation=evaluation)
    return sample


async def generate_and_rm_group(args, group: list[Sample], sampling_params: dict, evaluation=False) -> list[Sample]:
    state = GenerateState(args)

    if state.aborted:
        return group

    group = await asyncio.gather(
        *[generate_and_rm(args, sample, sampling_params.copy(), evaluation=evaluation) for sample in group]
    )

    # for the rm that need the whole group, we will not do the rm here
    if not state.aborted and args.group_rm:
        rewards = await batched_async_rm(args, group)
        for sample, reward in zip(group, rewards):
            sample.reward = reward

    return group


async def abort(args, rollout_id: int):
    aborted_samples = []

    state = GenerateState(args)
    assert not state.aborted
    state.aborted = True
    response = await get(
        f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers", use_http2=args.use_http2
    )

    # abort all the requests
    for url in response["urls"]:
        print(f"Abort request for {url}", flush=True)
        await post(f"{url}/abort_request", {"abort_all": True}, use_http2=False)

    # make sure all the pending tasks are finished
    count = 0
    while state.pendings:
        done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)

        if not args.partial_rollout:
            continue

        # for partial rollout, collect the partial samples into the data buffer
        for task in done:
            group = task.result()
            for sample in group:
                if sample.response and "start_rollout_id" not in sample.metadata:
                    sample.metadata["start_rollout_id"] = rollout_id
            aborted_samples += group
            count += len(group)

    if args.partial_rollout:
        print(f"Collected {count} partial samples into the data buffer", flush=True)

    return aborted_samples


async def generate_rollout_async(args, rollout_id: int, data_source) -> list[list[Sample]]:
    """An example to implement the generate_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_source: the data source to fetch

    Returns:
        list[list[Sample]]: a list of samples generated by the rollout, the length of the list is exactly the same as the `rollout_batch_size`
    """
    assert args.rollout_global_dataset

    state = GenerateState(args)

    # instantiate data filters
    dynamic_filter = (
        load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None
    )

    # target_data_size is the total number of valid samples to get
    target_data_size = args.rollout_batch_size

    dapo_info = {
        "rollout_id": rollout_id,
        "reward": [], "truncated": [],
    }
    sampled_data = []

    data = []
    do_print = True
    pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
    while len(data) < target_data_size:
        while state.remaining_batch_size < target_data_size:
            # get samples from the buffer and submit the generation requests.
            samples = data_source(args.over_sampling_batch_size)
            state.submit_generate_tasks(samples)

        # wait for the generation to finish
        done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            group: list[Sample] = task.result()

            sampled_data.append(convert_samples_to_data(group))
            # info before filtering
            for sample_group in group:
                dapo_info['reward'].append(sample_group.reward)
                dapo_info['truncated'].append(1 if sample_group.status == Sample.Status.TRUNCATED else 0)

            if do_print:
                print(
                    f"First rollout sample: {[group[0].prompt + group[0].response]}, label: {group[0].label}, reward: {group[0].reward}",
                    flush=True,
                )
                do_print = False

            assert len(group) == args.n_samples_per_prompt
            if dynamic_filter is not None and not dynamic_filter(args, group):
                print(f"dynamically filter out this group, all {group[0].reward=} | {len(group)=} | {len(data)=}")
                state.remaining_batch_size -= 1
                continue

            # add the samples to the data
            # NOTE: here we have not stored all the unused samples back to the data buffer.
            if len(data) < target_data_size:
                data.append(group)
                pbar.update(args.n_samples_per_prompt)

    pbar.close()
    print(
        f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, label: {data[-1][0].label}, reward: {data[-1][0].reward}",
        flush=True,
    )

    # there are still some unfinished requests, abort them
    aborted_samples = await abort(args, rollout_id)

    assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
    data = sorted(data, key=lambda group: group[0].index)

    # save sampled data and info before filter
    save_info_and_data(args, rollout_id, dapo_info, sampled_data)

    # reset the global state to prevent effects on the next rollout or eval.
    state.reset()
    return data, aborted_samples


EVAL_PROMPT_DATASET = {}


async def eval_rollout(args, rollout_id):
    assert not args.group_rm, "Group RM is not supported for eval rollout"
    results = {}
    for i in range(0, len(args.eval_prompt_data), 2):
        name, path = args.eval_prompt_data[i : i + 2]
        results.update(await eval_rollout_single_dataset(args, rollout_id, name, path))
    return results, []


async def eval_rollout_single_dataset(args, rollout_id, name, path):
    """An example to implement the eval_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        name: str, the name of the dataset
        path: str, the path of the dataset
    """
    assert not args.group_rm, "Group RM is not supported for eval rollout"

    global EVAL_PROMPT_DATASET

    if name not in EVAL_PROMPT_DATASET:
        tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
        EVAL_PROMPT_DATASET[name] = Dataset(
            path,
            tokenizer=tokenizer,
            max_length=args.rollout_max_prompt_len,
            prompt_key=args.input_key if args.eval_input_key is None else args.eval_input_key,
            label_key=args.label_key if args.eval_label_key is None else args.eval_label_key,
            metadata_key=args.metadata_key,
            tool_key=args.tool_key if args.eval_tool_key is None else args.eval_tool_key,
            apply_chat_template=args.apply_chat_template,
        )
    dataset = EVAL_PROMPT_DATASET[name]

    sampling_params = dict(
        temperature=args.rollout_temperature if args.eval_temperature is None else args.eval_temperature,
        top_p=args.rollout_top_p if args.eval_top_p is None else args.eval_top_p,
        top_k=args.rollout_top_k if args.eval_top_k is None else args.eval_top_k,
        max_new_tokens=(
            args.rollout_max_response_len if args.eval_max_response_len is None else args.eval_max_response_len
        ),
        stop=args.rollout_stop,
        stop_token_ids=args.rollout_stop_token_ids,
        skip_special_tokens=args.rollout_skip_special_tokens,
        no_stop_trim=True,
        spaces_between_special_tokens=False,
        # presence_penalty=1.5,
    )

    tasks = []
    # do multiple samples for eval prompts
    sample_index = 0
    for i, prompt_sample in enumerate(dataset.samples):
        for j in range(args.n_samples_per_eval_prompt):
            # use the same prompt for multiple samples
            sample = copy.deepcopy(prompt_sample)
            sample.index = sample_index
            sample_index += 1
            tasks.append(
                generate_and_rm(
                    args,
                    sample,
                    sampling_params=sampling_params.copy(),
                    evaluation=True,
                )
            )

    data = []
    do_print = True
    pbar = tqdm(total=len(tasks), desc="Rollout generation", disable=not do_print)
    for coro in asyncio.as_completed(tasks):
        sample = await coro
        if do_print:
            print([sample.prompt + sample.response], sample.reward, flush=True)
            do_print = False
        data.append(sample)
        pbar.update(1)
    pbar.close()

    data.sort(key=lambda sample: sample.index)
    
    # save_data
    save_eval = convert_samples_to_data(data)
    save_eval_path = osp.join(args.save, "eval_data")
    os.makedirs(save_eval_path, exist_ok=True)

    with open(osp.join(save_eval_path, f"eval_rollout_{rollout_id}.jsonl"), "a") as f:
        for item in save_eval:
            json.dump(item, f, ensure_ascii=False)
            f.write("\n")
    
    reward_key = args.reward_key or args.eval_reward_key
    eval_result = {}
    for item in data:
        data_source = item.metadata['data_source']
        if data_source not in eval_result:
            eval_result[data_source] = {"rewards": [], "truncated": []}
            eval_result[data_source + "_verify"] = {"rewards": []}

        eval_result[data_source]["rewards"].append(item.reward[0] if not reward_key else item.reward[reward_key][0])
        eval_result[data_source + "_verify"]["rewards"].append(item.reward[-1] if not reward_key else item.reward[reward_key][-1])
        
        eval_result[data_source]["truncated"].append(item.status == Sample.Status.TRUNCATED)
    
    # 对多个数据集求平均
    avg_eval = {
        "avg": [],
        "avg_verify": []
    }
    for dataname, eval_metric in eval_result.items():
        if '_verify' in dataname:
            avg_eval['avg_verify'].extend(eval_metric['rewards'])
        else:
            avg_eval['avg'].extend(eval_metric['rewards'])
    
    eval_result["avg"] = {"rewards": avg_eval['avg']}
    eval_result["avg_verify"] = {"rewards": avg_eval['avg_verify']}

    # 先求平均，然后存储指标
    save_eval_result = {}
    for k, v in eval_result.items():
        save_eval_result[k] = {}
        for k1, v1 in v.items():
            if len(v1) > 0:
                save_eval_result[k][k1] = float(np.sum(v1))
    
    with open(osp.join(save_eval_path, f"eval_result.jsonl"), "a") as f:
        json.dump(save_eval_result, f, ensure_ascii=False)
        f.write("\n")
    
    return eval_result

# TODO remove this temp function
def lean_generate_rollout(args, rollout_id, data_buffer, evaluation=False):
    """An example to implement the generate_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_buffer: the data buffer to store the generated samples
        evaluation: bool, whether the rollout is for evaluation or not

    Returns:
        list[list[Sample]]: a list of list of samples generated by the rollout
    """
    completed_samples, aborted_samples = generate_abortable_samples(
        args, rollout_id, data_buffer.get_samples, evaluation=evaluation
    )
    data_buffer.add_samples(aborted_samples)
    return completed_samples


def generate_abortable_samples(args, rollout_id, data_source, evaluation=False):
    assert args.rollout_global_dataset
    if evaluation:
        return run(eval_rollout(args, rollout_id))
    return run(generate_rollout_async(args, rollout_id, data_source))
