from typing import List, Dict, Any, Tuple
import os
import os.path as osp
import re
import json5
import json
import time
import uuid
import asyncio
import copy
import requests

from tqdm import tqdm
from transformers import AutoTokenizer

from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.function_call.function_call_parser import FunctionCallParser


from slime.utils.async_utils import run
from slime.utils.data import Dataset
from slime.utils.types import Sample
from slime.utils.http_utils import get, post
from slime.utils.misc import SingletonMeta, load_function
from slime.rollout.rm_hub import async_rm, batched_async_rm

from .context_agent import ContextAgent, AGENT_TOOLS_SCHEMA, parse_text_output
from .reward_shaping import reward_shaping

__all__ = ["context_generate_rollout"]


from sglang.srt.entrypoints.openai.protocol import Tool

pydantic_tools = [Tool.model_validate(tool_dict) for tool_dict in AGENT_TOOLS_SCHEMA]


def convert_samples_to_data(samples: list[Sample]):
    return [sample.to_dict() for sample in samples]

def save_info_and_data(args, rollout_id: int, info: dict, data: List[List[dict]]):
    save_path = osp.join(args.save, "rollout_data")
    os.makedirs(save_path, exist_ok=True)

    with open(osp.join(save_path, f"rollout_info.jsonl"), "a") as f:
        f.write(json.dumps(info, ensure_ascii=False) + "\n")

    with open(osp.join(save_path, f"rollout_{rollout_id}.json"), "a") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


class GenerateState(metaclass=SingletonMeta):
    """
    The global state for the generation process.
    """

    def __init__(self, args):
        # persistant state for the generation process
        self.args = args
        self.tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
        self.semaphore = asyncio.Semaphore(
            args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine
        )
        print(f"rollout concurrency: {args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine}", flush=True)
        self.sampling_params = dict(
            temperature=args.rollout_temperature,
            top_p=args.rollout_top_p,
            top_k=args.rollout_top_k,
            max_new_tokens=args.rollout_max_response_len,
            stop=args.rollout_stop,
            stop_token_ids=args.rollout_stop_token_ids,
            skip_special_tokens=args.rollout_skip_special_tokens,
            no_stop_trim=True,
            spaces_between_special_tokens=False,
        )
        self.reset()

    def reset(self):
        self.remaining_batch_size = 0
        self.pendings = set()
        self.aborted = False

    def submit_generate_tasks(self, samples: list[list[Sample]]):
        for group in samples:
            self.pendings.add(
                asyncio.create_task(
                    # submit a group of samples as a single task.
                    generate_and_rm_group(
                        self.args,
                        group,
                        sampling_params=self.sampling_params.copy(),
                        evaluation=False,
                    )
                )
            )
        self.remaining_batch_size += len(samples)


TOOL_CALL_PATTERN = re.compile(r"<tool_call>(.*?)</tool_call>", re.DOTALL)

def parse_qwen_tool_calls(text: str) -> Tuple[List[Dict[str, Any]], str]:
    matches = TOOL_CALL_PATTERN.findall(text)
    if not matches:
        return [], text

    text_without_tool_calls = TOOL_CALL_PATTERN.sub("", text).strip()

    parsed_tool_calls = []
    for i, match_content in enumerate(matches):
        json_str = match_content.strip()
        try:
            tool_call_data = json5.loads(json_str)

            if isinstance(tool_call_data, dict) and "name" in tool_call_data and "arguments" in tool_call_data:
                if not isinstance(tool_call_data["arguments"], dict):
                     raise ValueError(f"'arguments' field must be a dictionary, but got {type(tool_call_data['arguments'])}.")

                formatted_call = {
                    "id": f"call_{str(uuid.uuid4())}",
                    "type": "function",
                    "function": {
                        "name": tool_call_data["name"],
                        "arguments": json.dumps(tool_call_data["arguments"])
                    }
                }
                parsed_tool_calls.append(formatted_call)
            else:
                print(f"Warning: Parsed tool call object is missing required keys or has wrong format. Content: {json_str}")

        except Exception as e:
            print(f"Warning: Failed to parse or format tool call. Error: {e}. Content: {json_str}")
            continue
            
    return parsed_tool_calls, text_without_tool_calls

