# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""

import os

import hydra
import numpy as np
import ray

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from pprint import pprint

import pandas as pd
from omegaconf import OmegaConf, open_dict

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.device import is_cuda_available

import json
import math
import re
import random

@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
    run_generation(config)

# For Reasoning Model, we follow the default format (User, Assistant)
REASONING_SINGLE_PROMPT_TEMPLATE = (
    "Please act as an impartial judge and evaluate the quality of the responses provided by two AI Chatbots to the Client question displayed below. \n\n"
    "[Client Question]\n{question}\n\n[The Start of Chatbot A's Response]\n{answer_a}\n[The End of Chatbot A's Response]\n\n"
    "[The Start of Chatbot B's Response]\n{answer_b}\n[The End of Chatbot B's Response]" + "\n\n"
    "Output your final verdict at last by strictly following this format: "
    "'<answer>[[A]]</answer>' if Chatbot A is better, or '<answer>[[B]]</answer>' if Chatbot B is better."
)


def extract_verdict(text):
    match = re.search(r"<answer>\[\[(A|B)\]\]</answer>", text)
    if match:
        return match.group(1)
    return None


def run_generation(config) -> None:
    if not ray.is_initialized():
        # this is for local ray cluster
        ray.init(
            runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
            num_cpus=config.ray_init.num_cpus,
        )

    ray.get(main_task.remote(config))


