import os
import sys
import json
import logging
import wandb
from datetime import datetime
import csv
from transformers import HfArgumentParser
from datasets import load_dataset
from typing import Optional, Dict, Any
from dataclasses import dataclass, field
from typing import List
from tqdm import tqdm
import re
from extraction import extract_answer, extract_gt
from evaluation import format_prompt, check_answer
from openai import OpenAI
from groq import Groq
import time
# --- Argument Classes ---
@dataclass
class ModelArguments:
    provider: str = field(
        default="openai",
        metadata={"help": "The provider of the model to evaluate (e.g., 'openai')."}
    )
    model_name: str = field(
        default="o4-mini",
        metadata={"help": "Name of the model to evaluate."}
    )
    api_key: Optional[str] = field(
        default=None,
        metadata={"help": "Hugging Face API key for model access."}
    )

@dataclass
class DataArguments:
    dataset_name: str = field(
        metadata={"help": "Name or path of the dataset to evaluate on (from Hugging Face Hub or local)."}
    )
    split: str = field(
        default="test", metadata={"help": "Dataset split to use for evaluation (e.g., 'test', 'validation')."}
    )
    input_columns: List[str] = field(
        default="question", metadata={"help": "The key in the dataset that contains the prompt or question."}
    )
    answer_key: str = field(
        default="answer", metadata={"help": "The key in the dataset that contains the ground truth answer."}
    )    
    subset_name: Optional[str] = field(
        default=None, metadata={"help": "Subset name of the dataset if applicable (e.g ., 'main' for GSM8K)."}
    )
    reasoning_type: str = field(
        default="numerical", metadata={"help": "Type of reasoning expected in the answers"}
    )

@dataclass
class EvalArguments:
    output_dir: str = field(
        default="eval_results", metadata={"help": "Directory to save evaluation results."}
    )
    eval_batch_size: int = field(
        default=16, metadata={"help": "Batch size for evaluation."}
    )
    custom_prompt_file :Optional[str] = field(
        default=None, metadata={"help": "Path to a custom prompt template file in JSON format."}
    )
    custom_prompt_id: Optional[str] = field(
        default=None, metadata={"help": "Custom prompt template to use for formatting questions"}
    )
    batched_eval: bool = field(
        default=False,
        metadata={"help": "Whether to use batched evaluation."}
    )
    early_stopping: bool = field(
        default=False,
    )

@dataclass
class WandbArguments:
    wandb_project: Optional[str] = field(
        default="llm-evaluation", metadata={"help": "The name of the W&B project to log to. Set to None to disable."}
    )
    wandb_run_name: Optional[str] = field(
        default=None, metadata={"help": "A specific name for the W&B run. If None, a random name will be generated."}
    )

# --- Helper Functions ---

def setup_logging(output_dir: str, model_name: str, dataset_name: str):
    safe_model_name = model_name.replace("/", "_")
    safe_dataset_name = dataset_name.replace("/", "_")
    log_filename = f"eval_{safe_model_name}_{safe_dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_filepath = os.path.join(output_dir, log_filename)
    os.makedirs(output_dir, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_filepath), logging.StreamHandler(sys.stdout)]
    )
    logging.info(f"Logging to {log_filepath}")

def check_answer(predicted_answer:str, ground_truth: str) -> bool:
    if predicted_answer is None:
        return False
    gt_str = str(ground_truth).strip().lower()
    pred_str = str(predicted_answer).strip().lower()
    try:
        if float(pred_str) == float(gt_str):
            return True
    except (ValueError, TypeError):
        pass
    return pred_str == gt_str

