import json
import torch
import re
import argparse
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_local_model(model_path):
    acc_token = "[Acc token]"
    cache_dir = "[Cache dir path]"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        token=acc_token,  
        cache_dir=cache_dir
    ).to('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True,
        token=acc_token, 
        cache_dir=cache_dir
    )
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

def generate_response(model, tokenizer, prompt, max_length=1024):
    """
    Generates a response using the model.
    """
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    input_ids = inputs["input_ids"].to(model.device)

    outputs = model.generate(
        input_ids,
        max_new_tokens=max_length,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id
    )
    generated_tokens = outputs[0][input_ids.shape[1]:]
    return tokenizer.decode(generated_tokens, skip_special_tokens=True)

def extract_answer(model_output):

    patterns = [
        r'Answer:\s*([A-D])',     
        r'Answer:\s*([A-D])\)',    
        r'\b([A-D])\)',            
        r'\b([A-D])\.',            
        r'\b([A-D])\b'              
    ]

    for pattern in patterns:
        match = re.search(pattern, model_output, re.IGNORECASE)
        if match:
            return match.group(1).upper()

    single_letter_match = re.search(r'\b([A-D])\b', model_output)
    if single_letter_match:
        return single_letter_match.group(1).upper()

    return None

def evaluate_accuracy(model, tokenizer, data):
    total_questions = 0
    correct_answers = 0

    for idx, item in enumerate(data):
        total_questions += 1
        question = item.get("input", "").strip()
        article = item.get("article", "").strip()
        system_instruction = item.get("system", "").strip()
        gold_answer = item.get("answer", "").strip().upper()

        if not question or not gold_answer:
            print(f"Data item {idx} is missing a question or answer, skipping.")
            continue

        options = item.get("options", {})
        option_A = options.get("A", "").strip()
        option_B = options.get("B", "").strip()
        option_C = options.get("C", "").strip()
        option_D = options.get("D", "").strip()

        prompt = (
            f"{system_instruction}\n"
            f"Article: {article}\n"
            f"Question: {question}\n"
            f"Options: A) {option_A}, "
            f"B) {option_B}, "
            f"C) {option_C}, "
            f"D) {option_D}\n"
            f"Answer: Please provide only the letter of the correct answer (A, B, C, or D)."
        )

        # Generate the model's response
        model_output = generate_response(model, tokenizer, prompt)

        # Extract the answer
        selected_option = extract_answer(model_output)

        # Check the answer
        if selected_option == gold_answer:
            correct_answers += 1
        else:
            print(f"Data item {idx + 1}: Incorrect answer")
            print(f"Question: {question}")
            print(f"Correct Answer: {gold_answer}")
            print(f"Model Answer: {model_output.strip()}\n")

    acc = correct_answers / total_questions if total_questions > 0 else 0
    print(f"Accuracy (acc): {acc:.2%} ({correct_answers}/{total_questions})")

def main():

    parser = argparse.ArgumentParser(description="Evaluate the model's performance on the provided dataset.")
    parser.add_argument('--model_path', type=str, required=True, help="Path to the pre-trained model.")
    parser.add_argument('--data_path', type=str, help="Path to the dataset.")
    args = parser.parse_args()

    print("Loading model and tokenizer...")
    model, tokenizer = load_local_model(args.model_path)
    print("Model and tokenizer loaded.")

    print(f"Loading dataset: {args.data_path} ...")
    with open(args.data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    print(f"Dataset loaded, {len(data)} questions in total.")

    print("Starting model accuracy evaluation...")
    evaluate_accuracy(model, tokenizer, data)
    print("Evaluation completed.")

if __name__ == "__main__":
    main()