import os 
import json
from tqdm import tqdm
import logging 
from transformers import HfArgumentParser
from typing import Optional
from dataclasses import dataclass, field
from datetime import datetime
import sys
import wandb
from extraction import extract_generated_data
from openai import OpenAI
from datasets import load_dataset
import time
import glob
@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to the pre-trained model."}
    )
    api_key: Optional[str] = field(
        default=None,
        metadata={"help": "Hugging Face API key for model access."}
    )
    temperature: Optional[float] = field(
        default=None,
        metadata={"help": "Temperature for sampling. Higher values mean more random outputs."}
    )
    top_p: Optional[float] = field(
        default=None,
        metadata={"help": "Top-p (nucleus) sampling parameter"}
    )

@dataclass
class GenerationArguments:
    prompt_file: str = field(
        metadata={"help": "Path to the file containing prompts."}
    )
    prompt_id: str = field(
        metadata={"help": "ID of the prompt to be used."}
    )
    num_samples: int = field(
        default=100,
        metadata={"help": "Number of samples to generate."}
    )   
    batched_generation: bool = field(
        default=False,
        metadata={"help": "Whether to use batched generation."}
    ) 
    output_dir: str = field(
        default="output",
        metadata={"help": "Directory to save the generated samples."}
    )
    dataset_name: Optional[str] = field(
        default=None,
        metadata={"help": "Name of the dataset to use for generation."}
    )
    dataset_path : Optional[str] = field(
        default=None,
        metadata={"help": "Path to the dataset file, if applicable."}
    )


@dataclass
class WandbArguments:
    wandb_project: Optional[str] = field(
        default="llm-evaluation", metadata={"help": "The name of the W&B project to log to. Set to None to disable."}
    )
    wandb_run_name: Optional[str] = field(
        default=None, metadata={"help": "A specific name for the W&B run. If None, a random name will be generated."}
    )

# setup logging
def setup_logging(output_dir: str, model_name: str):
    safe_model_name = model_name.replace("/", "_")
    log_filename = f"eval_{safe_model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
    log_filepath = os.path.join(output_dir, log_filename)
    os.makedirs(output_dir, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.FileHandler(log_filepath), logging.StreamHandler(sys.stdout)]
    )
    logging.info(f"Logging to {log_filepath}")

def load_prompt_template(filepath: str, prompt_id: str) -> str:
    prompt_template = None
    if filepath and prompt_id:
        try:
            with open(filepath, "r") as f:
                prompts = json.load(f)
            for prompt in prompts["prompts"]:
                if prompt["prompt_id"] == prompt_id:
                    logging.info(f"Checking prompt ID {prompt_id} in {filepath}")
                    prompt_template = prompt["template"]
                    return prompt_template
            if prompt_template is None:
                logging.error(f"Prompt ID {prompt_id} not found in {filepath}")
                sys.exit(1)
        except FileNotFoundError:
            logging.error(f"Prompt file {filepath} not found.")
            sys.exit(1)
    else:
        logging.error("Prompt file or prompt ID not provided.")
        sys.exit(1)

def format_prompt(template: str, natural_language: Optional[str]) -> str:
    if natural_language:
        return template.format(phrase=natural_language)
    return template    

def single_generation(num_samples: int, model_name_or_path: str, prompt_template: str, prompt_id: str, client: OpenAI, dataset: Optional[str]) -> list:
    generations = []
    for i in tqdm(range(0, num_samples), desc="Generating samples"):
        natural_language_example = dataset[i].get("")
        if dataset:
            natural_language_example = dataset[i].get("phrase", None)
        prompt = format_prompt(prompt_template, natural_language_example)
        output = client.responses.create(
            model= model_name_or_path,
            reasoning={"effort": "medium"},
            input=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ]
        )
        model_output = output.output_text
        if i == 0:
            logging.info(f"Generated output for first sample: {model_output}")
        phrase, alist = extract_generated_data(prompt_id, model_output, natural_language_example)
        generations.append({
            "prompt": prompt,
            "model_output": model_output,
            "phrase": phrase,
            "alist": alist
        })
    logging.info(f"Generated {len(generations)} samples.")
    return generations