@ray.remote(num_cpus=1)
def main_task(config):
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    local_path = copy_to_local(config.model.path)
    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    if config.rollout.temperature == 0.0:
        assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
    assert config.data.n_samples >= 1, "n_samples should always >= 1"

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if config.data.path.endswith('.parquet'):
        dataset = pd.read_parquet(config.data.path)
    elif config.data.path.endswith('.jsonl'):
        dataset = [json.loads(x) for x in open(config.data.path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)

    # rearrange after maybe new dataset constructing
    prompt_lst = dataset[config.data.prompt_key].tolist()
    prompt_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in prompt_lst]

    # responses_lst
    responses_lst = dataset[config.data.response_key].tolist()
    responses_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in responses_lst]

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu")
    wg.init_model()

    model_name = config.model.path.split('/')[-1]

    with open_dict(config.rollout):
        if config.rollout.get("responses_size", None) is None:
            config.rollout.responses_size = len(responses_lst[0])

    print(f"group size {config.rollout.n_candidates}")
    print(f"total size {config.rollout.responses_size}")
    assert config.rollout.responses_size % config.rollout.n_candidates == 0
    num_groups_per_prompt = config.rollout.responses_size // config.rollout.n_candidates
    num_rounds = int(math.log2(config.rollout.n_candidates))

    # 初始化所有的参赛者的索引
    # active_participants 的结构: List[List[List[int]]]
    # 维度: [prompt_idx][group_idx][participant_list]
    all_indices = list(range(config.rollout.responses_size))
    active_participants = [
        [
            all_indices[i:i + config.rollout.n_candidates]
            for i in range(0, config.rollout.responses_size, config.rollout.n_candidates)
        ]
        for _ in responses_lst
    ]

    for round_num in range(num_rounds):
        print(f"[{round_num+1}/{num_rounds}] knockout tournament to start.")

        batch_prompts = []
        match_metadata = []     # 存储元数据以便将结果映射回去
        # 1. 收集本轮所有比赛
        for prompt_idx, groups_indices in enumerate(active_participants):
            for group_idx, participants_indices in enumerate(groups_indices):
                # 将本组的参赛者两两配对
                for match_idx in range(0, len(participants_indices), 2):
                    index_a = participants_indices[match_idx]
                    index_b = participants_indices[match_idx + 1]

                    response_a = responses_lst[prompt_idx][index_a]
                    response_b = responses_lst[prompt_idx][index_b]

                    # 提取 </think> 之后的内容，构造候选对
                    if "</think>" in response_a:
                        response_a = response_a.split("</think>")[1]
                    if "</think>" in response_b:
                        response_b = response_b.split("</think>")[1]


                    # 收集问题
                    if isinstance(prompt_lst[prompt_idx], list):
                        for turn in prompt_lst[prompt_idx]:
                            if turn.get("role") == "user" and isinstance(turn.get("content"), str):
                                problem_text = turn.get("content").strip()
                    
                    # 套用 prompt 模板
                    user_prompt_single = REASONING_SINGLE_PROMPT_TEMPLATE.format(
                        question=problem_text,
                        answer_a=response_a,
                        answer_b=response_b
                    ) 
                    conversation = [
                        {"role":"user", "content": user_prompt_single}
                    ]

                    batch_prompts.append(conversation)
                    # 记录这场比赛的来源
                    match_metadata.append({
                        "prompt_idx": prompt_idx,
                        "group_idx": group_idx,
                        "pair_indices": (index_a, index_b)
                    })

        if not batch_prompts:
            print("本轮没有比赛，锦标赛结束。")
            break
        # 2. 使用 vLLM 批量推理
        total_samples = len(batch_prompts)
        print(f"{total_samples} samples to generate.")

        inputs = tokenizer.apply_chat_template(
            batch_prompts,
            add_generation_prompt=True,
            padding=True,
            truncation=True,
            max_length=config.rollout.prompt_length,
            return_tensors="pt",
            return_dict=True,
            tokenize=True,
        )

        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        position_ids = compute_position_id_with_mask(attention_mask)
        batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}

        data = DataProto.from_dict(batch_dict)
        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)

        output_padded = wg.generate_sequences(data_padded)
        output = unpad_dataproto(output_padded, pad_size=pad_size)
        output_texts = []
        for i in range(len(output)):
            data_item = output[i]
            prompt_length = data_item.batch["prompts"].shape[-1]
            valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
            valid_response_ids = data_item.batch["responses"][:valid_response_length]
            response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)
            output_texts.append(response_str)

        # 3. 处理比赛结果，确定下一轮的参赛者
        next_round_participants = [[[] for _ in range(num_groups_per_prompt)] for _ in range(len(prompt_lst))]

        for i, output_text in enumerate(output_texts):
            meta = match_metadata[i]
            index_a, index_b = meta["pair_indices"]
            
            # 确定胜者 (如果模型输出无效，则随机选择)
            extract_answer = extract_verdict(output_text)
            if extract_answer == 'A':
                winner_index = index_a
            elif extract_answer == 'B':
                winner_index = index_b
            else:
                winner_index = random.choice([index_a, index_b])
            
            row = meta["prompt_idx"]

            # 将胜者放入下一轮的列表中
            next_round_participants[meta["prompt_idx"]][meta["group_idx"]].append(winner_index)

        # 更新 active_participants 为下一轮的名单
        active_participants = next_round_participants
        print(f"第 {round_num + 1} 轮结束。")


    # 锦标赛结束后，active_participants 中每个组只剩一个冠军
    champion_indices_per_prompt = [[group[0] for group in prompt_groups] for prompt_groups in active_participants]

    # 构建最终输出
    knockout_flags = []
    for i in range(len(prompt_lst)):
        # 为当前 prompt 的所有冠军索引创建一个 Set，以实现 O(1) 的快速查找
        champion_index_set = set(champion_indices_per_prompt[i])
        
        # 直接通过索引生成 knockout_flags，无需任何字符串操作
        knockout_flag = [
            idx in champion_index_set 
            for idx in range(config.rollout.responses_size)
        ]
        knockout_flags.append(knockout_flag)
        
    dataset["knockout_flags"] = knockout_flags

    selected_data = dataset[['scores', 'knockout_flags']]

    if config.data.path.endswith('.parquet'):
        out_file = config.data.path.replace(".parquet", f"_@{config.rollout.n_candidates}_{model_name}.parquet")
        selected_data.to_parquet(out_file)
    elif config.data.path.endswith('.jsonl'):
        out_file = config.data.path.replace(".jsonl", f"_@{config.rollout.n_candidates}_{model_name}.jsonl")
        selected_data.to_json(out_file, orient='records', lines=True, force_ascii=False)

if __name__ == "__main__":
    main()
