from extraction import extract_answer, extract_gt
from evaluation import check_answer
from groq import Groq
import json
import logging
import sys
import time
from typing import Optional
from datasets import load_dataset
from evaluate_cs_models import setup_logging
import wandb
import argparse
def fetch_and_eval_batch(
        provider: str, 
        prompt_template: str, 
        reasoning_type: str, 
        client, 
        dataset: Optional[list],
        batch_id: str,
        ) -> list:
    logging.info(f"Starting fetching batch")
    logging.info(f"Batch created with ID: {batch_id}")
    # every 5 minutes check the status of the batch
    while True:
        batch_status = client.batches.retrieve(batch_id)
        if batch_status.status == "completed":
            logging.info("Batch generation completed successfully.")
            logging.info(f"Output file ID: {batch_status.output_file_id}")
            break
        elif batch_status.status == "failed":
            logging.error("Batch generation failed.")
            sys.exit(1)
        else:
            logging.info(f"Batch status: {batch_status.status}. Waiting for 5 minutes before checking again.")
            time.sleep(300)
    # Retrieve the batch output
    logging.info("Retrieving batch output...")
    response_batch = client.files.content(batch_status.output_file_id)
    response_filename = f"batch_output_{batch_id}.json"
    response_batch.write_to_file(response_filename)
    response_data = []
    with open(response_filename, 'r') as f:
        for line in f:
            parsed_line = json.loads(line)
            response_data.append(parsed_line)
    if response_data:
        # Save response data to a file
        logging.info(f"First response in batch: {response_data[0]}")
        with open(response_filename, "w") as f:
            json.dump(response_data, f, indent=4)
        logging.info(f"Response data saved to {response_filename}")
        # Create a list of questions and answers 
        logging.info("Processing batch responses...")
        results = []
        total_predictions = len(response_data)
        correct_predictions = 0
        answers = dataset['answer']
        for i, gen in enumerate(response_data):
            if provider == "openai":
                model_output = gen[i]('response', {}).get('body', {}).get('output', [{}])[1].get('content', [{}])[0].get('text', '')
            elif provider == "groq":
                model_output = gen['response']['body']["choices"][0]['message']['content']
            ground_truth = answers[int(gen["custom_id"].replace("request-", "")) - 1]
            ground_truth = extract_gt(ground_truth) if isinstance(ground_truth, str) else ground_truth
            predicted_answer = extract_answer(model_output, reasoning_type)
            is_correct = check_answer(predicted_answer, ground_truth)
            if is_correct:
                correct_predictions += 1
            results.append({
                "prompt": prompt_template,
                "model_output": model_output,
                "predicted_answer": predicted_answer,
                "ground_truth": ground_truth,
                "is_correct": is_correct,
            })
        logging.info(f"Processed {len(results)} generations from batch output.")
        return results, correct_predictions, total_predictions
    logging.warning("No responses found in the batch output.")
    return None

def main():
    parser = argparse.ArgumentParser(description="Used to retrieve batch evaluation results if the evaluate_cs_models.py script is interrupted during the retrieval process.")
    parser.add_argument("--provider", type=str, default="groq", help="Provider name (e.g., groq, openai)")
    parser.add_argument("--model_name", type=str, default="oss-20b", help="Model name")
    parser.add_argument("--prompt_template", type=str, default="alist_few_shot", help="Prompt template")
    parser.add_argument("--reasoning_type", type=str, default="numerical", help="Reasoning type")
    parser.add_argument("--dataset_name", type=str, default="openai/gsm8k", help="Dataset name")
    parser.add_argument("--dataset_subset", type=str, default="main", help="Dataset subset")
    parser.add_argument("--dataset_split", type=str, default="test", help="Dataset split")
    parser.add_argument("--batch_id", type=str, required=True, help="Batch ID")
    parser.add_argument("--wandb_project", type=str, default="prompt-finetuning-groq", help="Weights & Biases project name")
    parser.add_argument("--wandb_run_name", type=str, help="Weights & Biases run name")
    parser.add_argument("--api_key", type=str, required=True, help="API key for the provider")

    args = parser.parse_args()

    wandb_run_name = args.wandb_run_name or f"{args.model_name}-{args.dataset_name}-{args.prompt_template}".replace("/", "-")
    client = Groq(
        api_key=args.api_key,
    )

    setup_logging(".", args.model_name, args.dataset_name)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logging.info("Starting LLM evaluation script.")
    dataset = load_dataset(args.dataset_name, args.dataset_subset, split=args.dataset_split)
    results, correct_predictions, total_predictions = fetch_and_eval_batch(args.provider, args.prompt_template, args.reasoning_type, client, dataset, args.batch_id)
    print(f"Correct Predictions: {correct_predictions}, Total Predictions: {total_predictions}")
    accuracy = (correct_predictions / total_predictions) * 100 if total_predictions > 0 else 0
    logging.info(f"Evaluation finished. Accuracy: {accuracy:.2f}%")

    final_results = {
        "model": args.model_name,
        "dataset": args.dataset_name,
        "split": 'test',
        "accuracy": accuracy,
        "total_correct": correct_predictions,
        "total_predictions": total_predictions,
        "evaluation_details": results
    }
    try:
        wandb.init(
            project=wandb_project,
            name=wandb_run_name,
            config={
                "model_name": args.model_name,
                "data_args": args.dataset_name,
                "eval_args": args.prompt_template,
                "reasoning_type": args.reasoning_type
            }
        )
        logging.info(f"Logging to W&B project: {wandb_project}")
    except Exception as e:
        logging.warning(f"Could not initialize W&B: {e}. Disabling W&B logging.")
        wandb_project = None # Disable wandb if init fails

    if wandb_project:
        logging.info("Logging results to W&B...")
        wandb.log(final_results)
        results_table = wandb.Table(columns=["prompt", "ground_truth", "predicted_answer", "is_correct", "model_output"])
        for res in results:
            results_table.add_data(
                res["prompt"], res["ground_truth"], res["predicted_answer"], res["is_correct"], res["model_output"]
            )
        wandb.log({"evaluation_details": results_table})
        wandb.finish()

if __name__ == "__main__":
    main()