def create_batch_prompt_file(num_samples: int, prompt_template: str, model_name_or_path: str, dataset: Optional[list]) -> str:
    batch_prompts = []
    if dataset:
        num_samples = min(num_samples, len(dataset['train']))
        logging.info(f"Using {num_samples} samples from the dataset.")
    for i in range(num_samples):
        natural_language = dataset['train'][i].get("question", None) if dataset else None
        prompt = format_prompt(prompt_template, natural_language)
        batch_prompts.append(prompt)
    # use the datetime to create a unique filename
    datetime_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    batch_filename = f"batch_prompts_{datetime_str}.jsonl"
    with open(batch_filename, "w") as f:
        for i, prompt in enumerate(batch_prompts):
            f.write(json.dumps({"custom_id": f"request-{i+1}", "method": "POST", "url": "/v1/responses", "body": {"model": f"{model_name_or_path}", "reasoning": {"effort": "medium"}, "input": [{"role": "user", "content": prompt}]}}) + "\n")
    logging.info(f"Batch prompt file created: {batch_filename}")
    logging.info(f"Number of prompts in batch file: {len(batch_prompts)}")
    return batch_filename

def batched_generation(num_samples: int, model_name_or_path: str, prompt_template: str, prompt_id: str, client: OpenAI, dataset: Optional[list]) -> list:
    logging.info(f"Starting batched generation with {num_samples} samples.")
    logging.info("Creating batch prompt file...")
    batch_filename = create_batch_prompt_file(num_samples, prompt_template, model_name_or_path, dataset)
    logging.info(f"Batch prompt file created: {batch_filename}")
    logging.info("Uploading batch prompt file to OpenAI...")
    batch_input_file = client.files.create(
        file=open(batch_filename, "rb"),
        purpose="batch"
    )
    logging.info("Batch prompt file uploaded successfully.")
    batch_input_file_id = batch_input_file.id
    logging.info(f"Batch input file created with ID: {batch_input_file_id}")
    logging.info("Creating batch for generation...")
    batch = client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/responses",
        completion_window="24h",
    )
    batch_id = batch.id
    logging.info(f"Batch created with ID: {batch_id}")
    # every 5 minutes check the status of the batch
    while True:
        batch_status = client.batches.retrieve(batch.id)
        if batch_status.status == "completed":
            logging.info("Batch generation completed successfully.")
            logging.info(f"Output file ID: {batch_status.output_file_id}")
            break
        elif batch_status.status == "failed":
            logging.error("Batch generation failed.")
            sys.exit(1)
        else:
            logging.info(f"Batch status: {batch_status.status}. Waiting for 5 minutes before checking again.")
            time.sleep(300)
    logging.info("Retrieving batch output...")
    response_batch = client.files.content(batch_status.output_file_id).text
    response_filename = f"batch_output_{batch_id}.json"
    response_data = [json.loads(line) for line in response_batch.splitlines() if line.strip()]
    if response_data:
        # save response data to a file
        logging.info(f"First response in batch: {response_data[0]}")
        with open(response_filename, "w") as f:
            json.dump(response_data, f, indent=4)
        logging.info(f"Response data saved to {response_filename}")
        # create a list of generations from the response data
        logging.info("Processing batch responses...")
        generations = []
        for i, gen in enumerate(response_data):
            model_output = gen.get('response', {}).get('body', {}).get('output', [{}])[1].get('content', [{}])[0].get('text', '')
            phrase, alist = extract_generated_data(prompt_id, model_output)
            generations.append({
                "prompt": prompt_template,
                "model_output": model_output,
                "phrase": phrase,
                "alist": alist
            })
        logging.info(f"Processed {len(generations)} generations from batch output.")
        return generations
    logging.warning("No responses found in the batch output.")
    return None

