import argparse
import json
import os
import re
import time
import tempfile
import shutil
from typing import List, Dict, Any, Optional
import logging
from tqdm import tqdm

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM
)
from vllm import LLM, SamplingParams
from data_loader import *
from model import ModelEvaluator
from result_evaluator import ResultEvaluator


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


def main():
    parser = argparse.ArgumentParser(
        description="Model evaluation on various datasets")
    # Model-related parameters
    parser.add_argument("--model_path", type=str, required=True,
                        help="Path to the base model")
    parser.add_argument("--lora_path", type=str, default=None,
                        help="Path to LoRA weights (optional)")
    parser.add_argument("--torch_dtype", type=str, default="auto",
                        choices=["auto", "float16", "float32", "bfloat16"],
                        help="Model data type")

    # GPU
    parser.add_argument("--gpu", type=str, default="auto",
                        help="GPU device to use (e.g., 0, 1, auto)")
    parser.add_argument("--tensor_parallel_size", type=int, default=1,
                        help="Number of GPUs to use for tensor parallelism in vLLM")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)

    # Dataset-related parameters
    parser.add_argument("--dataset", type=str, required=True,
                        help="Dataset type")
    parser.add_argument("--data_path", type=str, default=None,
                        help="Path to custom dataset")
    parser.add_argument("--data_split", type=str, default="test",
                        help="Dataset split (train/test/validation)")

    # Generation parameters
    parser.add_argument("--max_new_tokens", type=int, default=2048,
                        help="Maximum new tokens to generate")
    parser.add_argument("--temperature", type=float, default=0.3,
                        help="Temperature for generation")
    parser.add_argument("--top_p", type=float, default=0.8,
                        help="Top-p for generation")
    parser.add_argument("--top_k", type=int, default=20,
                        help="Top-k for generation")

    # Output settings
    parser.add_argument("--output_dir", type=str, default="./results",
                        help="Output directory for results")
    parser.add_argument("--save_predictions", action="store_true",
                        help="Save detailed predictions")

    # Debugging
    parser.add_argument("--max_samples", type=int, default=None,
                        help="Maximum number of samples to evaluate (for debugging)")

    args = parser.parse_args()

    if args.gpu != "auto":
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if args.max_samples is not None and args.max_samples < 0:
        # If set to a negative number, there will be no limit on the number of samples.
        args.max_samples = None

    os.makedirs(args.output_dir, exist_ok=True)

    start_time = time.time()

    # data loading
    data = load_data(
        dataset_name=args.dataset,
        args=args,
        data_path=args.data_path,
        max_samples=args.max_samples,
        data_dir="./datasets"
    )

    logger.info(f"Loaded {len(data)} samples")
    # initialize evaluator
    evaluator = ModelEvaluator(
        model_path=args.model_path,
        lora_path=args.lora_path,
        torch_dtype=args.torch_dtype,
        tensor_parallel_size=args.tensor_parallel_size,
        gpu_memory_utilization=args.gpu_memory_utilization
    )

    # prepare conversations and ground truths
    conversations = []
    ground_truths = []

    for item in data:
        if args.dataset in {"gsm8k", "math", "mathqa", "svamp", "asdiv", "mawps", "aime24"}:
            conversation = prepare_math_dataset_conversation(item)
            ground_truths.append(item["answer"])
        elif args.dataset in {"folio", "arlsat", "logiqa", "reclor", "abductionr", "fld", "proofwriter", "ruletaker"}:
            conversation = prepare_logic_dataset_conversation(item)
            ground_truths.append(item["answer"])
        elif args.dataset == 'rulearena':
            conversation = prepare_rulearena_conversation(item, args)
            ground_truths.append(item["answer"])
        else:
            raise ValueError(f"Unsupported dataset: {args.dataset}")

        conversations.append(conversation)

    # Generate response
    logger.info("Starting response generation...")
    predictions = evaluator.generate_response(
        conversations=conversations,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        top_k=args.top_k
    )

    total_time = time.time() - start_time

    # evaluate results
    results_data = {
        "args": vars(args),
        "dataset": args.dataset,
        "total_samples": len(data),
        "model_path": args.model_path,
        "lora_path": args.lora_path,
        "evaluation_time": total_time,
        "samples_per_second": len(data) / total_time
    }

    if args.dataset in ["gsm8k", "math", "mathqa", "svamp", "asdiv", "mawps", "aime24"]:
        eval_results = ResultEvaluator.evaluate_mathdata(
            predictions, ground_truths)
    elif args.dataset in ["folio"]:
        eval_results = ResultEvaluator.evaluate_folio(
            predictions, ground_truths)
    elif args.dataset in ["arlsat"]:
        eval_results = ResultEvaluator.evaluate_arlsat(
            predictions, ground_truths)
    elif args.dataset in ["logiqa"]:
        eval_results = ResultEvaluator.evaluate_logiqa(
            predictions, ground_truths)
    elif args.dataset in ["reclor"]:
        eval_results = ResultEvaluator.evaluate_reclor(
            predictions, ground_truths)
    elif args.dataset in ["abductionr"]:
        eval_results = ResultEvaluator.evaluate_abductionr(
            predictions, ground_truths)
    elif args.dataset in ["rulearena"]:
        eval_results = ResultEvaluator.evaluate_rulearena(
            predictions, ground_truths, args.data_split)
    elif args.dataset in ["fld"]:
        eval_results = ResultEvaluator.evaluate_fld(predictions, ground_truths)
    elif args.dataset in ["proofwriter"]:
        eval_results = ResultEvaluator.evaluate_proofwriter(
            predictions, ground_truths)
    elif args.dataset in ["ruletaker"]:
        eval_results = ResultEvaluator.evaluate_ruletaker(
            predictions, ground_truths)
    else:
        raise ValueError(f"Unsupported dataset for evaluation: {args.dataset}")

    results_data["evaluation_metrics"] = eval_results

    logger.info(f"{args.dataset.upper()} Evaluation Results:")
    logger.info(f"Accuracy: {eval_results['accuracy']:.4f}")
    logger.info(f"Correct: {eval_results['correct']}/{eval_results['total']}")

    # Save prediction example
    if args.save_predictions:
        results_data["predictions"] = []
        for i, (conversation, pred, truth) in enumerate(zip(conversations, predictions, ground_truths)):
            formatted_conversation = "\n".join(
                f"{msg['role']}: {msg['content']}" for msg in conversation
            )

            results_data["predictions"].append({
                "index": i,
                "conversation": conversation,
                "input": formatted_conversation,
                "prediction": pred,
                "ground_truth": truth
            })

    output_file = os.path.join(
        args.output_dir, f"evaluation_results_{args.dataset}_{args.data_split}.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results_data, f, ensure_ascii=False, indent=2)

    logger.info(f"Results saved to: {output_file}")
    logger.info(f"Total evaluation time: {total_time:.2f} seconds")
    logger.info(f"Average time per sample: {total_time/len(data):.2f} seconds")


if __name__ == "__main__":
    main()
