# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""

import os

import hydra
import numpy as np
import ray

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from pprint import pprint

import pandas as pd
from omegaconf import OmegaConf, open_dict

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.device import is_cuda_available

from verl.workers.reward_manager.reward import rlcot_reward_fn
import json
import time

@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
    run_generation(config)

COT_SYNTHESIZER_PROMPT = """
Please act as an excellent summarizer and summarize the following AI responses to the questions. Your summary should fully consider the connection between the question and AI responses, resulting in a correct, high-quality answer. In most cases, the same response that appears most often in the response may be the correct answer. If you find that there is no correct answer, please try to generate a correct answer yourself. Do not copy The candidate's answer, give your summarized answer and reasons, and give the correct answer at the end of the sentence in the format: The answer is...

[The Start of Original Question]
{question}
[The End of Original Question]

[The Start of AI Responses]
{responses}
[The End of AI Responses]
"""

def run_generation(config) -> None:
    if not ray.is_initialized():
        # this is for local ray cluster
        ray.init(
            runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
            num_cpus=config.ray_init.num_cpus,
        )

    ray.get(main_task.remote(config))


@ray.remote(num_cpus=1)
def main_task(config):
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    local_path = copy_to_local(config.model.path)
    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    if config.rollout.temperature == 0.0:
        assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
    assert config.data.n_samples >= 1, "n_samples should always >= 1"

    # get out_file name && maked dir
    dataset_name = os.path.splitext(os.path.basename(config.data.path))[0]

    model_name = config.model.path.split('/')[-1]
    out_file_prefix = f"n{config.data.n_samples}_t{config.rollout.temperature}"
    if config.data.get('candidate_path', None) is not None:      
        out_file_prefix = out_file_prefix + "_" + model_name + f"_cand{config.rollout.n_candidates}_turn{config.rollout.candidate_turn}_maxlen{config.rollout.prompt_length}"
        if config.data.get('tag', None) is not None:    # add some tags for filename to debug
            out_file_prefix = out_file_prefix + config.data.tag
    
    output_dir = config.data.output_path
    out_file = f"{output_dir}/{dataset_name}/{out_file_prefix}.parquet"
    os.makedirs(f"{output_dir}/{dataset_name}", exist_ok=True)

    out_metric_json = out_file.replace(".parquet", f"_metrics.json")
    if os.path.exists(out_metric_json):
        print(f"Skipping {dataset_name} because {out_metric_json} already exists.")
        return
    start_time = time.time()

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if config.data.path.endswith('.parquet'):
        dataset = pd.read_parquet(config.data.path)
    elif config.data.path.endswith('.jsonl'):
        dataset = [json.loads(x) for x in open(config.data.path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)

    # merge the responses into dataset
    if config.data.get('candidate_path', None) is not None:
        import copy
        import re

        if config.data.candidate_path.endswith('.parquet'):
            candidate_dataset = pd.read_parquet(config.data.candidate_path)
        elif config.data.candidate_path.endswith('.jsonl'):
            candidate_dataset = [json.loads(x) for x in open(config.data.candidate_path)]
            if not isinstance(candidate_dataset, pd.core.frame.DataFrame):
                candidate_dataset = pd.DataFrame(candidate_dataset)

        assert len(dataset) == len(candidate_dataset), f"Dataset length mismatch: {len(dataset)} != {len(candidate_dataset)}"
        candidate_responses_lst = candidate_dataset[config.data.response_key].tolist()
        candidate_responses_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in candidate_responses_lst]

        new_rows = []        

        for data_row, response_list in zip(dataset.itertuples(index=False), candidate_responses_lst):
            row_dict = data_row._asdict() 
            prompt = row_dict[config.data.prompt_key]

            for user_idx in range(len(prompt)):
                if prompt[user_idx]['role'] == 'user':
                    break

            for i in range(config.rollout.candidate_turn):
                group_resp = response_list[i * config.rollout.n_candidates : (i + 1) * config.rollout.n_candidates]
                clipped_group_resp = []
                for j in range(len(group_resp)):
                    m = re.search(r"</think>\s*(.*)", group_resp[j], re.DOTALL)
                    if m:
                        # group_resp[j] = m.group(1).strip()
                        clipped_group_resp.append(m.group(1).strip())
                    elif config.rollout.get('no_solution_discard', False) is False:
                        clipped_group_resp.append(group_resp[j])

                # sort by length
                group_resp = sorted(clipped_group_resp, key=len)

                candidate_aggregated_response = ""
                for j, r in enumerate(group_resp):
                    separator_text = f"\nCandidate Response {j+1}:\n"
                    candidate_aggregated_response += separator_text + r
                candidate_aggregated_response = COT_SYNTHESIZER_PROMPT.format(
                    question=prompt[user_idx]['content'],
                    responses=candidate_aggregated_response,
                )

                # truncate
                encoded = tokenizer(candidate_aggregated_response, return_offsets_mapping=True)
                offset = encoded['offset_mapping']
                valid_len = config.rollout.prompt_length - 50

                if len(offset) > valid_len:
                    end_char = offset[valid_len - 1][1]
                    candidate_aggregated_response = candidate_aggregated_response[:end_char]

                new_row_dict = copy.deepcopy(row_dict)
                new_row_dict[config.data.prompt_key][user_idx]['content'] = candidate_aggregated_response
                new_rows.append(new_row_dict)

        new_dataset = pd.DataFrame(new_rows)
        assert len(new_dataset) == len(dataset) * config.rollout.candidate_turn, f"New dataset length mismatch: {len(new_dataset)} != {len(dataset) * config.rollout.candidate_turn}"
        dataset = new_dataset

    # rearrange after maybe new dataset constructing
    chat_lst = dataset[config.data.prompt_key].tolist()
    chat_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in chat_lst]


    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu")
    wg.init_model()

    total_samples = len(dataset)
    config_batch_size = config.data.batch_size
    num_batch = -(-total_samples // config_batch_size)
    output_lst = [[] for _ in range(config.data.n_samples)]

    for batch_idx in range(num_batch):
        print(f"[{batch_idx + 1}/{num_batch}] Start to process.")
        batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size]
        inputs = tokenizer.apply_chat_template(
            batch_chat_lst,
            add_generation_prompt=True,
            padding=True,
            truncation=True,
            max_length=config.rollout.prompt_length,
            return_tensors="pt",
            return_dict=True,
            tokenize=True,
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        position_ids = compute_position_id_with_mask(attention_mask)
        batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}

        data = DataProto.from_dict(batch_dict)
        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)

        # START TO GENERATE FOR n_samples TIMES
        print(f"[{batch_idx + 1}/{num_batch}] Start to generate.")
        for n_sample in range(config.data.n_samples):
            output_padded = wg.generate_sequences(data_padded)
            output = unpad_dataproto(output_padded, pad_size=pad_size)

            output_texts = []
            for i in range(len(output)):
                data_item = output[i]
                prompt_length = data_item.batch["prompts"].shape[-1]
                valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
                valid_response_ids = data_item.batch["responses"][:valid_response_length]
                response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)
                output_texts.append(response_str)

            output_lst[n_sample].extend(output_texts)

    # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
    output_lst = np.array(output_lst, dtype=object)
    output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()

    # add to the data frame
    dataset[config.data.response_key] = output_lst

    # compute evalutaion metrics
    reward_model_data = dataset[config.data.reward_model_key]
    responses = dataset[config.data.response_key]
    prompts = dataset[config.data.prompt_key]

    total = len(dataset)
    scores = []

    for i in range(total):
        response_lst = responses[i]
        prompt = prompts[i]
        reward_data = reward_model_data[i]
        ground_truth = reward_data['ground_truth']

        score_lst = []
        for r in response_lst:
            try:
                score, _ = rlcot_reward_fn(prompt, r, ground_truth)
                score_lst.append(score)
            except Exception as e:
                score_lst.append(0.0)

        scores.append(score_lst)

    dataset["scores"] = scores

    # write to a new parquet
    # dataset.to_parquet(out_file)
    dataset.to_json(out_file.replace('.parquet', '.jsonl'), orient='records', lines=True, force_ascii=False)

    if config.data.get('do_metrics', False):
        # make a metrics file      
        out_metric_json = out_file.replace(".parquet", f"_metrics.json")
        pass_at_1 = np.mean(scores)
        all_pass_at_1 = np.mean(scores, axis=0).tolist()
        num_scores = np.array(scores).size
        time_use = time.time() - start_time
        
        result_json = {
            'model_path': config.model.path,
            'dataset': dataset_name,
            'num_samples': total,
            'num_scores': num_scores,
            'mean_acc': pass_at_1,
            'all_acc': all_pass_at_1,
            "time_use_in_second": time_use,
            "time_use_in_minute": f"{int(time_use // 60)}:{int(time_use % 60):02d}",
        }

        print("result_json", result_json)

        with open(
            out_metric_json, "w"
        ) as f:
            json.dump(result_json, f, indent=4)

if __name__ == "__main__":
    main()
