# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""

import os

import hydra
import numpy as np
import ray

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from pprint import pprint

import pandas as pd
from omegaconf import OmegaConf, open_dict

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.device import is_cuda_available

import json
import math
import re
import random

@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
    run_generation(config)

Skywork_SYSTEM_PROMPT = """You are a helpful assistant in evaluating the quality of the responses for a given instruction. Your goal is to select the best response for the given instruction.
Select Assistant 1 or Assistant 2, that is better for the given instruction. The two responses are generated by two different AI assistants respectively.
Do NOT say both / neither are good.

Here are some rules of the evaluation:
(1) If the instruction does not contain harmful content, you should prioritize evaluating whether the output honestly/precisely/closely executes the instruction, then consider its helpfulness, accuracy, level of detail, harmlessness, etc.
(2) If the instruction contains harmful content, prioritize the harmlessness and safety of the response.
(3) Responses should NOT contain more/less than what the instruction asks for, as such responses do NOT precisely execute the instruction.
(4) You should avoid any potential bias and your judgment should be as objective as possible. Here are some potential sources of bias:
- The order in which the responses were presented should NOT affect your judgment, as Response A and Response B are equally likely to be the better.
- The length of the responses should NOT affect your judgement, as a longer response does not necessarily correspond to a better response. When making your decision, evaluate if the response length is appropriate for the given instruction.
(5) Your output should only consist of '\\boxed{Assistant 1}' if assistant 1 is better, or '\\boxed{Assistant 2}' if assistant 2 is better. Omit any other output.

"""

Skywork_PROMPT = """## Query

{question}

## Assistant responses

### Assistant 1

{answer1}


### Assistant 2

{answer2}

"""

Skywork_ASSISTANT_PROMPT = """## Analysis

Let's analyze this step by step and decide which assistant is better, and then answer \\boxed{Assistant 1} or \\boxed{Assistant 2}."""



def extract_verdict(text):
    pattern = re.compile(r"\\boxed\{(.*?)\}")
    match = pattern.search(text)
    if match:
        return match.group(1)
    return None


def run_generation(config) -> None:
    if not ray.is_initialized():
        # this is for local ray cluster
        ray.init(
            runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
            num_cpus=config.ray_init.num_cpus,
        )

    ray.get(main_task.remote(config))