def text_dump(response):
    content = '<think>\n' + response.choices[0].message.reasoning_content + '</think>\n\n' + response.choices[0].message.content
    tool_content = ''
    
    tool_calls = response.choices[0].message.tool_calls
    tool_calls_list = None
    assert tool_calls is None or len(tool_calls) <= 1, 'Multiple tool calls in a single turn are not allowed.'

    if tool_calls:
        function = tool_calls[0].function
        arguments = json5.loads(function.arguments)
        call_format = {
            "name": function.name,
            "arguments": arguments
        }
        tool_content = f"\n<tool_call>\n{json.dumps(call_format)}\n</tool_call>"

        tool_call_format = {
            "id": tool_calls[0].id,
            "type": "function",
            "function": {
                "name": function.name,
                "arguments": json.dumps(call_format["arguments"])
            }
        }
        tool_calls_list = [tool_call_format]

    return content + tool_content, tool_calls_list

def check_report_action(text, tool_calls):
    parse_status, parse_output = parse_text_output(text)

    if not parse_status:
        return False, parse_output # parse_output is str here

    if parse_output['action'] is None and parse_output['answer'] is None:
        return False, 'Neither action nor answer tag found.'

    # tool_calls, _ = parse_qwen_tool_calls(text)

    if parse_output['answer'] is None and len(tool_calls) == 0:
        return False, 'No answer and no tool call detected.'
    
    if parse_output['answer'] is None and len(tool_calls) > 1:
        return False, 'Multiple tool calls in a single turn are not allowed.'
        
    return True, 'Success'

async def generate(args, sample: Sample, sampling_params) -> List[Sample]:
    state = GenerateState(args)

    url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}/generate"
    tool_parser = FunctionCallParser(pydantic_tools, "qwen25")
    reasoning_parser = ReasoningParser('qwen3')

    assert (
        sample.status == Sample.Status.PENDING or sample.status == Sample.Status.ABORTED
    ), f"Sample status is {sample.status}"

    async def sglang_llm_caller(messages: list, params: dict, use_tools: bool, max_retry: int=3) -> dict | None:
        prompt_text = state.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
            tools=AGENT_TOOLS_SCHEMA if use_tools else None, 
        )

        # initialize with prompt tokens
        input_token_ids = state.tokenizer(prompt_text, add_special_tokens=False)["input_ids"]

        params['max_new_tokens'] = params['max_new_tokens'] - len(input_token_ids) - 128

        if params['max_new_tokens'] < 0:
            return {
                "input_tokens": input_token_ids,
                "text": None,
                "output_ids": None,
                "meta_info": {"finish_reason": {"type": "exceeded_max_length"}},
            }
        
        # Prepare payload - shared structure
        payload = {
            "input_ids": input_token_ids,
            "sampling_params": params,
            "return_logprob": True,
        }
        
        output = await post(url, payload, use_http2=args.use_http2)

        # Extract new response tokens
        if "output_token_logprobs" in output["meta_info"]:
            new_response_tokens = [item[1] for item in output["meta_info"]["output_token_logprobs"]]
        else:
            # abort
            new_response_tokens = []

        output['input_tokens'] = input_token_ids
        output['new_response_tokens'] = new_response_tokens

        # Extract rollout log probabilities for off-policy correction
        if args.use_tis:
            new_response_log_probs = [item[0] for item in output["meta_info"]["output_token_logprobs"]]
            output['rollout_log_probs'] = new_response_log_probs

        if output['meta_info']["finish_reason"]["type"] in ["abort"]:
            return output

        # parse reasoning content and tool calls
        no_tool_text, tool_call_list = tool_parser.parse_non_stream(output['text'])
        reasoning_text, no_think_text = reasoning_parser.parse_non_stream(output['text'])
        tool_call_list = [dict(tool_call) for tool_call in tool_call_list]

        output['parse_info'] = {
            "reasoning_text": reasoning_text,
            "no_think_text": no_think_text,
            "tool_calls": tool_call_list,
        }

        # format check
        if len(reasoning_text) == 0 or len(no_think_text) == 0:
            output['meta_info']['finish_reason']['type'] = "format_error"
            print("Invalid LLM response: 'reasoning_content' or 'content' field is missing or empty.", flush=True)

        is_success, reason = check_report_action(no_think_text, tool_call_list)
        if not is_success:
            output['meta_info']['finish_reason']['type'] = "format_error"
            print(f"LLM response format check failed: {reason}", flush=True)
        
        return output
    
    agent = ContextAgent(
        question=sample.prompt,
        sampling_params=sampling_params.copy(),
        llm_caller=sglang_llm_caller,
    )
    
    conversations, traj_status = await agent.rollout()
    stop_reason = traj_status['status']

    if len(conversations) == 0 or (stop_reason in ['abort']):
        sample.status = Sample.Status.ABORTED
        return [sample]

    else:
        final_content = conversations[-1][-1]["content"] if stop_reason not in ["exceeded_max_length", "length", "format_error"] else ''

        sample_list = []
        for idx, mess in enumerate(conversations):
            response = mess[1]['content'] if mess[0]['role'] == "user" else mess[2]['content']
            assert mess[0]['role'] == "user", f"First message must be user for `token-in-token-out`, but got {mess[0]['role']}"

            cur_sample = Sample(
                index=sample.index,
                prompt=mess[:1],
                tokens=mess[0]['tokens'] + mess[1]['tokens'],
                response=response,
                response_length=len(mess[1]['tokens']),
                label=sample.label,
                rollout_log_probs=mess[1]['rollout_log_probs'] if args.use_tis else None,

                metadata=sample.metadata.copy(),
                tool_status=True if mess[-1]['role'] == 'tool' and mess[-1].get("tool_status", False) else False
            )
            # sample group info
            cur_sample.metadata['group_idx'] = idx
            cur_sample.metadata['group_size'] = len(conversations)
            cur_sample.metadata['final_content'] = final_content
            cur_sample.metadata['traj_status'] = stop_reason

            match stop_reason:
                case "length":
                    cur_sample.status = Sample.Status.TRUNCATED
                case "abort":
                    cur_sample.status = Sample.Status.ABORTED
                case "stop":
                    cur_sample.status = Sample.Status.COMPLETED
                case "format_error":
                    cur_sample.status = Sample.Status.FORMATERROR
                case "exceeded_max_length":
                    cur_sample.status = Sample.Status.TRUNCATED
            
            sample_list.append(cur_sample)

        if stop_reason in ['format_error']:
            # 只保留最后一个format error的
            sample_list = sample_list[-1:]

    return sample_list


