import sys
sys.path.append("..")

import re
import os
from tqdm import tqdm
from utils import initialize_text_to_text_model, model_inference
from data import load_gsm8k
import argparse

import datetime
time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True)
    return parser.parse_args()

def extract_num(text):
    # Regex pattern to find the number following '####'
    pattern = r'####\s*(\d+)'
    # Using re.search to find the first match
    match = re.search(pattern, text)
    if match:
        result = match.group(1)
        # print(result)
    else:
        # print(text)
        result = ""
    try:
        return int(result.replace(",", ""))
    except:
        print(f"'{result}' can't be converted")
        return 0


def eval_gsm8k(model_name):
    _, _, test_set = load_gsm8k()
    model_type = "CausalLM"
    model, tokenizer = initialize_text_to_text_model(
        model_name,
        model_type,
        True,
        tokenizer="meta-llama/Llama-2-7b-hf",
        flash_attention=True)

    all = 0
    correct = 0
    t = tqdm(test_set)
    for i, example in enumerate(t):
        pred_text = model_inference(model,
                                    tokenizer,
                                    example['x'],
                                    model_type,
                                    max_target_length=512)
        print(pred_text)
        gt = extract_num(example["y"])
        pred = extract_num(pred_text)
        print(gt, pred)
        correct += int(gt == pred)
        all += 1
        print(f"{i} / {len(test_set)}: Accuracy: {correct/all*100:02f}%")
        t.set_description(f"Accuracy: {correct/all*100:02f}%")

    print("Acc:", correct / all)

    # append to gsm8k_results.txt (create if not exists)
    if not os.path.exists("./logs/gsm8k_results.txt"):
        with open("./logs/gsm8k_results.txt", "w") as f:
            f.write("Model Acc\n")


    with open("./logs/gsm8k_results.txt", "a") as f:
        f.write(f"{time} | {model_name} {correct/all}\n")


if __name__ == "__main__":
    args = parse_args()
    eval_gsm8k(args.model_name)
