import argparse
import csv
import numpy
import gc
import sys
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
from model_loader import load_model_and_apply_patches, add_args


ZERO_SHOT_PROMPT = "Please read the following text and answer the question below.\n\n<text>\n{context}\n</text>\n\nWhat is the correct answer to this question: {question}\nChoices:\n(A) {a}\n(B) {b}\n(C) {c}\n(D) {d}\n\nChoose the best answer by writing its corresponding letter (either A, B, C, or D)."

ZERO_SHOT_PROMPT_QWEN = "<|im_start|>system\nYou are a helpful AI bot that answers questions for a user. Keep your response short and direct, following the rules provided by user.<|im_end|>\n<|im_start|>user\n<context>\n{context}\n</context>\n\n<question>What is the correct answer to this question: {question}\n</question>\n\n<choices>\n(A) {a}\n(B) {b}\n(C) {c}\n(D) {d}\n</choices>\n\n<rule>\nChoose the best answer by writing its corresponding letter (either A, B, C, or D).\n</rule><|im_end|>\n<|im_start|>assistant\nHere is the letter of the correct answer:"

ZERO_SHOT_PROMPT_LLAMA3 = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a helpful AI bot that answers questions for a user. Keep your response short and direct, following the rules provided by user.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<context>\n{context}\n</context>\n\n<question>What is the correct answer to this question: {question}\n</question>\n\n<choices>\n(A) {a}\n(B) {b}\n(C) {c}\n(D) {d}\n</choices>\n\n<rule>\nChoose the best answer by writing its corresponding letter (either A, B, C, or D).\n</rule><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHere is the letter of the correct answer:"
CHOICES = ["A", "B", "C", "D"]

def get_prompt(sample):
    #options = sample["options"]
    instruction = ZERO_SHOT_PROMPT_LLAMA3.format(
        context=sample["context"], question=sample["question"], a=sample["choice_A"], b=sample["choice_B"], c=sample["choice_C"], d=sample["choice_D"])
    return f"{instruction}"


def main(args):
    models = [x[0] for x in args.model]

    tokenizer = AutoTokenizer.from_pretrained(
        models[0], model_max_length=sys.maxsize, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    dataset = load_dataset("./testset/longbenchv2", split=args.split)
    dataset = dataset.map(lambda sample: {"prompt": get_prompt(sample)})
    if args.filter_length:
        dataset = dataset.filter(lambda sample: sample["length"] != args.filter_length)

    if args.max_tokens:
        dataset = dataset.filter(lambda sample: len(
            tokenizer(sample["prompt"]).input_ids) <= args.max_tokens)
        
    if args.min_tokens:
        dataset = dataset.filter(lambda sample: len(
            tokenizer(sample["prompt"]).input_ids) >= args.min_tokens)

    choice_tokens = [x[0] for x in tokenizer(
        CHOICES, add_special_tokens=False).input_ids]
    decoded_choice = tokenizer.decode(
        choice_tokens, clean_up_tokenization_spaces=True)


    
    for model in models:
        torch.cuda.empty_cache()

        loaded = load_model_and_apply_patches(model, args)
        num_of_samples = {}
        num_of_correct = {}
        correct_answers = 0
        i = 0
        max = len(dataset) if args.limit is None else args.limit
        
        if max ==0:
            print(f"No samples found for model {model} with the given filters.")
            break
        else:
            print(f"Evaluating {model} on {max} samples...")

        bar = tqdm(total=max)
        while i < max:
            sample = dataset[i]

            
            domain = sample["domain"]

            if domain not in num_of_samples:
                num_of_samples[domain] = 1
                num_of_correct[domain] = 0
            else:
                num_of_samples[domain] += 1
            
            tokenized_prompt = tokenizer(sample["prompt"], return_tensors="pt")
            input_ids = tokenized_prompt.input_ids.to("cuda")
            attention_mask = tokenized_prompt.attention_mask.to("cuda")

            output = loaded.generate(input_ids, attention_mask=attention_mask,
                                    max_new_tokens=1, return_dict_in_generate=True, output_scores=True, pad_token_id=tokenizer.eos_token_id)
            scores = output.scores[0][0]
            choice_scores = [x.cpu() for x in [scores[choice_tokens[0]],
                                            scores[choice_tokens[1]], scores[choice_tokens[2]], scores[choice_tokens[3]]]]
            selection = numpy.argmax([x.float().cpu() for x in choice_scores])

            correct_answers += 1 if CHOICES[selection] == sample["answer"] else 0
            if CHOICES[selection] == sample["answer"]:
                num_of_correct[domain] += 1

            if args.print_choices:
                print(
                    f"Choice: {CHOICES[selection]} Correct: {sample['answer']}")

            i += 1
            percent = (correct_answers / i) * 100.0

            bar.desc = f"{model}: {percent:.1f}%"
            bar.update()

            if args.aggressive_memory:
                output = None
                input_ids = None
                gc.collect()
                torch.cuda.empty_cache()

        percent = correct_answers / max


        num_of_samples["avg"] = max
        num_of_correct["avg"] = correct_answers
        accuracies = {}


        for domain in num_of_samples:
            domain_accuracy = (num_of_correct[domain] / num_of_samples[domain]) * 100.0
            accuracies[domain] = domain_accuracy
            print(f"Domain: {domain}, Accuracy: {domain_accuracy:.2f} ({num_of_correct[domain]}/{num_of_samples[domain]})")


    if args.output_file:
        with open(args.output_file, "w", newline="",encoding="utf-8") as f:
            writer = csv.writer(f)
            writer.writerow(["domain", "acc"]) 
            for k, v in accuracies.items():
                writer.writerow([k, v])



        

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", action="append", nargs="+")
    parser.add_argument("--limit", type=int)
    parser.add_argument("--max-tokens", type=int)
    parser.add_argument("--min-tokens", type=int)
    parser.add_argument("--split", type=str, default="validation")
    parser.add_argument("--print-choices", action="store_true")
    parser.add_argument("--output-file", type=str)
    parser.add_argument("--filter-length", type=str)
    parser.add_argument("--aggressive-memory", action="store_true")

    args = add_args(parser).parse_args()
    main(args)
