import os
import sys
import json
import logging
import torch
import wandb
from datetime import datetime
from transformers import (
    HfArgumentParser,
    AutoTokenizer,
    AutoModelForCausalLM,
    GenerationConfig
)
from datasets import load_dataset
from typing import Optional, Dict, Any, Union
from dataclasses import dataclass, field
from typing import List
from tqdm import tqdm
import re
from extraction import extract_answer, extract_gt
import argparse
import yaml

# --- Argument Classes ---
@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to the pre-trained model or model identifier from Hugging Face Hub."}
    )
    tokenizer_name_or_path: Optional[str] = field(
        default=None, metadata={"help": "Path to the tokenizer from Hugging Face Hub. If None, defaults to model_name_or_path."}
    )
    device_map: Optional[str] = field(
        default="auto", metadata={"help": "Device map for model loading."}
    )
    torch_dtype: Optional[str] = field(
        default="auto", metadata={"help": "Torch dtype for model loading ('bfloat16', 'float16', 'auto')."}
    )

@dataclass
class DataArguments:
    dataset_name: str = field(
        metadata={"help": "Name or path of the dataset to evaluate on (from Hugging Face Hub or local)."}
    )
    split: str = field(
        default="test", metadata={"help": "Dataset split to use for evaluation."}
    )
    input_columns: List[str] = field(
        default="question", metadata={"help": "The key in the dataset that contains the prompt or question."}
    )
    answer_key: str = field(
        default="answer", metadata={"help": "The key in the dataset that contains the ground truth answer."}
    )    
    subset_name: Optional[str] = field(
        default=None, metadata={"help": "Subset name of the dataset if applicable."}
    )
    reasoning_type: str = field(
        default="numerical", metadata={"help": "Type of reasoning expected used for answer extraction ('numerical', 'multiple_choice')."}
    )

@dataclass
class GenerationArguments:
    max_new_tokens: int = field(
        default=1028, metadata={"help": "Maximum number of new tokens to generate."}
    )
    do_sample: bool = field(
        default=False
    )
    temperature: float = field(
        default=0.6
    )
    top_p: float = field(
        default=0.95
    )
    top_k: int = field(
        default=20
    )

@dataclass
class EvalArguments:
    output_dir: str = field(
        default="eval_results", metadata={"help": "Directory to save evaluation results."}
    )
    eval_batch_size: int = field(
        default=16, metadata={"help": "Batch size for evaluation."}
    )
    custom_prompt_file :Optional[str] = field(
        default=None, metadata={"help": "Path to a custom prompt template file in JSON format."}
    )
    custom_prompt_id: Optional[str] = field(
        default=None, metadata={"help": "Custom prompt template to use for formatting questions"}
    )

@dataclass
class WandbArguments:
    wandb_project: Optional[str] = field(
        default="llm-evaluation", metadata={"help": "The name of the W&B project to log to. Set to None to disable."}
    )
    wandb_run_name: Optional[str] = field(
        default=None, metadata={"help": "A specific name for the W&B run. If None, a random name will be generated."}
    )

# --- Helper Functions ---

def setup_logging(output_dir: str, model_name: str, dataset_name: str):
    safe_model_name = model_name.replace("/", "_")
    safe_dataset_name = dataset_name.replace("/", "_")
    log_filename = f"eval_{safe_model_name}_{safe_dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_filepath = os.path.join(output_dir, log_filename)
    os.makedirs(output_dir, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_filepath), logging.StreamHandler(sys.stdout)]
    )
    logging.info(f"Logging to {log_filepath}")

def format_prompt(example: Dict[str, Any], input_columns: List[str], custom_prompt_template: Optional[str] = None) -> str:
    """
    Formats a single dataset example into a prompt string for the model.
    """    
    append_str = "\nProvide the final answer in the format 'Answer: <answer>'.\nSolution:"
    if len(input_columns) == 1:
        question = example[input_columns[0]]
    else:
        question_col = input_columns[0]
        question = example[question_col] + "\nSelect the correct answer from the choices below:" if question_col in example else ""
        answers = example[input_columns[1]]
        for i,answer in enumerate(answers):
            question += f"\n{i}. {answer}"
        append_str = "From the choices provided, select the number of the choice that best answers the question (0,1,2, or 3) and provide your answer in the format 'Answer: <answer>'. \nSolution:"

    if len(question) == 0:
        logging.warning("Empty question found in example.")

    if custom_prompt_template:
        return custom_prompt_template.format(question=question) + " " + append_str
    prompt_template = (
        "Solve the following problem. Think step-by-step, provide your reasoning and then provide the final answer in the format 'Answer: <answer>'. Make sure to provide your reasoning before the final answer.\n"
        "Question: {question}\n"
    )
    return prompt_template.format(question=question) + " " + append_str

