import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["VLLM_USE_MODELSCOPE"]="True"



from transformers import  set_seed
import torch
import argparse
import json

from tqdm import tqdm
from peft import PeftModel, PeftConfig
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
PROMPT_DICT_ALPACA = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_name",
        type=str,
        default='vicuna',
        help="The name of the dataset to use (via the datasets library).",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default='alpaca',
        help="alpaca, wiz, vicuna.",
    )
    parser.add_argument(
        "--num_beams",
        type=int,
        default=3,
        help=(
            "Number of beams to use for evaluation. This argument will be "
            "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
        ),
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default='../../final_model/lenth_sorted_6134_2',
        help="Path to pretrained model or model identifier from huggingface.co/models.",

    )
    parser.add_argument("--peft",type=bool,default=True,help="Whether to use peft")
    parser.add_argument("--mode_name", type=str, default='lenth_sorted_6134_2')#vicuna_kl_T_low_3300_G2
    #parser.add_argument("--peft_path", type=str, default='../../mode_saved/epoches_3/kl_final/kl_target_low_noise10_240_3000',help="path to use peft")
    parser.add_argument("--seed", type=int, default=199, help="A seed for reproducible training.")
    parser.add_argument("--max_length", type=int, default=4096)
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    set_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    from transformers import LlamaTokenizer, LlamaForCausalLM
    sampling_params = SamplingParams(temperature=0.8,top_k=50,top_p=0.9,repetition_penalty=1.2,max_tokens=args.max_length,seed=args.seed)
    model = LLM(model=args.model_name_or_path,enforce_eager=True,seed=args.seed,dtype=torch.bfloat16)





    if args.prompt == 'alpaca':
        prompt_input, prompt_no_input = PROMPT_DICT_ALPACA["prompt_input"], PROMPT_DICT_ALPACA["prompt_no_input"]
    elif args.prompt == 'wiz':
        prompt_input, prompt_no_input = PROMPT_DICT_WIZARDLM["prompt_input"], PROMPT_DICT_WIZARDLM["prompt_no_input"]
    elif args.prompt == 'vicuna':
        prompt_input, prompt_no_input = PROMPT_DICT_VICUNA["prompt_input"], PROMPT_DICT_VICUNA["prompt_no_input"]

    
    if(args.dataset_name=="vicuna"):
        dataset_path = '../test_data/vicuna_test_set.jsonl'
        prompt_key = 'text'
    elif (args.dataset_name == "alpaca"):
        dataset_path = '../test_data/alpaca_eval_gpt4_baseline.jsonl'
        prompt_key = 'instruction'
    elif(args.dataset_name=="koala"):
        dataset_path = '../test_data/koala_test_set.jsonl'
        prompt_key = 'prompt'
    elif(args.dataset_name=="sinstruct"):
        dataset_path = '../test_data/sinstruct_test_set.jsonl'
        prompt_key = 'instruction'
    elif(args.dataset_name=="wizardlm"):
        dataset_path = '../test_data/wizardlm_test_set.jsonl'
        prompt_key = 'Instruction'
    elif(args.dataset_name=="lima"):
        dataset_path = '../test_data/lima_test_set.jsonl'
        prompt_key = 'conversations'
    data=[]
    point_list= []
    with open(dataset_path) as f:
        results = []
        dataset = list(f)
        for point in tqdm(dataset):
            point = json.loads(point)
            point_list.append(point)
            instruction = point[prompt_key]
            if(args.dataset_name=="sinstruct"):
                instances = point['instances']
                assert len(instances) == 1
                if  instances[0]['input']:
                    prompt = prompt_input.format_map({"instruction":instruction, 'input':instances[0]['input']})
                else:
                    prompt = prompt_no_input.format_map({"instruction":instruction})
            else:
                prompt = prompt_no_input.format_map({"instruction":instruction})
            data.append(prompt)
    outputs = model.generate(
        data,
        sampling_params,
    )
    for i,output in enumerate(outputs):

        generated_text = output.outputs[0].text
        point_list[i]['output'] = generated_text
        point_list[i]['instruction'] = point_list[i][prompt_key]
        point_list[i]['generator'] = args.mode_name
        if (args.dataset_name != "alpaca"):
            del point_list[i][prompt_key]


    saved_name = args.dataset_name +'_'+ args.mode_name+ ".json"
    with open(os.path.join('../data_save/epoches_3/kl_final', saved_name), "w") as f:
        json.dump(point_list, f, indent=4)

if __name__ == "__main__":
    main()