# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate responses given a dataset of prompts
"""

import os

import hydra
import numpy as np

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# os.environ['TORCH_COMPILE_DISABLE'] = '1'

from pprint import pprint

import pandas as pd
from omegaconf import OmegaConf

from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local

import time
import json
import torch
from transformers import AutoModelForSequenceClassification

from tqdm import tqdm

@hydra.main(config_path="config", config_name="generation", version_base=None)
def main(config):
    pprint(OmegaConf.to_container(config, resolve=True))  # resolve=True will eval symbol values
    OmegaConf.resolve(config)

    local_path = copy_to_local(config.model.path)
    trust_remote_code = config.data.get("trust_remote_code", False)
    tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

    # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
    if config.data.path.endswith('.parquet'):
        dataset = pd.read_parquet(config.data.path)
    elif config.data.path.endswith('.jsonl'):
        dataset = [json.loads(x) for x in open(config.data.path)]
        if not isinstance(dataset, pd.core.frame.DataFrame):
            dataset = pd.DataFrame(dataset)

    # prompt_list
    prompt_lst = dataset[config.data.prompt_key].tolist()
    prompt_lst = [(chat.tolist() if not isinstance(chat, list) else chat) for chat in prompt_lst]

    # responses_lst
    responses_lst = dataset[config.data.response_key].tolist()
    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token


    model = AutoModelForSequenceClassification.from_pretrained(
        config.model.path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        num_labels=1,
    )
    model_name = model.config._name_or_path.split('/')[-1]

    rewards = []
    for i in tqdm(range(len(dataset)), desc="Processing dataset"):
        message = prompt_lst[i]
        responses = responses_lst[i]

        row_reward = []
        all_pairs = []
        if isinstance(message, list):
            for turn in message:
                if turn.get("role") == "user" and isinstance(turn.get("content"), str):
                    problem_text = turn.get("content").strip()

                    for response_text in responses:
                        if "</think>" in response_text:
                            # extract solution
                            response_text = response_text.split("</think>")[1]

                        new_message = [
                            {"role": "user", "content": problem_text},
                            {"role": "assistant", "content": response_text}
                        ]
                        all_pairs.append(new_message)
        

        for message in all_pairs:
            conv = tokenizer.apply_chat_template(message, tokenize=True, return_tensors="pt")

            with torch.no_grad():
                score = model(conv.to(model.device)).logits[0][0].item()
                row_reward.append(score)

        rewards.append(row_reward)

    dataset["rewards"] = rewards
    selected_data = dataset[['scores', 'rewards']]

    if config.data.path.endswith('.parquet'):
        out_file = config.data.path.replace(".parquet", f"_{model_name}.parquet")
        selected_data.to_parquet(out_file)
    elif config.data.path.endswith('.jsonl'):
        out_file = config.data.path.replace(".jsonl", f"_{model_name}.jsonl")
        selected_data.to_json(out_file, orient='records', lines=True, force_ascii=False)

if __name__ == "__main__":
    main()
