"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the CC-By-NC license found in the
LICENSE file in the root directory of this source tree.
"""

import concurrent
import json
import random

import openai
import torch
from datasets import load_dataset
from fire import Fire
from huggingface_hub import login
from openai import OpenAI
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams


def main(
    input_path="auto",
    output_path="/fsx-ram/yifeizhou/collab_llm/outputs/temp_samplebestofn.jsonl",
    agent_model="/fsx-ram/shared/Meta-Llama-3.1-8B-Instruct",  # meta-llama/Llama-3.1-8B-Instruct",
    temperature=1.0,
    best_of_n=16,
    data_fraction=1.0,
): 

    tensor_parallel_size = torch.cuda.device_count()
    print(f"tensor_parallel_size: {tensor_parallel_size}")

    with open(input_path, "r") as fb:
        data = [json.loads(line) for line in fb]

    # random.shuffle(data)
    data = data[: int(len(data) * data_fraction)]
    assert "dialogue_history" in data[0], "no dialogue history found in data!"
    flatten_data = []
    for d in data:
        ground_truth = d["task"]["ground_truth"]
        for i, dh in enumerate(d["dialogue_history"]):
            if "input" in dh:
                flatten_data.append(
                    {
                        "input": dh["input"],
                        "input_with_ground_truth": f"In light that the final answer is: {ground_truth}."
                        + dh["input"],
                        "older_output": dh["output"],
                        "reward": d["reward"],
                        "ground_truth": ground_truth,
                        "messages": d["dialogue_history"][: i + 1],
                    }
                )
    data = flatten_data

    for d in data:
        d["additional_outputs"] = []
        d["additional_reference_logprobs"] = []
        d["additional_reference_logprobs_sum"] = []
    llm_args = {
        "model": agent_model,
        "distributed_executor_backend": "ray",
        "tensor_parallel_size": tensor_parallel_size,
        "enforce_eager": True,
    }
    llm = LLM(**llm_args)
    sampling_params = SamplingParams(
        temperature=temperature,
        logprobs=0,
        n=best_of_n,
        max_tokens=1024,
        # use_beam_search=False,
    )
    all_messages = [d["input"] for d in data]
    outputs = llm.generate(all_messages, sampling_params, use_tqdm=True)
    for d, output in zip(data, outputs):
        for o in output.outputs:
            all_logprobs = []
            for logprobs in o.logprobs:
                if logprobs is not None:
                    for v in logprobs.values():
                        all_logprobs.append(v.logprob)
            d["additional_outputs"].append(o.text)
            d["additional_reference_logprobs"].append(
                sum(all_logprobs) / len(all_logprobs)
            )
            d["additional_reference_logprobs_sum"].append(sum(all_logprobs))

    output_results = []
    for d in data:
        output_results.append(d)

    with open(output_path, "w") as fb:
        for d in output_results:
            fb.write(json.dumps(d) + "\n")


if __name__ == "__main__":
    Fire(main)
