# example: python bon_inference.py --gpu_num=1 --total_gpus=8 --batch_size=4 --num_prompts=16 --num_completions_per_prompt=16

import os
import math
import random
import pathlib
import json
import pandas as pd
from typing import List
import argparse
import torch
from torch import nn
from datasets import load_from_disk, load_dataset
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoModelForCausalLM,
)
from accelerate import Accelerator
import pickle

import sys

sys.path.append("../reward")
from sft.scenario_datasets import RolloutDataset
from generate_bon_completions import generate_continuations
import fire
import datetime


def get_prompt_dataset(
    prompts: List[str],
    completions: List[str],
    max_length: int,
    tokenizer: AutoTokenizer,
) -> List[str]:
    formatted_prompts = []
    formatted_completions = []
    for i in tqdm(range(len(prompts))):
        tmp = tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"]
        if len(tmp) >= max_length:
            continue
        tmp = tokenizer.decode(
            tmp, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        if not tmp:
            continue
        formatted_prompts.append(tmp)
        formatted_completions.append(completions[i])
    return formatted_prompts, formatted_completions


def main(
    model_name="EleutherAI/pythia-1.4b-deduped",
    model_location="/data/private_models/xx_models/policy_models/llama-7b",
    tokenizer_location='/data/private_models/xx_models/llama/llama_hf_weights_v1.1/llama-7b',
    temperature=1,
    max_input_length=384,
    max_completion_length=128,
    batch_size=64,
    num_prompts=1200,
    num_completions_per_prompt=10112,
    offset=0,
    gpu_num=None,
    total_gpus=None,
):
    start, end = 0, 0
    if total_gpus:
        start = gpu_num * num_prompts // total_gpus
        end = (gpu_num + 1) * num_prompts // total_gpus
        print("gpu_num:", gpu_num)
        print("total_gpus:", total_gpus)
        print("start:", start)
        print("end:", end)

    now = datetime.datetime.now()
    formatted_now = now.strftime("%Y-%m-%d-%H-%M-%S")
    # Create an accelerator object
    accelerator = Accelerator()

    seed_val = 1
    random.seed(seed_val)
    if tokenizer_location:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_location)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    if model_location:
        model = AutoModelForCausalLM.from_pretrained(
            model_location, pad_token_id=tokenizer.eos_token_id
        ).to("cuda")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, pad_token_id=tokenizer.eos_token_id
        ).to("cuda")
    model = model.half()
    
    def get_samples(dataset):
        samples = []
        c = 0
        for sample in dataset:
            sample_list = []
            current_sample = ""
            try:
                for chunk in sample.split("\n\nHuman: "):
                    if not chunk:
                        continue
                    question, response = chunk.split("\n\nAssistant: ")
                    current_sample += "\n\nHuman: " + question + "\n\nAssistant: "
                    sample_list.append((current_sample.rstrip(), response))
                    current_sample += response
            except:
                c += 1
                continue
            samples.append(sample_list[-1])
        accelerator.print("skipped", c, "samples")
        random.shuffle(samples)
        return samples

    train_set = get_samples(RolloutDataset("train"))
    train_posts, train_continuations = zip(*train_set)

    train_prompts, train_continuations = get_prompt_dataset(
        train_posts, train_continuations, max_input_length, tokenizer
    )
    print("train_prompts length:", len(train_prompts))
    if total_gpus:
        train_prompts = train_prompts[start + offset : end + offset]
    else:
        train_prompts = train_prompts[offset : num_prompts + offset]
    accelerator.print("loaded dataset, LEN:", len(train_prompts))
    accelerator.print(train_prompts[-1])

    bon_generations = []
    i = 0
    for prompt in tqdm(train_prompts):
        accelerator.print(f"prompt{i+start}: {prompt}")

        continuations = generate_continuations(
            model=model,
            batch_size=batch_size,
            num_completions=num_completions_per_prompt,
            temperature=temperature,
            max_new_tokens=max_completion_length,
            max_input_length=max_input_length,
            tokenizer=tokenizer,
            prompt=prompt,
            accelerator=accelerator,
        )

        # detokenize the continuations - can be done in batch

        detokenized_continuations = tokenizer.batch_decode(
            continuations, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        bon_generations.append(detokenized_continuations)

        # save the generations in a pickle file
        print("saving generations")
        if total_gpus:
            out_filename = f"generations/bon_generations_{formatted_now}_{gpu_num}.pkl"
        else:
            out_filename = f"generations/bon_generations_{formatted_now}.pkl"
        with open(out_filename, "wb") as f:
            pickle.dump(bon_generations, f)

        i += 1


if __name__ == "__main__":
    fire.Fire(main)
