import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import torch as t
import argparse
import csv
from cut.utils import load_model

ans_map = {
    'A': 0,
    'B': 1,
    'C': 2,
    'D': 3
}

def prepare_data(data, batch_size=8):
    """
    Return a generator of batches of the form (text_batch, answers_batch)
    """
    batch = []
    for row in data:

        question = f"""\
The following are multiple choice questions (with answers).

{row[0]}
A. {row[1]}
B. {row[2]}
C. {row[3]}
D. {row[4]}
Answer:
"""
        ans = row[5]
        batch.append((question, ans_map[ans]))
        if len(batch) == batch_size:
            yield batch
            batch = []

def get_accuracy(model, tokenizer, batches):

    # get token idxs for A, B, C, D
    A_idx = tokenizer.encode("A")[-1]
    B_idx = tokenizer.encode("B")[-1]
    C_idx = tokenizer.encode("C")[-1]
    D_idx = tokenizer.encode("D")[-1]
    choice_idxs = t.tensor([A_idx, B_idx, C_idx, D_idx]).to(model.device)


    corrects = []
    for batch in batches:
        texts = [x[0] for x in batch]
        answers = t.tensor([x[1] for x in batch]).to(model.device)
        inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
        outputs = model(**inputs).logits[:, -1, choice_idxs]
        predictions = outputs.argmax(dim=-1)
        corrects.extend((predictions == answers).tolist())
    return corrects

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate a model on a multiple choice dataset')
    parser.add_argument('--model_name_or_path', type=str, default="HuggingFaceH4/zephyr-7b-beta")
    parser.add_argument('--data_dir', type=str, default='data/mmlu/mmlu.jsonl')
    parser.add_argument('--batch_size', type=int, default=8, help='The batch size to use for evaluation')
    args = parser.parse_args()

    model, tokenizer = load_model(args.model_name_or_path)
    t.set_grad_enabled(False)

    corrects = {}
    # iterate over all files in data_dir
    for file in os.listdir(args.data_dir):
        if file.endswith(".csv"):
            reader = csv.reader(open(os.path.join(args.data_dir, file), 'r'))
            batches = prepare_data(reader, args.batch_size)
            corrects[file] = get_accuracy(model, tokenizer, batches)
            print(f"Accuracy for {file}: {sum(corrects[file]) / len(corrects[file]):.2f}")
    all_corrects = [x for sublist in corrects.values() for x in sublist]
    print(f"Overall accuracy: {sum(all_corrects) / len(all_corrects):.2f}")



