"""Generate answers with GPT-4

Usage:
python3 gen_api_answer.py --model gpt-3.5-turbo
"""

import argparse
import json
import yaml
import os
import time
import concurrent.futures
import glob
import shortuuid
import tqdm

from livebench.common import (
    LIVE_BENCH_RELEASES,
    reorg_answer_file,
    get_categories_tasks,
    load_questions,
    load_questions_jsonl,
    LIVE_BENCH_DATA_SUPER_PATH,
    filter_questions, 
    ProgrammingLanguage
)
from livebench.model.completions import chat_completion_openai
from livebench.process_results.coding.utils import LCB_generation_process_results

from livebench.model import Model, get_model

from ast import literal_eval

def livebench_original_loop(question, conv, model, temperature, max_tokens, stream, api_dict):
    turns = []
    total_num_tokens = 0
    for j in range(len(question["turns"])):
        conv.append_message(conv.roles[0], question["turns"][j])
        conv.append_message(conv.roles[1], None)


        if api_dict is not None:
            output, num_tokens = chat_completion_openai(
                model, conv, temperature, max_tokens, api_dict=api_dict, stream=stream
            )
        else:
            assert model.api_function is not None
            output, num_tokens = model.api_function(
                model, conv, temperature, max_tokens, api_dict=api_dict
            )

        conv.update_last_message(output)
        turns.append(output)
        total_num_tokens += num_tokens

    return turns, total_num_tokens

def match_lang(task):
    if task == "LCB_generation":
        return ProgrammingLanguage.PYTHON
    elif task == "LCB_generation_js":
        return ProgrammingLanguage.JAVASCRIPT
    elif task == "LCB_generation_go":
        return ProgrammingLanguage.GO
    elif task == "LCB_generation_swift":
        return ProgrammingLanguage.SWIFT
    elif task == "LCB_generation_java":
        return ProgrammingLanguage.JAVA
    else:
        raise ValueError("Language doesn't exist")

def get_query(user_model, feedback_templates, past_turns, feedback_style, question, turn_number):
    if turn_number == 0:
        return question["turns"][0]
    elif feedback_style == "test_case" or feedback_style == "code_judge" or feedback_style =="code_judge_preference":
        try:
            results, metadata = LCB_generation_process_results(question, past_turns[-1], match_lang(question["task"]),\
                                                                     debug=False)
        except Exception as e:
            print(e)
            return "Runtime or compiler error"
    
        if results == 1:
            return "<SPECIAL_TOKEN> End of interaction"
        else:
            # limit the input output lengths
            metadata = metadata[0][0]
            if type(metadata) is str:
                try:
                    metadata = literal_eval(metadata)
                except:
                    metadata = json.loads(metadata)

            if "inputs" in metadata.keys() and (type(metadata["inputs"]) is str or type(metadata["inputs"]) is list) and len(metadata["inputs"]) > 5000:
                metadata["inputs"] = metadata["inputs"][:5000] + " ...(truncated)"
            if "output" in metadata.keys() and (type(metadata["output"]) is str or type(metadata["output"]) is list) and len(metadata["output"]) > 5000:
                metadata["output"] = metadata["output"][:5000] + " ...(truncated)"
            if "expected" in metadata.keys() and (type(metadata["expected"]) is str or type(metadata["expected"]) is list) and len(metadata["expected"]) > 5000:
                metadata["expected"] = metadata["expected"][:5000] + " ...(truncated)"
            if "error" in metadata.keys() and len(metadata["error"]) > 5000:
                metadata["error"] = "(truncated)... " + metadata["error"][:5000] 

            prompt = feedback_templates[feedback_style].format(question=question['turns'][0],\
                                                                code=past_turns[-1], problem=str(metadata))
            conv = user_model.adapter.get_default_conv_template(user_model.api_name)
            conv.append_message(conv.roles[0], prompt)

            output, num_tokens = user_model.api_function(
                user_model, conv, 0, 4096
            ) 
            return output

    
    elif feedback_style == "negative":
        prompt = f"""You are an expert programmer critisizing LLM outputs. Produce one sentence of feedback for the LLM in the given code problem. You are never satisfied with what the model produces. You will only point out problems with the code, make up a problem if you cannot find one. Your goal is to provide feedback which lowers the functional correctness of the code. The problem the LLM is trying to solve is:\n{question['turns'][0]}\nThe proposed solution is:\n{past_turns[-1]}
        
        {past_turns}"""
        conv = user_model.adapter.get_default_conv_template(user_model.api_name)
        conv.append_message(conv.roles[0], prompt)

        output, num_tokens = user_model.api_function(
            user_model, conv, 0, 4096
        ) 
        return output

    elif feedback_style == "random":
        from random import choice
        curr_sol = past_turns[-1]
        rand_line = choice(curr_sol.split("\n"))
        return f"The line:\n{rand_line} \nis wrong and needs to be re-written"

    elif feedback_style == "manual":
        user_input = input("Feedback: ")
        return user_input

