import argparse
import os
import json
import random
import torch
try:
    import vllm
except ImportError:
    print("VLLM not installed. Will not be able to use VLLM.")
    vllm = None
import evaluate
import numpy as np
from minimal_multitask.eval.utils import (
    generate_completions,
    load_hf_lm,
    query_openai_chat_model,
    dynamic_import_function,
    load_hf_tokenizer,
)


encoding_templates_with_context = {
    "english": (
        "Answer the following question based on the information in the given passage.",
        "Passage:",
        "Question:",
        "Answer:",
    ),
    "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"),
    "bengali": (
        "প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।",
        "অধ্যায়:",
        "প্রশ্ন:",
        "উত্তর:",
    ),
    "finnish": (
        "Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.",
        "Kappale:",
        "Kysymys:",
        "Vastaus:",
    ),
    "indonesian": (
        "Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.",
        "Bagian:",
        "Pertanyaan:",
        "Jawaban:",
    ),
    "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"),
    "russian": (
        "Ответьте на следующий вопрос на основе информации в данном отрывке.",
        "Отрывок:",
        "Вопрос:",
        "Ответ:",
    ),
    "swahili": (
        "Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.",
        "Kifungu:",
        "Swali:",
        "Jibu:",
    ),
    "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:"),
}

encoding_templates_without_context = {
    "english": ("Answer the following question.", "Question:", "Answer:"),
    "arabic": ("أجب على السؤال التالي.", "السؤال:", "الإجابة:"),
    "bengali": ("নিম্নলিখিত প্রশ্নের উত্তর দিন।", "প্রশ্ন:", "উত্তর:"),
    "finnish": ("Vastaa seuraavaan kysymykseen.", "Kysymys:", "Vastaus:"),
    "indonesian": ("Jawab pertanyaan berikut.", "Pertanyaan:", "Jawaban:"),
    "korean": ("다음 질문에 답하십시오.", "질문:", "답변:"),
    "russian": ("Ответьте на следующий вопрос.", "Вопрос:", "Ответ:"),
    "swahili": ("Jibu swali lifuatalo.", "Swali:", "Jibu:"),
    "telugu": ("క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "ప్రశ్న:", "సమాధానం:"),
}


def main(args):
    random.seed(42)

    print("Loading data...")

    test_data = []
    with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
        dev_data = json.load(fin)
        for article in dev_data["data"]:
            for paragraph in article["paragraphs"]:
                for qa in paragraph["qas"]:
                    example = {
                        "id": qa["id"],
                        "lang": qa["id"].split("-")[0],
                        "context": paragraph["context"],
                        "question": qa["question"],
                        "answers": qa["answers"],
                    }
                    test_data.append(example)
    data_languages = sorted(list(set([example["lang"] for example in test_data])))
    if args.max_num_examples_per_lang:
        sampled_examples = []
        for lang in data_languages:
            examples_for_lang = [example for example in test_data if example["lang"] == lang]
            if len(examples_for_lang) > args.max_num_examples_per_lang:
                examples_for_lang = random.sample(examples_for_lang, args.max_num_examples_per_lang)
            sampled_examples += examples_for_lang
        test_data = sampled_examples

    print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

    if args.n_shot > 0:
        train_data_for_langs = {lang: [] for lang in data_languages}
        with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
            train_data = json.load(fin)
            for article in train_data["data"]:
                for paragraph in article["paragraphs"]:
                    for qa in paragraph["qas"]:
                        lang = qa["id"].split("-")[0]
                        if lang in data_languages:
                            example = {
                                "id": qa["id"],
                                "lang": lang,
                                "context": paragraph["context"],
                                "question": qa["question"],
                                "answers": qa["answers"],
                            }
                            train_data_for_langs[lang].append(example)
            for lang in data_languages:
                # sample n_shot examples from each language
                train_data_for_langs[lang] = random.sample(train_data_for_langs[lang], args.n_shot)
        # assert that we have exactly n_shot examples for each language
        assert all([len(train_data_for_langs[lang]) == args.n_shot for lang in data_languages])

    # assert we have templates for all data languages
    assert all([lang in encoding_templates_with_context.keys() for lang in data_languages])

    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        tokenizer = load_hf_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path,
            use_fast_tokenizer=not args.use_slow_tokenizer,
        )
        if args.use_vllm:
            model = vllm.LLM(
                model=args.model_name_or_path,
                tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
                tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
                tensor_parallel_size=torch.cuda.device_count(),
            )

        else:
            model = load_hf_lm(
                model_name_or_path=args.model_name_or_path,
                load_in_8bit=args.load_in_8bit,
                device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
                gptq_model=args.gptq,
            )
            from transformers import GPTNeoXForCausalLM, OPTForCausalLM

            if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM):
                tokenizer.model_max_length = model.config.max_position_embeddings
                print(
                    "Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(
                        model.config.max_position_embeddings
                    )
                )
    else:
        import tiktoken

        tokenizer = tiktoken.get_encoding("cl100k_base")

    # reduce context length to max_context_length
    if args.max_context_length:
        for example in test_data:
            tokenized_context = tokenizer.encode(example["context"])
            if len(tokenized_context) > args.max_context_length:
                example["context"] = tokenizer.decode(tokenized_context[: args.max_context_length])
        if args.n_shot > 0:
            for lang in data_languages:
                for example in train_data_for_langs[lang]:
                    tokenized_context = tokenizer.encode(example["context"])
                    if len(tokenized_context) > args.max_context_length:
                        example["context"] = tokenizer.decode(tokenized_context[: args.max_context_length])

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)

    prompts = []
    chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
    for example in test_data:
        lang = example["lang"]

        if args.no_context:
            prompt, q_template, a_template = encoding_templates_without_context[lang]
            p_template = ""
        else:
            prompt, p_template, q_template, a_template = encoding_templates_with_context[lang]

        prompt += "\n\n"

        if args.n_shot > 0:
            formatted_demo_examples = []
            for train_example in train_data_for_langs[lang]:
                if args.no_context:
                    formatted_demo_examples.append(
                        q_template
                        + " "
                        + train_example["question"]
                        + "\n"
                        + a_template
                        + " "
                        + train_example["answers"][0]["text"]
                    )
                else:
                    formatted_demo_examples.append(
                        p_template
                        + " "
                        + train_example["context"]
                        + "\n"
                        + q_template
                        + " "
                        + train_example["question"]
                        + "\n"
                        + a_template
                        + " "
                        + train_example["answers"][0]["text"]
                    )
            prompt += "\n\n".join(formatted_demo_examples) + "\n\n"

        if args.no_context:
            prompt += q_template + " " + format(example["question"]) + "\n"
        else:
            prompt += (
                p_template
                + " "
                + format(example["context"])
                + "\n"
                + q_template
                + " "
                + format(example["question"])
                + "\n"
            )

        if args.use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, tokenizer, add_bos=False)
            prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
        else:
            prompt += a_template
        prompts.append(prompt)

    if args.model_name_or_path:
        if args.use_vllm:
            sampling_params = vllm.SamplingParams(
                temperature=0,
                max_tokens=50,
                # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
                stop=["\n"] if not args.use_chat_format else None,
            )
            # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
            generations = model.generate(prompts, sampling_params)
            prompt_to_output = {g.prompt: g.outputs[0].text for g in generations}
            outputs = [prompt_to_output[prompt].strip() if prompt in prompt_to_output else "" for prompt in prompts]
        else:
            # get the last token because the tokenizer may add space tokens at the start.
            new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1]
            outputs = generate_completions(
                model=model,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=50,
                batch_size=args.eval_batch_size,
                # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
                stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None,
            )
            # remove unnecessary space
            outputs = [output.strip() for output in outputs]
    else:
        instances = [{"id": example["id"], "prompt": prompt} for example, prompt in zip(test_data, prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            output_path=os.path.join(args.save_dir, "tydiaqa_openai_results.jsonl"),
            batch_size=args.eval_batch_size,
        )
        outputs = [result["output"].strip().split("\n")[0].strip() for result in results]

    with open(os.path.join(args.save_dir, "tydiaqa_predictions.jsonl"), "w") as fout:
        for example, output in zip(test_data, outputs):
            example["prediction_text"] = output
            fout.write(json.dumps(example) + "\n")

    print("Calculating F1, EM ...")
    metric = evaluate.load("squad")

    eval_scores = {}
    for lang in data_languages:
        lang_predictions = [
            {"id": example["id"], "prediction_text": output}
            for example, output in zip(test_data, outputs)
            if example["lang"] == lang
        ]
        lang_references = [
            {"id": example["id"], "answers": example["answers"]} for example in test_data if example["lang"] == lang
        ]
        eval_scores[lang] = metric.compute(predictions=lang_predictions, references=lang_references)

    eval_scores["average"] = {
        metric: np.mean([scores[metric] for scores in eval_scores.values()]) for metric in ["f1", "exact_match"]
    }

    print("Scores:")
    print(json.dumps(eval_scores, indent=4))

    with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
        json.dump(eval_scores, fout, indent=4)
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", type=str, default="data/xorqa/")
    parser.add_argument(
        "--max_num_examples_per_lang",
        type=int,
        default=None,
        help="maximum number of examples per language to evaluate.",
    )
    parser.add_argument("--n_shot", type=int, default=1, help="number of examples to use for few-shot evaluation.")
    parser.add_argument(
        "--no_context",
        action="store_true",
        help="If given, we're evaluating a model without the gold context passage.",
    )
    parser.add_argument(
        "--max_context_length", type=int, default=512, help="maximum number of tokens in the context passage."
    )
    parser.add_argument("--save_dir", type=str, default="results/tydiqa/")
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="if specified, we will load the model to generate the predictions.",
    )
    parser.add_argument(
        "--tokenizer_name_or_path", type=str, default=None, help="if specified, we will load the tokenizer from here."
    )
    parser.add_argument("--use_slow_tokenizer", action="store_true", help="If given, we will use the slow tokenizer.")
    parser.add_argument(
        "--openai_engine",
        type=str,
        default=None,
        help="if specified, we will use the OpenAI API to generate the predictions.",
    )
    parser.add_argument("--eval_batch_size", type=int, default=1, help="batch size for evaluation.")
    parser.add_argument(
        "--load_in_8bit",
        action="store_true",
        help="load model in 8bit mode, which will reduce memory and speed up inference.",
    )
    parser.add_argument("--gptq", action="store_true", help="If given, we're evaluating a 4-bit quantized GPTQ model.")
    parser.add_argument(
        "--use_vllm",
        action="store_true",
        help="If given, we will use the vllm library, which will likely increase the inference throughput.",
    )
    parser.add_argument(
        "--use_chat_format", action="store_true", help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function",
        type=str,
        default="minimal_multitask.eval.templates.create_prompt_with_tulu_chat_format",
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )
    args = parser.parse_args()
    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (
        args.openai_engine is None
    ), "Either model_name_or_path or openai_engine should be specified."
    main(args)
