from openai import OpenAI
import json
from extraction import extract_generated_data
import wandb
import sys
import re
import argparse

def extract_question(content: str, pattern: str = r"Generate an alist from the following natural language:(.*?)\.Provide your answer") -> str:
    pattern = re.compile(pattern, re.DOTALL)
    match = pattern.search(content)
    
    if match:
        return match.group(1).strip()
    return None

def create_prompt_dict(file_path: str):
    prompt_data = {}
    try:
        with open(file_path, 'r') as f:
            for line in f:
                data = json.loads(line)
                request_id = data.get('custom_id')
                content = data.get('body', {}).get('input', [{}])[0].get('content')
                phrase = extract_question(content)
                if request_id and phrase:
                    prompt_data[request_id] = phrase 
                else:
                    print(f"Warning: Missing request_id or phrase in line: {line.strip()}")
    except FileNotFoundError:
        print(f"Error: The file '{file_path}' was not found.")
    except json.JSONDecodeError:
        print("Error: Failed to decode JSON from a line in the file.")
    return prompt_data

def main():
    parser = argparse.ArgumentParser(description="Process batch responses and log to WandB. To be used if the batch generation script is disrupted during the retrieval process.")
    parser.add_argument('--prompt_file_path', type=str, required=True, help='Path to the file containing prompts.')
    parser.add_argument('--api_key', type=str, required=True, help='OpenAI API key.')
    parser.add_argument('--batch_id', type=str, required=True, help='ID of the batch to process.')
    parser.add_argument('--batch_file_id', type=str, required=True, help='ID of the batch file to retrieve responses from.')
    parser.add_argument('--save_responses', action='store_true', help='Flag to save responses to a file.')
    parser.add_argument('--output_file', type=str, default='batch_output.json', help='File to save the batch output responses.')
    parser.add_argument('--prompt_pattern', type=str, default=r"Generate an alist from the following natural language:(.*?)\.Provide your answer", help='Pattern to extract questions from prompts.')
    parser.add_argument('--wandb_project', type=str, default='dataset_generation', help='WandB project name.')
    parser.add_argument('--wandb_run_name', type=str, default='deepmind-mathematics-alist-from-nl', help='WandB run name.')
    args = parser.parse_args()
    # get original prompts
    file_path = args.prompt_file_path
    prompts = create_prompt_dict(file_path)
    # initialize WandB
    wandb.init(
        project=args.wandb_project,
        name=args.wandb_run_name,
    )
    # get responses
    batch_file_id = args.batch_file_id
    client = OpenAI(api_key=args.api_key)
    response = client.files.content(file_id=batch_file_id).text
    response_filename = args.output_file
    try:
        response_data = [json.loads(line) for line in response.splitlines() if line.strip()]
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON from response: {e}")
        sys.exit(1)
    if response_data:
        if args.save_responses:
            print(f"Saving batch responses to {response_filename}...")
            try:
                with open(response_filename, 'w') as f:
                    json.dump(response_data, f, indent=2)
                print(f"Batch responses saved to {response_filename}.")
            except Exception as e:
                print(f"Error saving batch responses: {e}")
        print("Processing batch responses...")
        generations = []
        for i, gen in enumerate(response_data):
            # retrieve the phrase in the prompt
            request_id = gen.get("custom_id")
            phrase = prompts[request_id]
            # extract the alist from the response
            model_output = gen.get('response', {}).get('body', {}).get('output', [{}])[1].get('content', [{}])[0].get('text', '')
            phrase, alist = extract_generated_data(prompt_id="alist_from_nl", model_output=model_output, natural_language=phrase)
            generations.append({
                "model_output": model_output,
                "phrase": phrase,
                "alist": alist
            })
        print(f"Processed {len(generations)} generations from batch output.")
        # log to WandB
        print("Logging generated samples to WandB...")
        generation_table = wandb.Table(columns=["model_output", "phrase", "alist"])
        for gen in generations:
            generation_table.add_data(gen["model_output"], gen["phrase"], gen["alist"])
        wandb.log({"generated_samples": generation_table})
    wandb.finish()

if __name__ == "__main__":
    main()
