# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""
import ray
import numpy as np
import hydra
import os
import json
import copy
import torch

from typing import List, Tuple
from omegaconf import ListConfig
from tqdm import tqdm

os.environ['NCCL_DEBUG'] = 'WARN'
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from verl.utils.model import compute_position_id_with_mask

import pandas as pd

from transformers import AutoTokenizer
# from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.trainer.main_ppo import RewardManager
from verl.protocol import pad_dataproto_to_divisor

from ray.util import pdb

def load_jsonl(file):
    with open(file, 'r') as f:
        return [json.loads(line) for line in f]

def write_jsonl(data, file):
    with open(file, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

def write_json(data, file):
    with open(file, 'w') as f:
        json.dump(data, f, indent=2)

def filter_data_by_keywords(data: DataProto, filter_keys: List[str], key: str = 'task') -> Tuple[DataProto, dict]:
        """根据关键词列表过滤数据
        
        Args:
            data: 原始数据
            filter_keys: 需要过滤的关键词列表，例如 ['system1', 'dummy']
            key: 用于过滤的字段名，默认为'task'
            
        Returns:
            filtered_data: 过滤后的数据
            stats: 过滤统计信息
        """
        # 创建布尔掩码：如果数据包含任何一个filter_keys中的关键词，就标记为False
        mask = np.ones(len(data), dtype=bool)
        for keyword in filter_keys:
            mask &= ~np.array([keyword in task for task in data.non_tensor_batch[key]])
        
        # 使用布尔掩码筛选数据
        if data.batch is not None:
            filtered_batch = data.batch[torch.tensor(mask)]
        else:
            filtered_batch = None
            
        # 筛选non_tensor_batch
        filtered_non_tensor = {}
        for key, val in data.non_tensor_batch.items():
            filtered_val = val[mask]
            # 如果筛选后不是numpy数组，转换为numpy数组
            if not isinstance(filtered_val, np.ndarray):
                filtered_val = np.array([filtered_val], dtype=object)
            filtered_non_tensor[key] = filtered_val
        
        # 统计信息
        stats = {
            'total_samples': len(data),
            'filtered_samples': len(data) - mask.sum(),
            'remaining_samples': mask.sum(),
            'filtered_keywords': filter_keys
        }
        
        filtered_data = DataProto(
            batch=filtered_batch,
            non_tensor_batch=filtered_non_tensor,
            meta_info=data.meta_info
        )
        
        return filtered_data, stats


@hydra.main(config_path='config', config_name='generation', version_base=None)
def main(config):
    from pprint import pprint
    from omegaconf import OmegaConf
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)
    local_path = copy_local_path_from_hdfs(config.model.path)
    from verl.utils import hf_tokenizer
    tokenizer = hf_tokenizer(local_path)

    if config.rollout.temperature == 0.:
        assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.'

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    # dataset = pd.read_parquet(config.data.path)
    if not isinstance(config.data.path, (List, ListConfig)):
        config.data.path = [config.data.path]
    dataset = []
    for path in config.data.path:
        dataset.extend(load_jsonl(path)) # List[Dict]
    # dataset = RLHFDataset(jsonl_files=config.data.path,
    #                         tokenizer=tokenizer,
    #                         prompt_key=config.data.prompt_key,
    #                         max_prompt_length=config.rollout.prompt_length,
    #                         filter_prompts=True,
    #                         return_raw_chat=config.data.get('return_raw_chat', False),
    #                         truncation='error')
    # if config.data.shuffle:
    #     train_dataloader_generator = torch.Generator()
    #     train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
    #     sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
    # else:
    # sampler = SequentialSampler(data_source=dataset)
    # test_dataloader = DataLoader(dataset=dataset,
    #                             batch_size=len(dataset),
    #                             shuffle=False,
    #                             drop_last=False,
    #                             collate_fn=collate_fn)

    # print(f'Size of test dataloader: {len(test_dataloader)}')

    # chat_lst = dataset[config.data.prompt_key].tolist()

    # chat_lst = [chat.tolist() for chat in chat_lst]

    if config.debug:
        dataset = dataset[:8]

    val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1, compute_score=None)

    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='actor_rollout') # actor_rollout
    resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes)
    wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
    wg.init_model()

    total_samples = len(dataset)
    print(f'Size of test data: {total_samples}')
    # real_batch_size = data.batch['input_ids'].shape[0]
    config_batch_size = config.data.batch_size
    dp_size = wg.world_size // config.rollout.tensor_model_parallel_size
    num_batch = (total_samples // config_batch_size) + 1
    output_lst = [[] for _ in range(config.data.n_samples)]
    import time
    start_time = time.time()
    for batch_idx in tqdm(range(num_batch)):
        print(f'[{batch_idx+1}/{num_batch}] Start to process.')
        batch_data = dataset[batch_idx * config_batch_size:(batch_idx + 1) * config_batch_size]
        # dummy input_ids
        batch_chat_lst = [item[config.data.prompt_key] for item in batch_data]
        inputs = tokenizer.apply_chat_template(batch_chat_lst,
                                               add_generation_prompt=True,
                                               padding=True,
                                               truncation=True,
                                               max_length=config.rollout.prompt_length,
                                               return_tensors='pt',
                                               return_dict=True,
                                               tokenize=True)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        position_ids = compute_position_id_with_mask(attention_mask)

        batch_dict = {'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids}

        non_tensors = {}
        for key in batch_data[0].keys():
            if key not in ['input_ids', 'attention_mask', 'position_ids']:
                non_tensors[key] = np.array([item[key] for item in batch_data], dtype=object)
        non_tensors['job'] = np.array(copy.deepcopy(batch_data), dtype=object)
        data = DataProto.from_dict(batch_dict, non_tensors=non_tensors)
        original_task_set = set(data.non_tensor_batch['task']) # for check
        real_batch_size = data.batch['input_ids'].shape[0]
        multi_agent_train = True
        if real_batch_size % wg.world_size != 0:
            data, dummy_data_size = pad_dataproto_to_divisor(data, wg.world_size, change_task=multi_agent_train)
            print(
                f'world_size {wg.world_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data'
            ) # data.non_tensor_batch['task']
        # if real_batch_size % dp_size != 0:
        #     data, dummy_data_size = pad_dataproto_to_divisor(data, dp_size, change_task=multi_agent_train)
        #     print(
        #         f'dp_size {dp_size} is not divisible by real_batch_size {real_batch_size}, add {dummy_data_size} dummy data'
        #     ) # data.non_tensor_batch['task']
        # dummy1_batch_size = data.batch['input_ids'].shape[0]
        # if dummy1_batch_size % wg.world_size !=0:
        #     data, dummy_data_size = pad_dataproto_to_divisor(data, wg.world_size, change_task=multi_agent_train)
        #     print(
        #         f'world_size {wg.world_size} is not divisible by dummy1_batch_size {dummy1_batch_size}, add {dummy_data_size} dummy data'
        #     )

        batch_size = data.batch['input_ids'].shape[0]
        assert batch_size % dp_size == 0, f'batch_size {batch_size} is not divisible by dp_size {dp_size}'

        print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
        # START TO GENERATE FOR n_samples TIMES
        for i in tqdm(range(config.data.n_samples), desc="repeat"):
            sampled_data = copy.deepcopy(data)
            output = wg.generate_sequences(sampled_data) # verl.protocol.DataProto
            # remove dummy data
            # output = output[:real_batch_size] # verl.protocol.DataProtoItem
            if multi_agent_train:
                test_batch, status = filter_data_by_keywords(output, ['system1', 'dummy'], 'task') # test_batch is DataProto  # output.non_tensor_batch['task']
            else:
                test_batch = DataProto(
                    batch=output.batch,
                    non_tensor_batch=output.non_tensor_batch,  # dict_keys(['data_source', 'task', 'uid', 'job', 'messages', 'answer_tag_part', 'boxed_answer', 'answer', 'stop_reason', 'decoded_responses'])
                    meta_info=output.meta_info,
                )

            assert original_task_set == set(test_batch.non_tensor_batch['task']), f"validation ERROR: {original_task_set}\n!=\n{set(test_batch.non_tensor_batch['task'])}"

            reward_tensor, _ = val_reward_fn(test_batch)
            scores = reward_tensor.sum(-1).cpu().tolist() # test_batch.non_tensor_batch['task']

            output_item = []
            for idx, score in enumerate(scores):
                tmp = copy.deepcopy(test_batch[idx].non_tensor_batch)
                tmp['score'] = score
                output_item.append(tmp)

            output_lst[i].extend(output_item)

    # TODO: 按数据集划分
    # convert output_lst from (n_samples, n_data) to (n_data, n_sampels)
    output_lst = np.array(output_lst, dtype=object)
    output_lst = np.transpose(output_lst, axes=(1, 0)).tolist()

    scores = np.array([[item['score'] for item in row] for row in output_lst])
    # best_of_n = np.max(scores, axis=1)  # 每组取最大值
    # accuracy = np.mean(best_of_n)  # 计算准确率
    # 计算不同n的accuracy
    n_samples = scores.shape[1]
    for i in range(1, n_samples + 1):  # 步长为2，取奇数
        # 只取前i个样本
        curr_scores = scores[:, :i]
        best_of_i = np.max(curr_scores, axis=1)
        accuracy = np.mean(best_of_i)
        print(f"Best-of-{i} Accuracy: {accuracy:.4f}")

    # 如果想要保存结果
    results = {
        i: float(np.mean(np.max(scores[:, :i], axis=1)))
        for i in range(1, n_samples + 1)
    }

    # # add to the data frame
    # dataset[f'responses'] = output_lst

    # write to a new parquet
    output_dir = os.path.join(config.data.output_path, f"{config.data.n_samples}_{config.timestamp}.jsonl")
    os.makedirs(os.path.dirname(output_dir), exist_ok=True)

    save_output = 0 # type(output_lst[0][0]['score']) isinstance(output_lst[0][0]['answer'], np.ndarray)
    for item in output_lst:
        for i in item:
            for k, v in i.items():
                if isinstance(v, np.ndarray):
                    i[k] = v.tolist() # list() output_lst[0][0]['answer'].tolist()
                elif isinstance(v, (np.float32, np.float64)):
                    i[k] = float(v)
                elif isinstance(v, np.integer):
                    i[k] = int(v)
    
    write_jsonl(output_lst, output_dir)
    write_json(results, os.path.join(config.data.output_path, f"{config.data.n_samples}_{config.timestamp}_results.json"))
    print(f"best of {config.data.n_samples} save at {output_dir}")

    print(f"time cost: {time.time() - start_time}s")


if __name__ == '__main__':
    main()
