import torch
import json
import glob
from tqdm import tqdm
import re
import os
import argparse
from unsloth import FastModel


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate a fine-tuned language model on a QA dataset.")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the trained LoRA model directory.")
    parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the JSON test files.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the prediction results.")
    parser.add_argument("--log_file", type=str, required=True, help="File to write the accuracy logs.")
    parser.add_argument("--max_seq_length", type=int, default=4096, help="Maximum sequence length for the model.")
    parser.add_argument("--load_in_4bit", action='store_true', help="Load the model in 4-bit precision.")
    return parser.parse_args()


def generate_answer(model, tokenizer, question, options, is_multiple_choice):
    system_prompt = """Respond in the following format:
    <think>
    ...
    </think>
    <answer>
    ...
    </answer>
    """
    if is_multiple_choice:
        instruction = "This is a MULTIPLE choice question. Choose ALL correct answers and think step by step.\nProvide your final answer using \\boxed{} format to include all correct options.\n"
    else:
        instruction = "This is a SINGLE choice question. Choose only ONE correct answer and think step by step.\nProvide your final answer using \\boxed{} format.\n"

    content = f"""
        {instruction}
        Question: {question}
        A. {options['opa']}
        B. {options['opb']}
        C. {options['opc']}
        D. {options['opd']}
    """
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": content}
    ]

    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=4096,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    match = re.findall(r"\\boxed\{+([^}]+?)\}+", generated_text)

    if match:
        final_match = match[-1].strip()
        return generated_text, final_match
    return generated_text, None


def main():
    args = parse_args()

    print("Loading model and tokenizer with Unsloth FastModel...")
    model, tokenizer = FastModel.from_pretrained(
        model_name=args.model_path,
        max_seq_length=args.max_seq_length,
        load_in_4bit=args.load_in_4bit,
    )
    print("Model loaded successfully.")

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    os.makedirs(args.output_dir, exist_ok=True)

    file_paths = glob.glob(os.path.join(args.data_dir, "*.json"))
    file_accuracies = {}
    grand_total_correct = 0
    grand_total_questions = 0

    # Clear log file before starting
    with open(args.log_file, 'w') as f:
        f.write("")

    for file_path in file_paths:
        with open(file_path, 'r') as f:
            data = json.load(f)

        total_correct = 0
        total_questions = 0

        output_data = []

        for item in tqdm(data, desc=f"Processing {os.path.basename(file_path)}"):
            question = item['question']
            options = {'opa': item['opa'], 'opb': item['opb'], 'opc': item['opc'], 'opd': item['opd']}
            true_answer_map = {'opa': 'A', 'opb': 'B', 'opc': 'C', 'opd': 'D'}

            is_multiple_choice = "answers" in item
            if is_multiple_choice:
                true_answer_keys = sorted([true_answer_map[ans] for ans in item['answers'] if ans in true_answer_map])
                true_answer = ", ".join(true_answer_keys)
            else:
                true_answer = true_answer_map.get(item['answer'])

            prediction, pred_answer = generate_answer(model, tokenizer, question, options, is_multiple_choice)

            if pred_answer:
                print(f"Ground Truth: {true_answer}, Prediction: {pred_answer}")
                if pred_answer == true_answer:
                    total_correct += 1

            total_questions += 1

            item['prediction'] = prediction
            item['prediction_answer'] = pred_answer
            output_data.append(item)

        grand_total_correct += total_correct
        grand_total_questions += total_questions

        new_file_path = os.path.join(args.output_dir, os.path.basename(file_path))
        with open(new_file_path, 'w') as f:
            json.dump(output_data, f, indent=4)

        file_accuracy = total_correct / total_questions if total_questions > 0 else 0
        file_accuracies[os.path.basename(file_path)] = file_accuracy

        print(f"Accuracy for {os.path.basename(file_path)}: {file_accuracy:.4f}")
        with open(args.log_file, 'a') as out_file:
            out_file.write(f"{os.path.basename(file_path)}: {file_accuracy:.4f}\n")

    print("\n--- Final Results ---")
    for file_name, accuracy in file_accuracies.items():
        print(f"{file_name}: {accuracy:.4f}")

    overall_accuracy = grand_total_correct / grand_total_questions if grand_total_questions > 0 else 0
    print(f"\nOverall Accuracy: {overall_accuracy:.4f}")
    with open(args.log_file, 'a') as out_file:
        out_file.write(f"\nOverall Accuracy: {overall_accuracy:.4f}\n")


if __name__ == "__main__":
    main()