import argparse
import os
import json
import random
import torch
import vllm
import evaluate
import numpy as np
from eval.utils import (
    generate_completions, 
    load_hf_lm, 
    query_openai_chat_model,
    dynamic_import_function,
    load_hf_tokenizer,
)


encoding_templates_with_context = {
    "english": ("Answer the following question based on the information in the given passage.", "Passage:", "Question:", "Answer:"),
    "arabic": ("أجب على السؤال التالي بناءً على المعلومات في المقطع المعطى.", "المقطع:", "السؤال:", "الإجابة:"),
    "bengali": ("প্রদত্ত অধ্যায়ের তথ্যের উপর ভিত্তি করে নিম্নলিখিত প্রশ্নের উত্তর দিন।", "অধ্যায়:", "প্রশ্ন:", "উত্তর:"),
    "finnish": ("Vastaa seuraavaan kysymykseen annetun kappaleen tiedon perusteella.", "Kappale:", "Kysymys:", "Vastaus:"),
    "indonesian": ("Jawab pertanyaan berikut berdasarkan informasi di bagian yang diberikan.", "Bagian:", "Pertanyaan:", "Jawaban:"),
    "korean": ("주어진 문단의 정보에 기반하여 다음 질문에 답하십시오.", "문단:", "질문:", "답변:"),
    "russian": ("Ответьте на следующий вопрос на основе информации в данном отрывке.", "Отрывок:", "Вопрос:", "Ответ:"),
    "swahili": ("Jibu swali lifuatalo kulingana na habari kwenye kifungu kilichotolewa.", "Kifungu:", "Swali:", "Jibu:"),
    "telugu": ("ఇచ్చిన పేరాలోని సమాచారం ఆధారంగా కింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "పేరా:", "ప్రశ్న:", "సమాధానం:")
}

encoding_templates_without_context = {
    "english": ("Answer the following question.", "Question:", "Answer:"),
    "arabic": ("أجب على السؤال التالي.", "السؤال:", "الإجابة:"),
    "bengali": ("নিম্নলিখিত প্রশ্নের উত্তর দিন।", "প্রশ্ন:", "উত্তর:"),
    "finnish": ("Vastaa seuraavaan kysymykseen.", "Kysymys:", "Vastaus:"),
    "indonesian": ("Jawab pertanyaan berikut.", "Pertanyaan:", "Jawaban:"),
    "korean": ("다음 질문에 답하십시오.", "질문:", "답변:"),
    "russian": ("Ответьте на следующий вопрос.", "Вопрос:", "Ответ:"),
    "swahili": ("Jibu swali lifuatalo.", "Swali:", "Jibu:"),
    "telugu": ("క్రింది ప్రశ్నకు సమాధానం ఇవ్వండి.", "ప్రశ్న:", "సమాధానం:")
}


def main(args):
    random.seed(42)

    print("Loading data...")

    test_data = []
    with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-dev.json")) as fin:
        dev_data = json.load(fin)
        for article in dev_data["data"]:
            for paragraph in article["paragraphs"]:
                for qa in paragraph["qas"]:
                    example = {
                        "id": qa["id"], 
                        "lang": qa["id"].split("-")[0],
                        "context": paragraph["context"],
                        "question": qa["question"],
                        "answers": qa["answers"]
                    }
                    test_data.append(example)
    data_languages = sorted(list(set([example["lang"] for example in test_data]))) 
    if args.max_num_examples_per_lang:
        sampled_examples = []
        for lang in data_languages:
            examples_for_lang = [example for example in test_data if example["lang"] == lang]
            if len(examples_for_lang) > args.max_num_examples_per_lang:
                examples_for_lang = random.sample(examples_for_lang, args.max_num_examples_per_lang)
            sampled_examples += examples_for_lang
        test_data = sampled_examples
    
    print(f"Loaded {len(test_data)} examples from {len(data_languages)} languages: {data_languages}")

    if args.n_shot > 0:
        train_data_for_langs = {lang: [] for lang in data_languages}
        with open(os.path.join(args.data_dir, "tydiqa-goldp-v1.1-train.json")) as fin:
            train_data = json.load(fin)
            for article in train_data["data"]:
                for paragraph in article["paragraphs"]:
                    for qa in paragraph["qas"]:
                        lang = qa["id"].split("-")[0]
                        if lang in data_languages:
                            example = {
                                "id": qa["id"],
                                "lang": lang,
                                "context": paragraph["context"],
                                "question": qa["question"],
                                "answers": qa["answers"]
                            }
                            train_data_for_langs[lang].append(example)
            for lang in data_languages:
                # sample n_shot examples from each language
                train_data_for_langs[lang] = random.sample(train_data_for_langs[lang], args.n_shot)
        # assert that we have exactly n_shot examples for each language
        assert all([len(train_data_for_langs[lang]) == args.n_shot for lang in data_languages])

    
    # assert we have templates for all data languages
    assert all([lang in encoding_templates_with_context.keys() for lang in data_languages])
        
    if args.model_name_or_path:
        print("Loading model and tokenizer...")
        tokenizer = load_hf_tokenizer(
            model_name_or_path=args.model_name_or_path,
            tokenizer_name_or_path=args.tokenizer_name_or_path,
            use_fast_tokenizer=not args.use_slow_tokenizer,
        )
        if args.use_vllm:
            model = vllm.LLM(
                model=args.model_name_or_path,
                tokenizer=args.tokenizer_name_or_path if args.tokenizer_name_or_path else args.model_name_or_path,
                tokenizer_mode="slow" if args.use_slow_tokenizer else "auto",
                tensor_parallel_size=torch.cuda.device_count(),
            )
            
        else:
            model = load_hf_lm(
                model_name_or_path=args.model_name_or_path, 
                load_in_8bit=args.load_in_8bit, 
                device_map="balanced_low_0" if torch.cuda.device_count() > 1 else "auto",
                gptq_model=args.gptq,
            )
            from transformers import GPTNeoXForCausalLM, OPTForCausalLM
            if isinstance(model, GPTNeoXForCausalLM) or isinstance(model, OPTForCausalLM):
                tokenizer.model_max_length = model.config.max_position_embeddings
                print("Set tokenizer.model_max_length to model.config.max_position_embeddings: {}".format(model.config.max_position_embeddings))
    else:
        import tiktoken
        tokenizer = tiktoken.get_encoding("cl100k_base")

    # reduce context length to max_context_length
    if args.max_context_length:
        for example in test_data:
            tokenized_context = tokenizer.encode(example["context"])
            if len(tokenized_context) > args.max_context_length:
                example["context"] = tokenizer.decode(tokenized_context[:args.max_context_length])
        if args.n_shot > 0:
            for lang in data_languages:
                for example in train_data_for_langs[lang]:
                    tokenized_context = tokenizer.encode(example["context"])
                    if len(tokenized_context) > args.max_context_length:
                        example["context"] = tokenizer.decode(tokenized_context[:args.max_context_length])

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir, exist_ok=True)

    prompts = []
    chat_formatting_function = dynamic_import_function(args.chat_formatting_function) if args.use_chat_format else None
    for example in test_data:
        lang = example["lang"]

        if args.no_context:
            prompt, q_template, a_template = encoding_templates_without_context[lang]
            p_template = ""
        else:
            prompt, p_template, q_template, a_template = encoding_templates_with_context[lang]

        prompt += "\n\n"
        
        if args.n_shot > 0:
            formatted_demo_examples = []
            for train_example in train_data_for_langs[lang]:
                if args.no_context:
                    formatted_demo_examples.append(
                        q_template + " " + train_example["question"] + "\n" + a_template + " " + train_example["answers"][0]["text"]
                    )
                else:
                    formatted_demo_examples.append(
                        p_template + " " + train_example["context"] + "\n" + q_template + " " + train_example["question"] + "\n" + a_template + " " + train_example["answers"][0]["text"]
                    )
            prompt += "\n\n".join(formatted_demo_examples) + "\n\n"
        
        if args.no_context:
            prompt += q_template + " " + format(example["question"]) + "\n"
        else:
            prompt += p_template + " " + format(example["context"]) + "\n" + q_template + " " + format(example["question"]) + "\n"

        if args.use_chat_format:
            messages = [{"role": "user", "content": prompt}]
            prompt = chat_formatting_function(messages, tokenizer, add_bos=False)
            prompt += a_template if prompt[-1] in ["\n", " "] else " " + a_template
        else:
            prompt += a_template
        prompts.append(prompt)

    if args.model_name_or_path:
        if args.use_vllm:
            sampling_params = vllm.SamplingParams(
                temperature=0,
                max_tokens=50,
                stop=["\n"] if not args.use_chat_format else None,  # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
            )
            # We need to remap the outputs to the prompts because vllm might not return outputs for some prompts (e.g., if the prompt is too long)
            generations = model.generate(prompts, sampling_params)
            prompt_to_output = {
                g.prompt: g.outputs[0].text for g in generations
            }
            outputs = [prompt_to_output[prompt].strip() if prompt in prompt_to_output else "" for prompt in prompts]
        else:
            new_line_token = tokenizer.encode("\n", add_special_tokens=False)[-1] # get the last token because the tokenizer may add space tokens at the start.
            outputs = generate_completions(
                model=model,
                tokenizer=tokenizer,
                prompts=prompts,
                max_new_tokens=50,
                batch_size=args.eval_batch_size,
                stop_id_sequences=[[new_line_token]] if not args.use_chat_format else None,  # we only use stop token for non-chat format (usually applied to vanilla pretrained language models). For chat format, we will rely on the model knows when to stop.
            )
            # remove unnecessary space
            outputs = [output.strip() for output in outputs]
    else:
        instances = [{"id": example["id"], "prompt": prompt} for example, prompt in zip(test_data, prompts)]
        results = query_openai_chat_model(
            engine=args.openai_engine,
            instances=instances,
            output_path=os.path.join(args.save_dir, "tydiaqa_openai_results.jsonl"),
            batch_size=args.eval_batch_size,
        )
        outputs = [result["output"].strip().split("\n")[0].strip() for result in results]
    
    with open(os.path.join(args.save_dir, "tydiaqa_predictions.jsonl"), "w") as fout:
        for example, output in zip(test_data, outputs):
            example["prediction_text"] = output
            fout.write(json.dumps(example) + "\n")

    print("Calculating F1, EM ...")
    metric = evaluate.load("squad")
    
    eval_scores = {}
    for lang in data_languages:
        lang_predictions = [{"id": example["id"], "prediction_text": output} for example, output in zip(test_data, outputs) if example["lang"] == lang]
        lang_references = [{"id": example["id"], "answers": example["answers"]} for example in test_data if example["lang"] == lang]
        eval_scores[lang] = metric.compute(predictions=lang_predictions, references=lang_references)

    print("Calculating recall ...")
    for lang in data_languages:
        lang_predictions = [output for example, output in zip(test_data, outputs) if example["lang"] == lang]
        lang_references = [example["answers"] for example in test_data if example["lang"] == lang]
        recalls = []
        for pred, refs in zip(lang_predictions, lang_references):
            recall = 100.0 if any([ref['text'].strip().lower() in pred.strip().lower() for ref in refs]) else 0.0
            recalls.append(recall)
        eval_scores[lang]['recall'] = np.mean(recalls)

    eval_scores["average"] = {metric: np.mean([scores[metric] for scores in eval_scores.values()]) for metric in ["f1", "exact_match", "recall"]}

    print("Scores:")
    print(json.dumps(eval_scores, indent=4))
    
    with open(os.path.join(args.save_dir, "metrics.json"), "w") as fout:
        json.dump(eval_scores, fout, indent=4)
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data_dir",
        type=str,
        default="data/xorqa/"
    )
    parser.add_argument(
        "--max_num_examples_per_lang",
        type=int,
        default=None,
        help="maximum number of examples per language to evaluate."
    )
    parser.add_argument(
        "--n_shot",
        type=int,
        default=1,
        help="number of examples to use for few-shot evaluation."
    )
    parser.add_argument(
        "--no_context",
        action="store_true",
        help="If given, we're evaluating a model without the gold context passage."
    )
    parser.add_argument(
        "--max_context_length",
        type=int,
        default=512,
        help="maximum number of tokens in the context passage."
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="results/tydiqa/"
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default=None,
        help="if specified, we will load the model to generate the predictions."
    )
    parser.add_argument(
        "--tokenizer_name_or_path",
        type=str,
        default=None,
        help="if specified, we will load the tokenizer from here."
    )
    parser.add_argument(
        "--use_slow_tokenizer",
        action="store_true",
        help="If given, we will use the slow tokenizer."
    )
    parser.add_argument(
        "--openai_engine",
        type=str,
        default=None,
        help="if specified, we will use the OpenAI API to generate the predictions."
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=1,
        help="batch size for evaluation."
    )
    parser.add_argument(
        "--load_in_8bit",
        action="store_true",
        help="load model in 8bit mode, which will reduce memory and speed up inference."
    )
    parser.add_argument(
        "--gptq",
        action="store_true",
        help="If given, we're evaluating a 4-bit quantized GPTQ model."
    )
    parser.add_argument(
        "--use_vllm",
        action="store_true", 
        help="If given, we will use the vllm library, which will likely increase the inference throughput."
    )
    parser.add_argument(
        "--use_chat_format", 
        action="store_true", 
        help="If given, we will use the chat format for the prompts."
    )
    parser.add_argument(
        "--chat_formatting_function", 
        type=str, 
        default="eval.templates.create_prompt_with_tulu_chat_format", 
        help="The function to use to create the chat format. This function will be dynamically imported. Please see examples in `eval/templates.py`."
    )
    args = parser.parse_args()
    # model_name_or_path and openai_engine cannot be both None or both not None.
    assert (args.model_name_or_path is None) != (args.openai_engine is None), "Either model_name_or_path or openai_engine should be specified."
    main(args)
