import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import re
from tqdm import tqdm
import json
import os
from datetime import datetime
import time
from zoneinfo import ZoneInfo
from peft import PeftModel

def evaluate_gsm8k_for_llama3():
    NUM_SAMPLES_TO_RUN = 1000
    INFERENCE_BATCH_SIZE = 8

    base_model_path = "/root/model/Llama-3-8B-Instruct"   # Please modify this

    lora_adapter_path = "../ckpt/ssr-llama-gsm8k-0.35"    # Please modify this

    output_dir = "../eval_results/gsm8k"
    os.makedirs(output_dir, exist_ok=True)
    timezone = ZoneInfo("Asia/Shanghai")
    timestamp = datetime.now(timezone).strftime("%Y%m%d_%H%M%S")
    lora_name = os.path.basename(lora_adapter_path.strip('/'))
    model_identifier = f"{lora_name}"
    output_file = os.path.join(output_dir, f"{model_identifier}_{timestamp}.jsonl")


    tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    tokenizer.padding_side = 'left'
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="flash_attention_2"
    )

    model = PeftModel.from_pretrained(model, lora_adapter_path)
    model = model.merge_and_unload()
    model.eval()

    gsm8k_local_path = "../data/GSM8K/test-00000-of-00001.parquet"
    test_dataset = load_dataset("parquet", data_files={"test": gsm8k_local_path}, split="test")

    if 'NUM_SAMPLES_TO_RUN' in locals() and NUM_SAMPLES_TO_RUN is not None:
        test_dataset = test_dataset.select(range(NUM_SAMPLES_TO_RUN))



    results_data = []
    correct_count = 0
    total_count = len(test_dataset)
    start_time = time.time()

    with torch.no_grad():
        for i in tqdm(range(0, total_count, INFERENCE_BATCH_SIZE), desc="running"):
            batch_indices = range(i, min(i + INFERENCE_BATCH_SIZE, total_count))
            batch_examples = [test_dataset[j] for j in batch_indices]
            
            questions = [ex["question"] for ex in batch_examples]
            ground_truth_answers = [ex["answer"] for ex in batch_examples]


            messages_batch = [
                [
                    {"role": "user", "content": q}
                ]
                for q in questions
            ]
            
        
            prompts_text = [
                tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
                for messages in messages_batch
            ]

            model_inputs = tokenizer(
                prompts_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(model.device)

            generated_ids = model.generate(
                **model_inputs,
                do_sample=False,
                num_beams=1,
                max_new_tokens=256
            )


            prompt_lengths = model_inputs['input_ids'].shape[1]
            response_ids = generated_ids[:, prompt_lengths:]
            responses_text = tokenizer.batch_decode(response_ids, skip_special_tokens=True)


            for j, response_text in enumerate(responses_text):
                current_index = batch_indices[j]
                question = questions[j]
                ground_truth_answer = ground_truth_answers[j]
                true_answer_str = re.search(r'####\s*(-?[\d,]+)', ground_truth_answer).group(1)

                model_answer_str = None
                match = re.search(r'####\s*(-?[\d,]+\.?\d*)', response_text)
                if match:
                    model_answer_str = match.group(1).strip()
                else:
                    matches = re.findall(r'-?[\d,]+\.?\d*(?=\s|$)', response_text)
                    if matches:
                        model_answer_str = matches[-1]

                is_correct = False
                try:
                    model_answer_val = float(str(model_answer_str).replace(',', ''))
                    true_answer_val = float(true_answer_str.replace(',', ''))
                    if abs(model_answer_val - true_answer_val) < 1e-4:
                        correct_count += 1
                        is_correct = True
                except (ValueError, TypeError, AttributeError):
                    pass
                
                results_data.append({
                    "index": current_index, "question": question, "ground_truth_answer": ground_truth_answer,
                    "model_full_response": response_text, "extracted_model_answer": model_answer_str,
                    "extracted_true_answer": true_answer_str, "is_correct": is_correct
                })

    end_time = time.time()
    total_time = end_time - start_time
    samples_per_second = total_count / total_time if total_time > 0 else 0

    with open(output_file, 'w', encoding='utf-8') as f:
        for entry in results_data:
            f.write(json.dumps(entry, ensure_ascii=False) + '\n')

    print("\n==================== Batch Evaluation Complete ====================")
    print(f"Base Model: {base_model_path}")
    if lora_adapter_path:
        print(f"LoRA Adapter: {lora_adapter_path}")
    print(f"Total Samples Processed: {total_count}")
    print(f"Correct Answers: {correct_count}")
    print(f"Final Accuracy: {correct_count / total_count if total_count > 0 else 0:.4f}")
    print(f"\nDetailed results saved to: {output_file}")
    print("================================================================")


if __name__ == "__main__":
    evaluate_gsm8k_for_llama3()