def single_eval(client: OpenAI, dataset, model_args, data_args, custom_prompt_template=None):
    results = []
    correct_predictions = 0
    total_predictions = 0
    logging.info("Starting evaluation loop...")
    for i in tqdm(range(0, len(dataset)), desc="Evaluating"):
        example = dataset[i]
        prompt = format_prompt(example, data_args.input_columns, custom_prompt_template)
        if model_args.provider == "openai":
            output = client.responses.create(
                    model=model_args.model_name,
                    reasoning={"effort": "medium"},
                    input=[
                        {
                            "role": "user",
                            "content": prompt
                        }
                    ]
                )
            model_output = output.output_text
        else:
            output = client.chat.completions.create(
                model=model_args.model_name,
                messages=[
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
            )
            model_output = output.choices[0].message.content
        if i == 0:
            logging.info(f"Generated output for first sample: {model_output}")
        ground_truth = str(example[data_args.answer_key])
        ground_truth = extract_gt(ground_truth) if isinstance(ground_truth, str) else ground_truth
        predicted_answer = extract_answer(model_output, data_args.reasoning_type)
        is_correct = check_answer(predicted_answer, ground_truth)
        if is_correct:
            correct_predictions += 1
        total_predictions += 1
        results.append({
            "prompt": prompt,
            "ground_truth": ground_truth,
            "predicted_answer": predicted_answer,
            "model_output": model_output,
            "is_correct": is_correct
        })
        if i == 0:
            logging.info(f"First example - Prompt: {prompt}, GT: {ground_truth}, Predicted: {predicted_answer}, Model Output: {model_output}")
            wandb.log({
                "first_example_prompt": prompt,
                "first_example_ground_truth": ground_truth,
                "first_example_predicted_answer": predicted_answer,
                "first_example_model_output": model_output
            })

    return results, correct_predictions, total_predictions

def create_batch_prompt_file(num_samples: int, prompt_template: str, provider: str, model_name_or_path: str, dataset: Optional[list], input_columns: List[str], output_column: str) -> str:
    batch_prompts = []
    if dataset:
        num_samples = min(num_samples, len(dataset))
        logging.info(f"Using {num_samples} samples from the dataset.")
    answers = {}
    for i in range(num_samples):
        example = dataset[i]
        prompt = format_prompt(example, input_columns, custom_prompt_template=prompt_template)
        batch_prompts.append(prompt)
        answers[f"request-{i+1}"] = dataset[i].get(output_column, None) if dataset else None
    # use the datetime to create a unique filename
    datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    batch_filename = f"batch_prompts_{datetime_str}.jsonl"
    with open(batch_filename, "w") as f:
        for i, prompt in enumerate(batch_prompts):
            if provider == "openai":
                f.write(json.dumps({"custom_id": f"request-{i+1}", "method": "POST", "url": "/v1/responses", "body": {"model": f"{model_name_or_path}", "reasoning": {"effort": "medium"}, "input": [{"role": "user", "content": prompt}]}}) + "\n")
            elif provider == "groq":
                f.write(json.dumps({"custom_id": f"request-{i+1}", "method": "POST", "url": "/v1/chat/completions", "body": {"model": f"{model_name_or_path}", "messages": [{"role": "user", "content": prompt}]}}) + "\n")
    logging.info(f"Batch prompt file created: {batch_filename}")
    logging.info(f"Number of prompts in batch file: {len(batch_prompts)}")
    return batch_filename, answers

def batched_generation(prompt_file: str, prompt_id: str, early_stopping: bool, provider: str, model_name_or_path: str, prompt_template: str, reasoning_type: str, client: OpenAI, dataset: Optional[list], input_column: str, output_column: str) -> list:
    # Create a batch prompt file
    logging.info(f"Starting batched generation with {len(dataset)} samples.")
    logging.info("Creating batch prompt file...")
    batch_filename, answers = create_batch_prompt_file(len(dataset), prompt_template, provider, model_name_or_path, dataset, input_column, output_column)
    logging.info(f"Batch prompt file created: {batch_filename}")
    # Upload the batch prompt file to OpenAI
    logging.info("Uploading batch prompt file to OpenAI...")
    batch_input_file = client.files.create(
        file=open(batch_filename, "rb"),
        purpose="batch"
    )
    logging.info("Batch prompt file uploaded successfully.")
    # Create a batch for evaluation
    batch_input_file_id = batch_input_file.id
    logging.info(f"Batch input file created with ID: {batch_input_file_id}")
    logging.info("Creating batch for generation...")
    batch = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/responses" if provider == "openai" else "/v1/chat/completions",
        completion_window="24h",
    )
    # Get the batch ID and track the batch status
    batch_id = batch.id
    logging.info(f"Batch created with ID: {batch_id}")
    if early_stopping:
        with open("batch_ids.csv", "a", newline='') as f:
            writer = csv.writer(f)
            writer.writerow([prompt_file, prompt_id, batch_id])
        logging.info("Early stopping is enabled. Exiting after batch creation. Batch ID saved to batch_ids.csv")
        return None, 0, 0
    # every 5 minutes check the status of the batch
    while True:
        batch_status = client.batches.retrieve(batch.id)
        if batch_status.status == "completed":
            logging.info("Batch generation completed successfully.")
            logging.info(f"Output file ID: {batch_status.output_file_id}")
            break
        elif batch_status.status == "failed":
            logging.error("Batch generation failed.")
            sys.exit(1)
        else:
            logging.info(f"Batch status: {batch_status.status}. Waiting for 5 minutes before checking again.")
            time.sleep(300)
    # Retrieve the batch output
    logging.info("Retrieving batch output...")
    response_batch = client.files.content(batch_status.output_file_id).text
    print(type(response_batch))
    response_filename = f"batch_output_{batch_id}.json"
    response_batch.write_to_file(response_filename)
    response_data = [json.loads(line) for line in response_batch.splitlines() if line.strip()]
    if response_data:
        # Save response data to a file
        logging.info(f"First response in batch: {response_data[0]}")
        with open(response_filename, "w") as f:
            json.dump(response_data, f, indent=4)
        logging.info(f"Response data saved to {response_filename}")
        # Create a list of questions and answers 
        logging.info("Processing batch responses...")
        results = []
        total_predictions = len(response_data)
        correct_predictions = 0
        for i, gen in enumerate(response_data):
            if provider == "openai":
                model_output = gen.get('response', {}).get('body', {}).get('output', [{}])[1].get('content', [{}])[0].get('text', '')
            elif provider == "groq":
                model_output = gen.get('response', {}).get('body', {}).get('', [{}]).get("choices", [{}])[0].get('message', {}).get('content', '')
            ground_truth = answers[gen.get("custom_id")]
            ground_truth = extract_gt(ground_truth) if isinstance(ground_truth, str) else ground_truth
            predicted_answer = extract_answer(model_output, reasoning_type)
            is_correct = check_answer(predicted_answer, ground_truth)
            if is_correct:
                correct_predictions += 1
            results.append({
                "prompt": prompt_template,
                "model_output": model_output,
                "answer": predicted_answer
            })
        logging.info(f"Processed {len(results)} generations from batch output.")
        return results, correct_predictions, total_predictions
    logging.warning("No responses found in the batch output.")
    return None

