# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""

import os

import hydra
import numpy as np
import ray

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from pprint import pprint

import pandas as pd
from omegaconf import OmegaConf

from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.hdfs_io import makedirs
from verl.utils.model import compute_position_id_with_mask
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.device import is_cuda_available

from verl.workers.reward_manager.reward import rlcot_reward_fn
import time
import json

@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
    run_generation(config)


def run_generation(config) -> None:
    if not ray.is_initialized():
        # this is for local ray cluster
        ray.init(
            runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
            num_cpus=config.ray_init.num_cpus,
        )

    ray.get(main_task.remote(config))


@ray.remote(num_cpus=1)
def main_task(config):
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    local_path = copy_to_local(config.model.path)
    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    if config.rollout.temperature == 0.0:
        assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
    assert config.data.n_samples >= 1, "n_samples should always >= 1"

    # get out_file name && maked dir
    dataset_name = os.path.splitext(os.path.basename(config.data.path))[0]

    out_file_prefix = f"n{config.data.n_samples}_t{config.rollout.temperature}"
    if config.data.get('candidate_path', None) is not None:      
        out_file_prefix = out_file_prefix + f"_cand{config.rollout.n_candidates}_turn{config.rollout.candidate_turn}_maxlen{config.rollout.prompt_length}"
        if config.data.get('tag', None) is not None:    # add some tags for filename to debug
            out_file_prefix = out_file_prefix + config.data.tag
    
    output_dir = config.data.output_path
    out_file = f"{output_dir}/{dataset_name}/{out_file_prefix}.parquet"
    os.makedirs(f"{output_dir}/{dataset_name}", exist_ok=True)

    out_metric_json = out_file.replace(".parquet", f"_metrics.json")
    if os.path.exists(out_metric_json):
        print(f"Skipping {dataset_name} because {out_metric_json} already exists.")
        return
    start_time = time.time()

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if config.data.path.endswith('.parquet'):
        dataset = pd.read_parquet(config.data.path)
    elif config.data.path.endswith('.jsonl'):
        dataset = [json.loads(x) for x in open(config.data.path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)

    # merge the responses into dataset
    if config.data.get('candidate_path', None) is not None:
        import copy
        import re

        if config.data.candidate_path.endswith('.parquet'):
            candidate_dataset = pd.read_parquet(config.data.candidate_path)
        elif config.data.candidate_path.endswith('.jsonl'):
            candidate_dataset = [json.loads(x) for x in open(config.data.candidate_path)]
            if not isinstance(candidate_dataset, pd.core.frame.DataFrame):
                candidate_dataset = pd.DataFrame(candidate_dataset)

        assert len(dataset) == len(candidate_dataset), f"Dataset length mismatch: {len(dataset)} != {len(candidate_dataset)}"
        candidate_responses_lst = candidate_dataset[config.data.response_key].tolist()
        candidate_responses_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in candidate_responses_lst]

        new_rows = []        
        PROMPT_CONST = (
            "You are an expert & creative solver, provided with a challenging problem and a set of candidate responses "
            "which may be correct, partially correct or even wrong.\n"
            "You should first fully summarize the connection between candidate responses and problem, then generate a new and superior solution. "
            "You should generate a correct solution yourself if all candidates are wrong. "
            "Don't copy candidates, use insights selectively and give your summarized answer and reasons. "
            "\nProblem:\n"
        )

        for data_row, response_list in zip(dataset.itertuples(index=False), candidate_responses_lst):
            row_dict = data_row._asdict() 
            prompt = row_dict[config.data.prompt_key]

            for user_idx in range(len(prompt)):
                if prompt[user_idx]['role'] == 'user':
                    break

            for i in range(config.rollout.candidate_turn):
                group_resp = response_list[i * config.rollout.n_candidates : (i + 1) * config.rollout.n_candidates]
                clipped_group_resp = []
                for j in range(len(group_resp)):
                    m = re.search(r"</think>\s*(.*)", group_resp[j], re.DOTALL)
                    if m:
                        # group_resp[j] = m.group(1).strip()
                        clipped_group_resp.append(m.group(1).strip())
                    elif config.rollout.get('no_solution_discard', False) is False:
                        clipped_group_resp.append(group_resp[j])

                # sort by length
                group_resp = sorted(clipped_group_resp, key=len)

                candidate_aggregated_response = ""
                for j, r in enumerate(group_resp):
                    separator_text = f"\nCandidate Response {j+1}:\n"
                    candidate_aggregated_response += separator_text + r
                candidate_aggregated_response = PROMPT_CONST + prompt[user_idx]['content'] + candidate_aggregated_response

                # truncate
                encoded = tokenizer(candidate_aggregated_response, return_offsets_mapping=True)
                offset = encoded['offset_mapping']
                valid_len = config.rollout.prompt_length - 50

                if len(offset) > valid_len:
                    end_char = offset[valid_len - 1][1]
                    candidate_aggregated_response = candidate_aggregated_response[:end_char]

                new_row_dict = copy.deepcopy(row_dict)
                new_row_dict[config.data.prompt_key][user_idx]['content'] = candidate_aggregated_response
                new_rows.append(new_row_dict)

        new_dataset = pd.DataFrame(new_rows)
        assert len(new_dataset) == len(dataset) * config.rollout.candidate_turn, f"New dataset length mismatch: {len(new_dataset)} != {len(dataset) * config.rollout.candidate_turn}"
        dataset = new_dataset

    # rearrange after maybe new dataset constructing
    chat_lst = dataset[config.data.prompt_key].tolist()
    chat_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in chat_lst]


    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout")
    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu")
    wg.init_model()

    total_samples = len(dataset)
    config_batch_size = config.data.batch_size
    num_batch = -(-total_samples // config_batch_size)
    output_lst = [[] for _ in range(config.data.n_samples)]

    for batch_idx in range(num_batch):
        print(f"[{batch_idx + 1}/{num_batch}] Start to process.")
        batch_chat_lst = chat_lst[batch_idx * config_batch_size : (batch_idx + 1) * config_batch_size]
        inputs = tokenizer.apply_chat_template(
            batch_chat_lst,
            add_generation_prompt=True,
            padding=True,
            truncation=True,
            max_length=config.rollout.prompt_length,
            return_tensors="pt",
            return_dict=True,
            tokenize=True,
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        position_ids = compute_position_id_with_mask(attention_mask)
        batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids}

        data = DataProto.from_dict(batch_dict)
        data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)

        # START TO GENERATE FOR n_samples TIMES
        print(f"[{batch_idx + 1}/{num_batch}] Start to generate.")
        for n_sample in range(config.data.n_samples):
            output_padded = wg.generate_sequences(data_padded)
            output = unpad_dataproto(output_padded, pad_size=pad_size)

            output_texts = []
            for i in range(len(output)):
                data_item = output[i]
                prompt_length = data_item.batch["prompts"].shape[-1]
                valid_response_length = data_item.batch["attention_mask"][prompt_length:].sum()
                valid_response_ids = data_item.batch["responses"][:valid_response_length]
                response_str = tokenizer.decode(valid_response_ids, skip_special_tokens=True)
                output_texts.append(response_str)

            output_lst[n_sample].extend(output_texts)

    # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
    output_lst = np.array(output_lst, dtype=object)
    output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()

    # add to the data frame
    dataset[config.data.response_key] = output_lst

    # compute evalutaion metrics
    reward_model_data = dataset[config.data.reward_model_key]
    responses = dataset[config.data.response_key]
    prompts = dataset[config.data.prompt_key]

    total = len(dataset)
    scores = []

    for i in range(total):
        response_lst = responses[i]
        prompt = prompts[i]
        reward_data = reward_model_data[i]
        ground_truth = reward_data['ground_truth']

        score_lst = []
        for r in response_lst:
            try:
                score, _ = rlcot_reward_fn(prompt, r, ground_truth)
                score_lst.append(score)
            except Exception as e:
                score_lst.append(0.0)

        scores.append(score_lst)

    dataset["scores"] = scores

    # write to a new parquet
    # dataset.to_parquet(out_file)
    dataset.to_json(out_file.replace('.parquet', '.jsonl'), orient='records', lines=True, force_ascii=False)

    if config.data.get('do_metrics', False):
        # make a metrics file      
        out_metric_json = out_file.replace(".parquet", f"_metrics.json")
        pass_at_1 = np.mean(scores)
        all_pass_at_1 = np.mean(scores, axis=0).tolist()
        num_scores = np.array(scores).size
        time_use = time.time() - start_time
        
        result_json = {
            'model_path': config.model.path,
            'dataset': dataset_name,
            'num_samples': total,
            'num_scores': num_scores,
            'mean_acc': pass_at_1,
            'all_acc': all_pass_at_1,
            "time_use_in_second": time_use,
            "time_use_in_minute": f"{int(time_use // 60)}:{int(time_use % 60):02d}",
        }

        print("result_json", result_json)

        with open(
            out_metric_json, "w"
        ) as f:
            json.dump(result_json, f, indent=4)


if __name__ == "__main__":
    main()