def check_answer(predicted_answer:str, ground_truth: str) -> bool:
    if predicted_answer is None:
        return False
    gt_str = str(ground_truth).strip().lower()
    pred_str = str(predicted_answer).strip().lower()
    try:
        if float(pred_str) == float(gt_str):
            return True
    except (ValueError, TypeError):
        pass
    return pred_str == gt_str

def main():
    # --- Argument Parsing ---
    job_id = os.environ.get("JOB_NUM")
    job_id = int(job_id) if job_id and job_id.isdigit() else None
    if job_id:
        logging.info(f"Running evaluation for JOB_NUM: {job_id}")
    else:
        logging.warning("No JOB_NUM found in environment variables")
        sys.exit(1)
    
    hyperparam_file = os.environ.get("HYPERPARAM_FILE")
    if not hyperparam_file:
        hyperparam_file = "hyperparams.yml"
    logging.info(f"Using hyperparameter file: {hyperparam_file}")
    if not os.path.exists(hyperparam_file):
        logging.error(f"Hyperparameter file {hyperparam_file} does not exist.")
        sys.exit(1)
    with open(hyperparam_file, "r") as f:
        hyperparams = yaml.safe_load(f)
    
    hyperparam_set = hyperparams[job_id-1]["hyperparam_set"]

    if hyperparam_set is None:
        logging.error(f"Job ID {job_id} not found in {hyperparam_file}")
        sys.exit(1)

    parser = HfArgumentParser((ModelArguments, DataArguments, GenerationArguments, EvalArguments, WandbArguments))
    model_args, data_args, gen_args, eval_args, wandb_args = parser.parse_dict(args=hyperparam_set)
    
    # --- Setup Logging ---
    setup_logging(eval_args.output_dir, model_args.model_name_or_path, data_args.dataset_name)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logging.info("Starting LLM evaluation script.")
    logging.info(f"W&B Project: {wandb_args.wandb_project}, Run Name: {wandb_args.wandb_run_name}")
    
    # --- Initialize W&B ---
    if wandb_args.wandb_project:
        try:
            wandb.init(
                project=wandb_args.wandb_project,
                name=wandb_args.wandb_run_name,
                config={
                    "model_args": model_args,
                    "data_args": data_args,
                    "gen_args": gen_args,
                    "eval_args": eval_args,
                }
            )
            logging.info(f"Logging to W&B project: {wandb_args.wandb_project}")
        except Exception as e:
            logging.warning(f"Could not initialize W&B: {e}. Disabling W&B logging.")
            wandb_args.wandb_project = None # Disable wandb if init fails

    # --- Load Model and Tokenizer ---
    logging.info(f"Loading model: {model_args.model_name_or_path}")
    tokenizer_path = model_args.tokenizer_name_or_path or model_args.model_name_or_path
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, padding_side="left")
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            device_map=model_args.device_map,
            torch_dtype=getattr(torch, model_args.torch_dtype) if hasattr(torch, model_args.torch_dtype) else "auto",
            temperature=gen_args.temperature,
            top_p=gen_args.top_p,
            top_k=gen_args.top_k,
        )
        logging.info("Model and tokenizer loaded successfully.")
    except Exception as e:
        logging.error(f"Failed to load model or tokenizer: {e}")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)

    # --- Load and Prepare Dataset ---
    logging.info(f"Loading dataset: {data_args.dataset_name}, split: {data_args.split}")
    try:
        dataset = load_dataset(data_args.dataset_name, data_args.subset_name if data_args.subset_name else None, split=data_args.split)
    except Exception as e:
        logging.error(f"Failed to load dataset: {e}")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)

    # --- Load Custom Prompt ---
    custom_prompt_template = None
    if eval_args.custom_prompt_file and eval_args.custom_prompt_id:
        try:
            with open(eval_args.custom_prompt_file, "r") as f:
                custom_prompts = json.load(f)
            for prompt in custom_prompts["prompts"]:
                logging.info(f"Checking prompt ID {eval_args.custom_prompt_id} in {eval_args.custom_prompt_file}")
                if prompt["prompt_id"] == eval_args.custom_prompt_id:
                    custom_prompt_template = prompt["template"]
                    logging.info(f"Using custom prompt template: {custom_prompt_template}")
                    break
            if custom_prompt_template is None:
                logging.warning(f"Prompt ID {eval_args.custom_prompt_id} not found in {eval_args.custom_prompt_file}. Using default prompt template.")
        except Exception as e:
            logging.error(f"Failed to load custom prompt template: {e}. Using default prompt template.")
    else:
        logging.info("No custom prompt file or ID provided. Using default prompt template.")

    # --- Evaluation Loop ---
    model.eval()
    results = []
    correct_predictions = 0
    total_predictions = 0
    generation_config = GenerationConfig(
        max_new_tokens=gen_args.max_new_tokens,
        do_sample=gen_args.do_sample,
        temperature=gen_args.temperature,
        top_p=gen_args.top_p,
        pad_token_id=tokenizer.pad_token_id
    )

    logging.info("Starting evaluation loop...")
    with torch.no_grad():
        for i in tqdm(range(0, len(dataset), eval_args.eval_batch_size), desc="Evaluating"):
            batch_indices = range(i, min(i + eval_args.eval_batch_size, len(dataset)))
            batch_examples = [dataset[j] for j in batch_indices]
            prompts = [format_prompt(ex, data_args.input_columns, custom_prompt_template) for ex in batch_examples]
            inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
            outputs = model.generate(**inputs, generation_config=generation_config)
            generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            
            for j, example in enumerate(batch_examples):
                full_generated_text = generated_texts[j]
                prompt_length = len(tokenizer.decode(inputs['input_ids'][j], skip_special_tokens=True))
                model_output = full_generated_text[prompt_length:]
                ground_truth = str(example[data_args.answer_key])
                ground_truth = extract_gt(ground_truth) if isinstance(ground_truth, str) else ground_truth
                predicted_answer = extract_answer(model_output, data_args.reasoning_type)
                is_correct = check_answer(predicted_answer, ground_truth)
                if is_correct:
                    correct_predictions += 1
                total_predictions += 1
                results.append({
                    "prompt": prompts[j],
                    "ground_truth": ground_truth,
                    "predicted_answer": predicted_answer,
                    "model_output": model_output,
                    "is_correct": is_correct
                })
                if i == 0 and j == 0:
                    logging.info(f"First example - Prompt: {prompts[j]}, GT: {ground_truth}, Predicted: {predicted_answer}, Model Output: {model_output}")
                    wandb.log({
                        "first_example_prompt": prompts[j],
                        "first_example_ground_truth": ground_truth,
                        "first_example_predicted_answer": predicted_answer,
                        "first_example_model_output": model_output
                    })

    # --- Calculate and Save Results ---
    accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
    logging.info(f"Evaluation finished. Accuracy: {accuracy:.2f}%")
    
    final_results = {
        "model": model_args.model_name_or_path,
        "dataset": data_args.dataset_name,
        "split": data_args.split,
        "accuracy": accuracy,
        "total_correct": correct_predictions,
        "total_predictions": total_predictions,
        "evaluation_details": results
    }

    safe_model_name = model_args.model_name_or_path.replace("/", "_")
    safe_dataset_name = data_args.dataset_name.replace("/", "_")
    results_filename = f"results_{safe_model_name}_{safe_dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    results_filepath = os.path.join(eval_args.output_dir, results_filename)
    
    with open(results_filepath, "w") as f:
        json.dump(final_results, f, indent=4)
    logging.info(f"Results saved to {results_filepath}")

    # Log to W&B ---
    if wandb_args.wandb_project:
        logging.info("Logging results to W&B...")
        wandb.log({
            "accuracy": accuracy,
            "total_correct": correct_predictions,
            "total_predictions": total_predictions
        })
        
        results_table = wandb.Table(columns=["prompt", "ground_truth", "predicted_answer", "is_correct", "model_output"])
        for res in results:
            results_table.add_data(
                res["prompt"], res["ground_truth"], res["predicted_answer"], res["is_correct"], res["model_output"]
            )
        wandb.log({"evaluation_details": results_table})
        wandb.finish()

if __name__ == "__main__":
    main()
