import json
import sys
from argparse import ArgumentParser, Namespace
from typing import List
import os

from transformers import AutoTokenizer
from tqdm import tqdm
import requests
from vllm import LLM, SamplingParams

def get_args() -> Namespace:
    parser = ArgumentParser()

    group = parser.add_argument_group("Evaluate LLM on InstructBench")
    group.add_argument("--config", type=str, help="Datasets and corresponding Instruction transformations to evaluate")
    group.add_argument("--input_path", type=str, help="data path")

    group.add_argument("--batch_size", type=int, default=1, help="batch size for inference")
    group.add_argument("--model_name", type=str, default="mistralai/Mistral-7B-v0.1", help="model name to use")
    group.add_argument("--max_input_length", type=int, default=8192, help="max input length")
    group.add_argument("--max_output_length", type=int, help="max output length")
    
    group.add_argument("--output_folder", type=str, help="output folder")
    
    args = parser.parse_args()
    return args

def main() -> None:
    args = get_args()

    model = args.model_name

    config = {}
    with open(args.config, "r", errors="ignore", encoding="utf8") as reader:
        config = json.load(reader)

    if not os.path.exists(args.output_folder):
        os.makedirs(args.output_folder)

    output_folder = args.output_folder

    # for each dataset
    for each_dataset in config:
        # for every instruction
        for each_instruction in config[each_dataset]:
            file_name = os.path.join(args.input_path, f"{each_dataset}_{each_instruction}.jsonl")

            data_instances = []
            with open(file_name, "r", errors="ignore", encoding="utf8") as reader:
                for each_line in reader:
                    data_instances.append( json.loads(each_line) )
                reader.close()

            print(f"Number of instances in dataset is {len(data_instances)}")
            end_index = len(each_dataset)

            sampling_params = SamplingParams(temperature=0.0, max_tokens=30, n=1)
            llm = LLM(model=model)

            out_file = open(os.path.join(output_folder, f"{each_dataset}_{each_instruction}.jsonl"), "w")

            batch_examples = []
            
            for data_index in tqdm(range(len(data_instances)), desc="Generating for each example"):
                example = data_instances[data_index]

                batch_examples.append(example)

                if len(batch_examples) > args.batch_size:
                    input_prompts = []
                    for each_example in batch_examples:
                        input_prompts.append( each_example["input"] )

                    outputs = llm.generate(input_prompts, sampling_params)

                    for index in range(len(batch_examples)):
                        batch_examples[index]["generated_text"] = outputs[index].outputs[0].text
                        
                        out_file.write(json.dumps(batch_examples[index]))
                        out_file.write("\n")
                        out_file.flush()

                    batch_examples = []
            out_file.close()


if __name__ == "__main__":
    main()