import re
import click
import pandas as pd
from eval.util import (
    # load_model_and_tokenizer,
    # batched_generate,
    parse_number,
    format_example,
    prep_incontext_examples,
    write_results,
)
from eval.eval_hacks import load_sampler, batched_generate
from utils import read_json, seed_all
from tqdm.auto import tqdm as tqdm
import datasets
from pathlib import Path
import json
seed_all(42)

# Answer the following multiple-choice question by giving the correct answer letter in parentheses.
# Provide CONCISE reasoning for the answer, and make sure to finish the response with "Therefore, the
# answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B), (C), (D), (E), etc.
# Question: {question}
# (A) {choice_A}
# (B) {choice_B}
# (C) ...
# Answer the above question and REMEMBER to finish your response with the exact phrase
# "Therefore, the answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B),
# (C), (D), (E), etc.

def evaluate_mmlu(sampler_factory,  test_df, batch_size, qa_format):
    test_df = test_df.reset_index(drop=True)
    # incontext_indices = prep_incontext_examples(test_df, num_incontext_examples)

    prompts = []
    for i, row in test_df.iterrows():
        assert row['answer'] < 4
        prompt = [
            "Answer the following multiple-choice question by giving the correct answer letter in parentheses. "
            'Provide CONCISE reasoning for the answer, and make sure to finish the response with '
            '"Therefore, the answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B), (C), (D).\n'
            f'Question: {row["question"].strip()}\n'
        ]
        for i, choice in enumerate(row['choices']):
            letter = "ABCD"[i]
            prompt.append(f"({letter}): {choice.strip()}\n")

        prompt.append(
            "Answer the above question and REMEMBER to finish your response with the exact phrase "
            '"Therefore, the answer is (ANSWER_LETTER)" where (ANSWER_LETTER) is one of (A), (B), (C), (D).'
        )
        prompt.append(qa_format)
        prompt = ''.join(prompt)
        prompts.append(prompt)

    print(f"--- MMLU example prompt ---\n{prompts[0]}\n----------------------")

    outputs = batched_generate(
        prompts=prompts,
        bsf=sampler_factory,
        do_sample=False,
        max_new_bytes=1000,
        batch_size=batch_size,
        stop_strings=[f"Therefore, the answer is ({letter})" for letter in "ABCDE"],
        include_stop_str_in_output=True
    )

    results = []
    for prompt, output, answer, subject in zip(prompts, outputs, test_df.answer, test_df.subject):
        # parsed_pred = parse_mc_pred(output, qa_format=qa_format)
        results.append(
            {
                "prompt": prompt,
                "output": output,
                "answer": "ABCD"[answer],
                "valid": output.endswith(tuple([f"({letter})" for letter in "ABCD"])),
                "correct": output.endswith(f"({'ABCD'[answer]})"),
            }
        )

    return results


@click.command()
@click.option("--model_name_or_path", type=str, default="pile-npt25k")
@click.option("--output_dir", type=str, default="results/squad/olmo-20k")
@click.option("--eval_batch_size", type=int, default=4)
@click.option("--start", type=int)
@click.option("--end", type=int)
def main(
    model_name_or_path: str,
    output_dir: str,
    eval_batch_size: int,
    start: int,
    end: int
):
    out_dir = Path(output_dir)
    out_dir.mkdir(exist_ok=True)
    out_file = out_dir/ f'out_{start}_{end}.json'
    if out_file.exists():
        return 

    is_base = False
    if model_name_or_path.endswith("Base"):
        is_base = True
    if model_name_or_path.startswith("meta-llama/Llama-3") and model_name_or_path.endswith("B"):
        is_base = True

    if is_base:
        print("Detected base model")

    sampler_factory = load_sampler(model_name_or_path)
    # sampler_factory=None
    test_df = pd.read_json("olmo_data/eval/mmlu/test.jsonl", lines=True)
    results = evaluate_mmlu(
        sampler_factory,
        test_df[start:end],
        batch_size=eval_batch_size,
        qa_format='\nAnswer: ' if is_base else '',
    )
    with open(out_file, 'w') as f:
        json.dump(results, f)


if __name__ == "__main__":
    main()
# main("allenai/OLMo-2-0425-1B-Instruct", None, 5, 10, 'qa')