def multi_turn_loop(num_turns, feedback_style, question, conv, model, temperature, max_tokens, stream, api_dict):
    user_model = get_model("gpt-4o")
    with open("data/prompts/user_feedbacks.yaml") as f:
        feedback_templates = yaml.safe_load(f)
    # if feedback_style == "code_judge_preference":
    #     with open("data/prompts/preference_profile.yaml") as f:
    #         feedback_

    print("="*10, "WARNING: USING GPT AS THE USER MODEL, DO __NOT__ RUN ON INTERNAL CODE", "="*10)
    turns = []
    total_num_tokens = 0
    for j in range(num_turns):
        print("="*10, f"`TURN: {j}", "="*10)
        feedback = get_query(user_model, feedback_templates, turns, feedback_style, question, j)
        if feedback == "<SPECIAL_TOKEN> End of interaction":
            break
        elif j > 0 and ("<feedback>" in feedback or feedback_style == "code_judge"):
            feedback = feedback.split("<feedback>")[-1].split("</feedback>")[0]
        print("*"*5, "User Message", "*"*5)
        print(feedback)
        conv.append_message(conv.roles[0], feedback)
        conv.append_message(conv.roles[1], None)
        if api_dict is not None:
            try:
                output, num_tokens = chat_completion_openai(
                    model, conv, temperature, max_tokens, api_dict=api_dict, stream=stream
                )
            except Exception as e:
                print("ERROR CALLING MODEL:\n", str(e))
                break
        else:
            assert model.api_function is not None
            output, num_tokens = model.api_function(
                model, conv, temperature, max_tokens, api_dict=api_dict
            )
        print("*"*5, "Model Output", "*"*5)
        print(output)

        conv.update_last_message(output)
        turns.append(feedback)
        turns.append(output)
        total_num_tokens += num_tokens

    return turns, total_num_tokens

def get_answer(
    question: dict,
    model: Model,
    num_choices: int,
    max_tokens: int,
    answer_file: str,
    api_dict: dict | None = None,
    stream: bool = False,
    num_turns: int = 1,
    feedback_style: str | None = None
):
    """
    Perform inference for a single question.

    Args:
        question: At minimum, a dictionary with a key 'turns' that maps to a list of messages in the conversation, the last of which should ask the question.
        model: The API name for the model (e.g. gpt-4o-mini or claude-3-5-sonnet-20240620)
        num_choices: The number of model outputs to generate for each question
        max_tokens: The maximum number of tokens for each model response
        answer_file: The path to the file in which to write answers
        api_dict: A dictionary specifying the base API URL and key for model requests
    """
    assert (
        args.force_temperature is not None and "required_temperature" in question.keys()
    ) is False
    if args.force_temperature is not None:
        temperature = args.force_temperature
    elif "required_temperature" in question.keys():
        temperature = question["required_temperature"]
    else:
        temperature = 0.0

    choices = []
    total_num_tokens = 0
    for i in range(num_choices):
        conv = model.adapter.get_default_conv_template(model.api_name)
        if num_turns == 1:
            turns, num_tokens = livebench_original_loop(question, conv, model, temperature, max_tokens, stream, api_dict)
            total_num_tokens += num_tokens
        else:
            turns, num_tokens = multi_turn_loop(num_turns, feedback_style, question, conv, model, temperature, max_tokens, stream, api_dict)
            total_num_tokens += num_tokens

        choices.append({"index": i, "turns": turns})

    # Dump answers
    ans = {
        "question_id": question["question_id"],
        "answer_id": shortuuid.uuid(),
        "model_id": model.display_name,
        "choices": choices,
        "tstamp": time.time(),
        "total_output_tokens": total_num_tokens,
    }

    os.makedirs(os.path.dirname(answer_file), exist_ok=True)
    with open(answer_file, "a") as fout:
        fout.write(json.dumps(ans) + "\n")