def main():
    # --- Argument Parsing ---
    parser = HfArgumentParser((ModelArguments, DataArguments, EvalArguments, WandbArguments))
    model_args, data_args, eval_args, wandb_args = parser.parse_args_into_dataclasses()
    
    # --- Setup Logging ---
    setup_logging(eval_args.output_dir, model_args.model_name, data_args.dataset_name)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logging.info("Starting LLM evaluation script.")
    logging.info(f"W&B Project: {wandb_args.wandb_project}, Run Name: {wandb_args.wandb_run_name}")
    
    # --- Initialize W&B ---
    if wandb_args.wandb_project and not eval_args.early_stopping:
        try:
            wandb.init(
                project=wandb_args.wandb_project,
                name=wandb_args.wandb_run_name,
                config={
                    "model_args": model_args,
                    "data_args": data_args,
                    "eval_args": eval_args,
                }
            )
            logging.info(f"Logging to W&B project: {wandb_args.wandb_project}")
        except Exception as e:
            logging.warning(f"Could not initialize W&B: {e}. Disabling W&B logging.")
            wandb_args.wandb_project = None # Disable wandb if init fails

    # --- Load and Prepare Dataset ---
    logging.info(f"Loading dataset: {data_args.dataset_name}, split: {data_args.split}")
    try:
        dataset = load_dataset(data_args.dataset_name, data_args.subset_name if data_args.subset_name else None, split=data_args.split)
    except Exception as e:
        logging.error(f"Failed to load dataset: {e}")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)

    # --- Load Custom Prompt ---
    custom_prompt_template = None
    if eval_args.custom_prompt_file and eval_args.custom_prompt_id:
        try:
            with open(eval_args.custom_prompt_file, "r") as f:
                custom_prompts = json.load(f)
            for prompt in custom_prompts["prompts"]:
                logging.info(f"Checking prompt ID {eval_args.custom_prompt_id} in {eval_args.custom_prompt_file}")
                if prompt["prompt_id"] == eval_args.custom_prompt_id:
                    custom_prompt_template = prompt["template"]
                    logging.info(f"Using custom prompt template: {custom_prompt_template}")
                    break
            if custom_prompt_template is None:
                logging.warning(f"Prompt ID {eval_args.custom_prompt_id} not found in {eval_args.custom_prompt_file}. Using default prompt template.")
        except Exception as e:
            logging.error(f"Failed to load custom prompt template: {e}. Using default prompt template.")
    else:
        logging.info("No custom prompt file or ID provided. Using default prompt template.")

    # --- Evaluation Loop ---
    if model_args.provider == "openai":
        client = OpenAI(api_key=model_args.api_key)
    elif model_args.provider == "groq":
        client = Groq(
            api_key=os.environ.get("GROQ_API_KEY"),
        )
    else:
        logging.error(f"Unsupported model provider: {model_args.provider}. Supported providers are 'openai' and 'groq'.")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)
    if eval_args.batched_eval:
        results, correct_predictions, total_predictions = batched_generation(
            prompt_file=eval_args.custom_prompt_file if eval_args.custom_prompt_file else "default_prompt",
            prompt_id=eval_args.custom_prompt_id if eval_args.custom_prompt_id else "default_prompt",
            early_stopping=eval_args.early_stopping,
            provider=model_args.provider,
            model_name_or_path=model_args.model_name,
            prompt_template=custom_prompt_template,
            reasoning_type=data_args.reasoning_type,
            client=client,
            dataset=dataset,
            input_column=data_args.input_columns,
            output_column=data_args.answer_key
        )
    else:
        results, correct_predictions, total_predictions = single_eval(
            client, dataset, model_args, data_args, custom_prompt_template
        )
    # --- Calculate and Save Results ---
    accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
    logging.info(f"Evaluation finished. Accuracy: {accuracy:.2f}%")
    
    final_results = {
        "model": model_args.model_name,
        "dataset": data_args.dataset_name,
        "split": data_args.split,
        "accuracy": accuracy,
        "total_correct": correct_predictions,
        "total_predictions": total_predictions,
        "evaluation_details": results
    }

    safe_model_name = model_args.model_name.replace("/", "_")
    safe_dataset_name = data_args.dataset_name.replace("/", "_")
    results_filename = f"results_{safe_model_name}_{safe_dataset_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    results_filepath = os.path.join(eval_args.output_dir, results_filename)
    
    with open(results_filepath, "w") as f:
        json.dump(final_results, f, indent=4)
    logging.info(f"Results saved to {results_filepath}")

    # Log to W&B ---
    if wandb_args.wandb_project and not eval_args.early_stopping:
        logging.info("Logging results to W&B...")
        wandb.log({
            "accuracy": accuracy,
            "total_correct": correct_predictions,
            "total_predictions": total_predictions
        })
        
        results_table = wandb.Table(columns=["prompt", "ground_truth", "predicted_answer", "is_correct", "model_output"])
        for res in results:
            results_table.add_data(
                res["prompt"], res["ground_truth"], res["predicted_answer"], res["is_correct"], res["model_output"]
            )
        wandb.log({"evaluation_details": results_table})
        wandb.finish()

if __name__ == "__main__":
    main()