async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluation=False) -> List[Sample]:
    # For samples with existing response, check if they're complete
    if sample.status == Sample.Status.COMPLETED or sample.status == Sample.Status.TRUNCATED:
        assert sample.response and sample.reward is not None
        return [sample]

    state = GenerateState(args)

    # generate
    async with state.semaphore:
        if state.aborted:
            sample.status = Sample.Status.ABORTED
            return [sample]

        if args.custom_generate_function_path is not None:
            custom_generate_func = load_function(args.custom_generate_function_path)
            sample_list = await custom_generate_func(args, sample, sampling_params)
        else:
            sample_list = await generate(args, sample, sampling_params)

    assert sample_list[0].status != Sample.Status.PENDING, "Generated samples should not be in PENDING status."
    if not evaluation and sample_list[0].status in [Sample.Status.ABORTED]:
        sample.status = Sample.Status.ABORTED
        return [sample]

    # for the rm that need the whole group, we will not do the rm here
    if args.group_rm:
        assert False, "group rm is not supported yet"
        return [sample]
    
    if sample_list[0].status in [Sample.Status.TRUNCATED, Sample.Status.FORMATERROR]:
        traj_reward = -1.0

    else:
        group_reward_sample = Sample(
            index=sample.index,
            prompt=sample.prompt,
            response=sample_list[0].metadata.get('final_content', ''),
            label=sample.label,
            status=sample_list[0].status,
            metadata=sample_list[0].metadata,
        )
        traj_reward = await async_rm(args, group_reward_sample)

    for cur_sample in sample_list:
        cur_sample.reward = traj_reward

    return sample_list


async def generate_and_rm_group(args, group: list[Sample], sampling_params: dict, evaluation=False) -> list[list[Sample]]:
    state = GenerateState(args)

    if state.aborted:
        return group

    group = await asyncio.gather(
        *[generate_and_rm(args, sample, sampling_params.copy(), evaluation=evaluation) for sample in group]
    )

    # for the rm that need the whole group, we will not do the rm here
    if not state.aborted and args.group_rm:
        rewards = await batched_async_rm(args, group)
        for sample, reward in zip(group, rewards):
            sample.reward = reward

    return group


async def abort(args, rollout_id: int):
    aborted_samples = []

    state = GenerateState(args)
    assert not state.aborted
    state.aborted = True
    response = await get(
        f"http://{args.sglang_router_ip}:{args.sglang_router_port}/list_workers", use_http2=args.use_http2
    )

    # for _ in range(3):
    # abort all the requests
    for url in response["urls"]:
        print(f"Abort request for {url}", flush=True)
        await post(f"{url}/abort_request", {"abort_all": True}, use_http2=False)

    for task in list(state.pendings):
        if not task.done():
            task.cancel()

    # make sure all the pending tasks are finished
    count = 0
    while state.pendings:
        done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)

        if not args.partial_rollout:
            continue

        # for partial rollout, collect the partial samples into the data buffer
        for task in done:
            try:
                group = task.result()  
            except asyncio.CancelledError:
                continue 

            for sample in group:
                if sample.response and "start_rollout_id" not in sample.metadata:
                    sample.metadata["start_rollout_id"] = rollout_id
            aborted_samples += group
            count += len(group)

    if args.partial_rollout:
        print(f"Collected {count} partial samples into the data buffer", flush=True)

    return aborted_samples