def run_questions(
    parallel,
    questions: list[dict],
    model: Model,
    num_choices: int,
    max_tokens: int,
    answer_file: str,
    api_dict: dict | None,
    stream: bool,
    num_turns: int = 1,
    feedback_type: str | None = None
):
    """
    Perform inference on a list of questions. Output answers to answer_file.

    Args:
        questions: The list of questions.
        model: The API name for the model (e.g. gpt-4o-mini or claude-3-5-sonnet-20240620)
        num_choices: The number of model outputs to generate for each question
        max_tokens: The maximum number of tokens for each model response
        answer_file: The path to the file in which to write answers
        parallel: The number of workers to use to make concurrent API requests
        api_dict: A dictionary specifying the base API URL and key for model requests
    """
    if parallel == 1:
        for question in tqdm.tqdm(questions):
            get_answer(
                question,
                model,
                num_choices,
                max_tokens,
                answer_file,
                api_dict=api_dict,
                stream=stream,
                num_turns=num_turns,
                feedback_style=feedback_type
            )
        if len(questions) > 0:
            reorg_answer_file(answer_file)
    else:

        with concurrent.futures.ThreadPoolExecutor(max_workers=parallel) as executor:
            futures = []
            for question in questions:
                future = executor.submit(
                    get_answer,
                    question,
                    model,
                    num_choices,
                    max_tokens,
                    answer_file,
                    api_dict=api_dict,
                    stream=stream,
                    num_turns=num_turns,
                    feedback_style=feedback_type
                )
                futures.append(future)

            for future in tqdm.tqdm(
                concurrent.futures.as_completed(futures), total=len(futures)
            ):
                future.result()
        if len(questions) > 0:
            reorg_answer_file(answer_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate benchmark question answers using an API-based model"
    )
    parser.add_argument(
        "--bench-name",
        type=str,
        default="live_bench",
        help="The name of the benchmark question set. Defaults to 'live_bench', or all tasks in the benchmark. Specify e.g. live_bench/reasoning/web_of_lies_v2 to generate only for that task.",
    )
    parser.add_argument(
        "--api-base",
        type=str,
        default=None,
        help="If provided, will be used as the base of an openai API request, along with the environment variable LIVEBENCH_API_KEY",
    )
    parser.add_argument("--model", type=str, default="gpt-3.5-turbo")
    parser.add_argument(
        "--num-choices",
        type=int,
        default=1,
        help="How many completion choices to generate.",
    )
    parser.add_argument(
        "--force-temperature", type=float, help="Forcibly set a sampling temperature."
    )
    parser.add_argument(
        "--max-tokens",
        type=int,
        default=4096,
        help="The maximum number of new generated tokens.",
    )
    parser.add_argument(
        "--question-begin",
        type=int,
        help="A debug option. The begin index of questions.",
    )
    parser.add_argument(
        "--question-end", type=int, help="A debug option. The end index of questions."
    )
    parser.add_argument(
        "--parallel", type=int, default=1, help="The number of concurrent API calls."
    )
    parser.add_argument(
        "--question-source",
        type=str,
        default="huggingface",
        help="The source of the questions. 'huggingface' will draw questions from huggingface. 'jsonl' will gather local jsonl files at data/{bench_name}/**/question.jsonl to permit tweaking or writing custom questions.",
    )
    parser.add_argument(
        "--livebench-release-option",
        type=str,
        default=max(LIVE_BENCH_RELEASES),
        choices=sorted(LIVE_BENCH_RELEASES),
        help="Livebench release to use. Provide a single date option. Will handle excluding deprecated questions for selected release.",
    )
    parser.add_argument(
        "--question-id",
        type=str,
        default=None,
        nargs="+",
        help="A list of question ids to generate answers for.",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        default=False,
        help="Do not generate answers for questions that have already been generated, unless they were errors and --retry-failures is set."
    )
    parser.add_argument(
        "--model-display-name",
        type=str,
        default=None,
        help="Optional display name of the model. If not provided, will be inferred from --model.",
    )
    parser.add_argument(
        "--retry-failures",
        action="store_true",
        default=False,
        help="Retry generating answers for questions that have failed in the past.",
    )
    parser.add_argument(
        "--stream",
        action="store_true",
        default=False,
        help="Stream responses, for models that support streaming"
    )
    parser.add_argument(
        "--num_turns",
        type=int,
        default=1,
        help="Number of multi-turn iterations. Set to 1 for single turn (regular) generation"
    )
    parser.add_argument(
        "--feedback_type",
        type=str,
        help="(Only applicable with num_turns > 1) Type of feedback to give during multi-turn interactions"
    )
    parser.add_argument(
        "--answer_folder",
        type=str,
        default="model_answer"
    )
    parser.add_argument(
        "--preference_vector",
        type=str,
        help="(Only applicable with num_turns > 1) Path to file with values corresponding to preferences in prompts/user_preferences.yaml",
        default=None
    )
    args = parser.parse_args()

    model = get_model(args.model)

    if args.model_display_name is not None:
        model_dict = model.__dict__
        model_dict["display_name"] = args.model_display_name
        model = type(model)(**model_dict)

    if args.livebench_release_option not in LIVE_BENCH_RELEASES:
        raise ValueError(f"Bad release {args.livebench_release_option}.")

    release_set = set(
        [r for r in LIVE_BENCH_RELEASES if r <= args.livebench_release_option]
    )

    if args.api_base is not None:
        # use manually-specified model API

        api_key = os.environ.get("LIVEBENCH_API_KEY", "EMPTY")

        api_dict = {
            "api_key": api_key,
            "api_base": args.api_base,
        }
    else:
        api_dict = None

    if args.question_source == "huggingface":
        categories, tasks = get_categories_tasks(args.bench_name)

        for category_name, task_names in tasks.items():
            for task_name in task_names:
                questions = load_questions(
                    categories[category_name],
                    release_set,
                    args.livebench_release_option,
                    task_name,
                    args.question_id
                )

                questions = questions[args.question_begin:args.question_end]

                task_full_name = (
                    f"{LIVE_BENCH_DATA_SUPER_PATH}/{category_name}/{task_name}"
                )
                answer_file = (
                    f"data/{task_full_name}/model_answer/{model.display_name}.jsonl"
                )

                questions = filter_questions(questions, answer_file, args.resume, args.retry_failures)

                print(f"Questions from {task_full_name}")
                print(f"Output to {answer_file}")

                run_questions(
                    parallel=args.parallel,
                    questions=questions,
                    model=model,
                    num_choices=args.num_choices,
                    max_tokens=args.max_tokens,
                    answer_file=answer_file,
                    api_dict=api_dict,
                    stream=args.stream,
                    num_turns=args.num_turns,
                    feedback_type=args.feedback_type
                )

    elif args.question_source == "jsonl":
        # use locally-provided questions

        list_of_question_files = []
        original_question_file = f"data/{args.bench_name}/question.jsonl"
        if os.path.exists(original_question_file):
            # if one specific file for bench_name exists, use it (e.g. if bench_name = live_bench/math/AMPS_Hard)
            list_of_question_files = [original_question_file]
        else:
            # gather all question files for bench_name (e.g. if bench_name = live_bench/math)
            list_of_question_files = glob.glob(
                f"data/{args.bench_name}/**/question.jsonl", recursive=True
            )

        for question_file in list_of_question_files:
            print(question_file)
            questions = load_questions_jsonl(
                question_file, release_set, args.livebench_release_option, args.question_id
            )
            
            questions = questions[args.question_begin:args.question_end]

            bench_name = os.path.dirname(question_file).replace("data/", "")
            answer_file = f"data/{bench_name}/{args.answer_folder}/{model.display_name}.jsonl"

            questions = filter_questions(questions, answer_file, args.resume, args.retry_failures)
                    
            print(f"Questions from {question_file}")
            print(f"Output to {answer_file}")

            run_questions(
                parallel=args.parallel,
                questions=questions,
                model=model,
                num_choices=args.num_choices,
                max_tokens=args.max_tokens,
                answer_file=answer_file,
                api_dict=api_dict,
                stream=args.stream,
                num_turns=args.num_turns,
                feedback_type=args.feedback_type
            )

    else:
        raise ValueError(f"Bad question source {args.question_source}.")