@ray.remote(num_cpus=1)
def main_task(config):
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    local_path = copy_to_local(config.model.path)
    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    if config.rollout.temperature == 0.0:
        assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
    assert config.data.n_samples >= 1, "n_samples should always >= 1"

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if config.data.path.endswith('.parquet'):
        dataset = pd.read_parquet(config.data.path)
    elif config.data.path.endswith('.jsonl'):
        dataset = [json.loads(x) for x in open(config.data.path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)

    # rearrange after maybe new dataset constructing
    prompt_lst = dataset[config.data.prompt_key].tolist()
    prompt_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in prompt_lst]

    # responses_lst
    responses_lst = dataset[config.data.response_key].tolist()
    responses_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in responses_lst]

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu")
    wg.init_model()

    model_name = config.model.path.split('/')[-1]

    with open_dict(config.rollout):
        if config.rollout.get("responses_size", None) is None:
            config.rollout.responses_size = len(responses_lst[0])

    print(f"group size {config.rollout.n_candidates}")
    print(f"total size {config.rollout.responses_size}")
    assert config.rollout.responses_size % config.rollout.n_candidates == 0
    num_groups_per_prompt = config.rollout.responses_size // config.rollout.n_candidates
    num_rounds = int(math.log2(config.rollout.n_candidates))

    # 初始化所有的参赛者的索引
    # active_participants 的结构: List[List[List[int]]]
    # 维度: [prompt_idx][group_idx][participant_list]
    all_indices = list(range(config.rollout.responses_size))
    active_participants = [
        [
            all_indices[i:i + config.rollout.n_candidates]
            for i in range(0, config.rollout.responses_size, config.rollout.n_candidates)
        ]
        for _ in responses_lst
    ]

    for round_num in range(num_rounds):
        print(f"[{round_num+1}/{num_rounds}] knockout tournament to start.")

        batch_prompts = []
        match_metadata = []     # 存储元数据以便将结果映射回去
        # 1. 收集本轮所有比赛
        for prompt_idx, groups_indices in enumerate(active_participants):
            for group_idx, participants_indices in enumerate(groups_indices):
                # 将本组的参赛者两两配对
                for match_idx in range(0, len(participants_indices), 2):
                    index_a = participants_indices[match_idx]
                    index_b = participants_indices[match_idx + 1]

                    response_a = responses_lst[prompt_idx][index_a]
                    response_b = responses_lst[prompt_idx][index_b]

                    # 提取 </think> 之后的内容，构造候选对
                    if "</think>" in response_a:
                        response_a = response_a.split("</think>")[1]
                    if "</think>" in response_b:
                        response_b = response_b.split("</think>")[1]


                    # 收集问题
                    if isinstance(prompt_lst[prompt_idx], list):
                        for turn in prompt_lst[prompt_idx]:
                            if turn.get("role") == "user" and isinstance(turn.get("content"), str):
                                problem_text = turn.get("content").strip()
                    
                    # 套用 prompt 模板
                    user_prompt = Skywork_PROMPT.format(
                        question=problem_text, answer1=response_a, answer2=response_b
                    ) + Skywork_ASSISTANT_PROMPT

                    conversation = [
                        {
                            "role": "system",
                            "content": Skywork_SYSTEM_PROMPT,
                        },

                        {"role":"user", "content": user_prompt}
                    ]

                    batch_prompts.append(conversation)
                    # 记录这场比赛的来源
                    match_metadata.append({
                        "prompt_idx": prompt_idx,
                        "group_idx": group_idx,
                        "pair_indices": (index_a, index_b)
                    })

        if not batch_prompts:
            print("本轮没有比赛，锦标赛结束。")
            break
        # 2. 使用 vLLM 批量推理
        total_samples = len(batch_prompts)
        print(f"{total_samples} samples to generate.")

        inputs = tokenizer.apply_chat_template(
            batch_prompts,
            add_generation_prompt=True,
            padding=True,
            truncation=True,
            max_length=config.rollout.prompt_length,
            return_tensors="pt",
            return_dict=True,
            tokenize=True,
        )

        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        position_ids = compute_position_id_with_mask(attention_mask)
        batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}

        data = DataProto.from_dict(batch_dict)
        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)

        output_padded = wg.generate_sequences(data_padded)
        output = unpad_dataproto(output_padded, pad_size=pad_size)
        output_texts = []
        for i in range(len(output)):
            data_item = output[i]
            prompt_length = data_item.batch["prompts"].shape[-1]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = data_item.batch["responses"][:valid_response_length]
            response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)
            output_texts.append(response_str)

        # 3. 处理比赛结果，确定下一轮的参赛者
        next_round_participants = [[[] for _ in range(num_groups_per_prompt)] for _ in range(len(prompt_lst))]

        for i, output_text in enumerate(output_texts):
            meta = match_metadata[i]
            index_a, index_b = meta["pair_indices"]
            
            # 确定胜者 (如果模型输出无效，则随机选择)
            extract_answer = extract_verdict(output_text)
            if extract_answer == 'Assistant 1':
                winner_index = index_a
            elif extract_answer == 'Assistant 2':
                winner_index = index_b
            else:
                winner_index = random.choice([index_a, index_b])
            
            # 将胜者放入下一轮的列表中
            next_round_participants[meta["prompt_idx"]][meta["group_idx"]].append(winner_index)

        # 更新 active_participants 为下一轮的名单
        active_participants = next_round_participants
        print(f"第 {round_num + 1} 轮结束。")


    # 锦标赛结束后，active_participants 中每个组只剩一个冠军
    champion_indices_per_prompt = [[group[0] for group in prompt_groups] for prompt_groups in active_participants]

    # 构建最终输出
    knockout_flags = []
    for i in range(len(prompt_lst)):
        # 为当前 prompt 的所有冠军索引创建一个 Set，以实现 O(1) 的快速查找
        champion_index_set = set(champion_indices_per_prompt[i])
        
        # 直接通过索引生成 knockout_flags，无需任何字符串操作
        knockout_flag = [
            idx in champion_index_set 
            for idx in range(config.rollout.responses_size)
        ]
        knockout_flags.append(knockout_flag)
        
    dataset["knockout_flags"] = knockout_flags

    selected_data = dataset[['scores', 'knockout_flags']]

    if config.data.path.endswith('.parquet'):
        out_file = config.data.path.replace(".parquet", f"_@{config.rollout.n_candidates}_{model_name}.parquet")
        selected_data.to_parquet(out_file)
    elif config.data.path.endswith('.jsonl'):
        out_file = config.data.path.replace(".jsonl", f"_@{config.rollout.n_candidates}_{model_name}.jsonl")
        selected_data.to_json(out_file, orient='records', lines=True, force_ascii=False)

if __name__ == "__main__":
    main()