def main():
    # --- Argument Parsing ---
    parser = HfArgumentParser((ModelArguments, GenerationArguments, WandbArguments))
    model_args, gen_args, wandb_args = parser.parse_args_into_dataclasses()

    # --- Setup Logging ---
    setup_logging(gen_args.output_dir, model_args.model_name_or_path)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logging.info("Starting LLM evaluation script.")
    logging.info(f"W&B Project: {wandb_args.wandb_project}, Run Name: {wandb_args.wandb_run_name}")

    # --- Initialize W&B ---
    if wandb_args.wandb_project:
        try:
            wandb.init(
                project=wandb_args.wandb_project,
                name=wandb_args.wandb_run_name,
            )
            logging.info(f"Logging to W&B project: {wandb_args.wandb_project}")
        except Exception as e:
            logging.warning(f"Could not initialize W&B: {e}. Disabling W&B logging.")
            wandb_args.wandb_project = None # Disable wandb if init fails
    
    # --- Load Dataset ---
    if gen_args.dataset_name:
            dataset = load_dataset(gen_args.dataset_name)
    elif gen_args.dataset_path:
        dataset_files = glob.glob(os.path.join(gen_args.dataset_path, "*.json"))
        logging.info(f"Found {len(dataset_files)} JSON files in {gen_args.dataset_path}.")
        if not dataset_files:
            logging.error(f"No JSON files found in {gen_args.dataset_path}.")
            if wandb_args.wandb_project: wandb.finish(exit_code=1)
            sys.exit(1)
        dataset = load_dataset("json", data_files=dataset_files)
        logging.info(f"Loaded dataset of length {len(dataset['train'])}.")
        logging.info(f"Dataset sample: {dataset['train'][0] if 'train' in dataset else dataset[0]}")

    # --- Generate Dataset ---
    client = OpenAI(api_key=model_args.api_key)
    prompt_template = load_prompt_template(gen_args.prompt_file, gen_args.prompt_id)
    if not prompt_template:
        logging.error("Prompt template could not be loaded.")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)
    logging.info("Starting dataset generation.")

    if gen_args.batched_generation:
        generations = batched_generation(
            gen_args.num_samples,
            model_args.model_name_or_path,
            prompt_template,
            gen_args.prompt_id,
            client,            
            dataset if dataset else None
        )
    else:
        generations = single_generation(
            gen_args.num_samples,
            model_args.model_name_or_path,
            prompt_template,
            gen_args.prompt_id,
            client,            
            model_args.temperature if model_args.temperature is not None else 1.0,
            model_args.top_p if model_args.top_p is not None else 1.0,
            dataset if gen_args.dataset_name else None
        )

    if generations is None:
        logging.error("No generations returned from batched generation.")
        if wandb_args.wandb_project: wandb.finish(exit_code=1)
        sys.exit(1)

    safe_model_name = model_args.model_name_or_path.replace("/", "_")
    results_filename = f"{gen_args.prompt_id}_{safe_model_name}_generated_samples.csv"
    results_filepath = os.path.join(gen_args.output_dir, results_filename)
    with open(results_filepath, "w") as f:
        f.write("prompt,model_output,phrase,alist\n")
        for gen in generations:
            f.write(f'"{gen["prompt"]}","{gen["model_output"]}","{gen["phrase"]}","{gen["alist"]}"\n')
    logging.info(f"Generated samples saved to {results_filepath}")

    if wandb_args.wandb_project:
        logging.info("Logging generated samples to W&B.")
        generation_table = wandb.Table(columns=["prompt", "model_output", "phrase", "alist"])
        for gen in generations:
            generation_table.add_data(gen["prompt"], gen["model_output"], gen["phrase"], gen["alist"])
        wandb.log({"generated_samples": generation_table})
        wandb.finish()

if __name__ == "__main__":
    main()
 