async def generate_rollout_async(args, rollout_id: int, data_source) -> list[list[Sample]]:
    """An example to implement the generate_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_source: the data source to fetch

    Returns:
        list[list[Sample]]: a list of samples generated by the rollout, the length of the list is exactly the same as the `rollout_batch_size`
    """
    assert args.rollout_global_dataset

    state = GenerateState(args)

    # instantiate data filters
    dynamic_filter = (
        load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None
    )
    over_sampling_filter = (
        load_function(args.over_sampling_filter_path) if args.over_sampling_filter_path is not None else None
    )

    # target_data_size is the total number of valid samples to get
    target_data_size = args.over_sampling_batch_size if over_sampling_filter is not None else args.rollout_batch_size

    dapo_info = {
        "rollout_id": rollout_id,
        "reward": [], "truncated": [], "format_error": [],
    }
    sampled_data = []
    
    data: list[list[Sample]] = []
    do_print = True
    pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
    while len(data) < target_data_size:
        while state.remaining_batch_size < target_data_size:
            # get samples from the buffer and submit the generation requests.
            samples = data_source(args.over_sampling_batch_size)
            state.submit_generate_tasks(samples)

        # wait for the generation to finish
        done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            group: list[list[Sample]] = task.result()
            
            # data
            sampled_data.append([convert_samples_to_data(sample_group) for sample_group in group])
            # info before filtering
            for sample_group in group:
                dapo_info['reward'].append(sample_group[0].reward if sample_group[0].reward is not None else 0)
                dapo_info['truncated'].append(1 if sample_group[0].status == Sample.Status.TRUNCATED else 0)
                dapo_info['format_error'].append(1 if sample_group[0].status == Sample.Status.FORMATERROR else 0)
    

            if do_print:
                prompt_text = state.tokenizer.apply_chat_template(
                    group[0][0].prompt,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=True,
                )
                print(
                    f"First rollout sample: {[prompt_text + group[0][0].response]}, reward: {group[0][0].reward}",
                    flush=True,
                )
                do_print = False

            assert len(group) == args.n_samples_per_prompt

            # check the status of each traj in the group
            train_traj_status = [Sample.Status.COMPLETED, Sample.Status.TRUNCATED, Sample.Status.FORMATERROR]
            success_traj = sum(traj[0].status in train_traj_status for traj in group)
            if success_traj != len(group):
                if success_traj / len(group) < 0.8:
                    print(f"{success_traj=}, {len(group)=}, so filter out this group | {len(data)=}")
                    state.remaining_batch_size -= 1
                    continue
                else:
                    # discard the invalid sample
                    group = [traj for traj in group if traj[0].status in train_traj_status]
            
            flatten_group = [sample for traj in group for sample in traj]
            if dynamic_filter is not None and not dynamic_filter(args, flatten_group):
                print(f"dynamically filter out this group, all {flatten_group[0].reward=} | {len(flatten_group)=} | {len(data)=}")
                state.remaining_batch_size -= 1
                continue
            
            # NEW ADD: reward shaping for step RL
            if args.step_reward_shaping:
                for conv in group:
                    for conv_sample in conv:
                        conv_sample.metadata['raw_reward'] = conv_sample.reward
                    conv = reward_shaping(args, conv, conv[-1].reward, args.step_reward_shaping_gamma)

            # add the samples to the data
            # NOTE: here we have not stored all the unused samples back to the data buffer.
            if len(data) < target_data_size:
                data.append(group)
                pbar.update(args.n_samples_per_prompt)

    pbar.close()
    prompt_text = state.tokenizer.apply_chat_template(
        data[-1][0][-1].prompt,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=True,
    )
    print(f"Finish rollout: {[prompt_text + data[-1][0][-1].response]}, reward: {data[-1][0][-1].reward}", flush=True)

    # there are still some unfinished requests, abort them
    aborted_samples = await abort(args, rollout_id)

    if over_sampling_filter is not None:
        data = over_sampling_filter(args, data)[: args.rollout_batch_size]
    assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"

    # flatten the groups
    data = [sample for group in data for sample in group]
    data = sorted(data, key=lambda group: group[0].index)

    # save sampled data and info before filter
    save_info_and_data(args, rollout_id, dapo_info, sampled_data)

    # reset the global state to prevent effects on the next rollout or eval.
    state.reset()
    return data, aborted_samples


