from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from datasets import load_dataset
import torch
from huggingface_hub import Repository, snapshot_download
import numpy as np
import random
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
import argparse
import time
from accelerate import Accelerator
import openai
import os
import llm_blender
import tqdm
import json
from alpaca_eval import evaluate


parser = argparse.ArgumentParser()

parser.add_argument("--model", type=str)
parser.add_argument("--name", type=str)
parser.add_argument("--load", action="store_true")


args = parser.parse_args()

model_dir = args.model
col_name = args.name
load_from_cache = args.load

openai.api_key = os.environ['OPENAI_API_KEY']
openai.base_url = "https://azure-openai-api.shenmishajing.workers.dev/v1/"

print(os.environ['OPENAI_API_KEY'])

home_directory = os.path.expanduser("~")
cache_path = os.path.join(home_directory, ".cache/vllm")


def generate_response_vllm(col_name):
    model = LLM(model_dir, download_dir=cache_path)
    tokenizer = AutoTokenizer.from_pretrained(
        "alignment-handbook/zephyr-7b-sft-full")
    with torch.inference_mode():
        sampling_params = SamplingParams(
            temperature=0.7,
            top_p=1.0,
            max_tokens=2048,
            stop=tokenizer.eos_token,
            skip_special_tokens=True,
            seed=24
        )
        alpaca_eval_data = load_dataset(
            "tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
        chat_prompts = []
        instructions = []
        for row in alpaca_eval_data:
            prompt = row["instruction"]
            instructions.append(prompt)
            prompt_message = [{"role": "system", "content": ""}, {
                "role": "user", "content": prompt}]
            new_prompt = tokenizer.apply_chat_template(
                prompt_message, tokenize=False, add_generation_prompt=True)
            chat_prompts.append(new_prompt)

        responses = model.generate(chat_prompts, sampling_params)
        responses = [response.outputs[0].text for response in responses]

        result = []
        for example, output in zip(alpaca_eval_data, responses):
            example["output"] = output
            example["generator"] = col_name
            result.append(example)
        output_dir = "./alpaca/"+col_name
        os.makedirs(output_dir, exist_ok=True)
        file_path = os.path.join(output_dir, "data.json")
        with open(file_path, 'w') as json_file:
            json.dump(result, json_file, indent=2)
        return result


if __name__ == "__main__":
    if load_from_cache:
        output_dir = "./alpaca/"+col_name
        file_path = os.path.join(output_dir, "data.json")
        with open(file_path, "r", encoding='utf-8') as json_file:
            result = json.load(json_file)
    else:
        result = generate_response_vllm(col_name)

    df_leaderboard, annotations = evaluate(
        model_outputs=result,
        annotators_config="azure_gpt4_turbo",
        # output_path=args.save_dir,
        is_return_instead_of_print=True,
        # caching_path=os.path.join(
        #     args.save_dir, "alpaca_eval_annotator_cache.json"),
        precomputed_leaderboard=None,
        is_cache_leaderboard=False,
        is_overwrite_leaderboard=True,
    )
    print(df_leaderboard.to_string(float_format="%.2f"))
    output_dir = "./alpaca/"+col_name
    save_dir = os.path.join(output_dir, "leaderboard.json")
    with open(save_dir, "w") as fout:
        json.dump(df_leaderboard.to_dict(), fout)