EVAL_PROMPT_DATASET = {}


async def eval_rollout(args, rollout_id):
    assert not args.group_rm, "Group RM is not supported for eval rollout"
    results = {}
    for i in range(0, len(args.eval_prompt_data), 2):
        name, path = args.eval_prompt_data[i : i + 2]
        results.update(await eval_rollout_single_dataset(args, rollout_id, name, path))
    return results, []


async def eval_rollout_single_dataset(args, rollout_id, name, path):
    """An example to implement the eval_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        name: str, the name of the dataset
        path: str, the path of the dataset
    """
    assert not args.group_rm, "Group RM is not supported for eval rollout"

    global EVAL_PROMPT_DATASET

    if name not in EVAL_PROMPT_DATASET:
        tokenizer = AutoTokenizer.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
        EVAL_PROMPT_DATASET[name] = Dataset(
            path,
            tokenizer=tokenizer,
            max_length=args.rollout_max_prompt_len,
            prompt_key=args.input_key if args.eval_input_key is None else args.eval_input_key,
            label_key=args.label_key if args.eval_label_key is None else args.eval_label_key,
            metadata_key=args.metadata_key,
            tool_key=args.tool_key if args.eval_tool_key is None else args.eval_tool_key,
            apply_chat_template=args.apply_chat_template,
        )
    dataset = EVAL_PROMPT_DATASET[name]

    sampling_params = dict(
        temperature=args.rollout_temperature if args.eval_temperature is None else args.eval_temperature,
        top_p=args.rollout_top_p if args.eval_top_p is None else args.eval_top_p,
        top_k=args.rollout_top_k if args.eval_top_k is None else args.eval_top_k,
        max_new_tokens=(
            args.rollout_max_response_len if args.eval_max_response_len is None else args.eval_max_response_len
        ),
        stop=args.rollout_stop,
        stop_token_ids=args.rollout_stop_token_ids,
        skip_special_tokens=args.rollout_skip_special_tokens,
        no_stop_trim=True,
        spaces_between_special_tokens=False,
    )

    tasks = []
    # do multiple samples for eval prompts
    sample_index = 0
    for i, prompt_sample in enumerate(dataset.samples):
        for j in range(args.n_samples_per_eval_prompt):
            # use the same prompt for multiple samples
            sample = copy.deepcopy(prompt_sample)
            sample.index = sample_index
            sample_index += 1
            tasks.append(
                generate_and_rm(
                    args,
                    sample,
                    sampling_params=sampling_params.copy(),
                    evaluation=True,
                )
            )

    data = []
    do_print = True
    pbar = tqdm(total=len(tasks), desc="Rollout generation", disable=not do_print)
    for coro in asyncio.as_completed(tasks):
        group: list[list[Sample]] = await coro
        if do_print:
            print(group[-1].response,  f"\nreward: {group[-1].reward}", flush=True) # group[-1].prompt, "\n", 
            do_print = False
        data.append(group)
        pbar.update(1)
    pbar.close()

    data.sort(key=lambda sample: sample[0].index)
    reward_key = args.reward_key or args.eval_reward_key
    eval_result = {}
    for item in data:
        data_source = item[0].metadata['data_source']
        if data_source not in eval_result:
            eval_result[data_source] = {"rewards": [], "truncated": []}
        eval_result[data_source]["rewards"].append(item[-1].reward if not reward_key else item[-1].reward[reward_key])
        eval_result[data_source]["truncated"].append(item[-1].status == Sample.Status.TRUNCATED)
    
    return eval_result



# TODO remove this temp function
def context_generate_rollout(args, rollout_id, data_buffer, evaluation=False):
    """An example to implement the generate_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_buffer: the data buffer to store the generated samples
        evaluation: bool, whether the rollout is for evaluation or not

    Returns:
        list[list[Sample]]: a list of list of samples generated by the rollout
    """
    completed_samples, aborted_samples = generate_abortable_samples(
        args, rollout_id, data_buffer.get_samples, evaluation=evaluation
    )
    data_buffer.add_samples(aborted_samples)
    return completed_samples


def generate_abortable_samples(args, rollout_id, data_source, evaluation=False):
    assert args.rollout_global_dataset
    if evaluation:
        return run(eval_rollout(args, rollout_id))
    return run(generate_rollout_async(args, rollout_id, data_source))
