#!/usr/bin/env python3
"""
Submit collected prompts to OpenAI Batch API.
Reads from llm-inference/prompts/prompts_xml_SU234_M3.json and submits as batch.
"""

import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any
from api_models.openai.query_openai import solve_graphs_batch


def load_collected_prompts(prompts_file: str) -> List[Dict[str, Any]]:
    """Load prompts from the collected JSON file."""
    with open(prompts_file, "r", encoding="utf-8") as f:
        return json.load(f)


def generate_response_path(prompt_data: Dict[str, Any], model: str) -> str:
    """Generate the response file path based on prompt metadata."""
    metadata = prompt_data["metadata"]

    benchmark = metadata["benchmark"]
    graph_type = metadata["graph_type"]
    encoding = metadata["encoding"]
    pattern = metadata["pattern"]
    system_prompt = metadata["system_prompt"]
    question_type = metadata["question_type"]
    target = metadata["target"]

    # Create response directory
    response_dir = f"datasets/{benchmark}/{graph_type}/responses"
    os.makedirs(response_dir, exist_ok=True)

    # Generate filename following the existing convention
    if question_type == "full_output":
        filename = f"{encoding}-{pattern}-{system_prompt}-{model}.txt"
    else:
        filename = (
            f"{encoding}-{pattern}-{system_prompt}-{question_type}-{target}-{model}.txt"
        )

    return os.path.join(response_dir, filename)


def generate_unique_batch_name(base_name: str, include_timestamp: bool = True) -> str:
    """
    Generate a unique batch name with timestamp.

    Parameters:
    - base_name: Base name for the batch
    - include_timestamp: Whether to add timestamp (can be disabled for testing)

    Returns:
    - Unique batch name with timestamp
    """
    if include_timestamp:
        # Format: YYYYMMDD_HHMMSS for sortability and readability
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        return f"{base_name}_{timestamp}"
    return base_name


def save_batch_tracking_info(
    batch_name: str,
    batch_ids: List[str],
    prompts_file: str,
    model: str,
    total_prompts: int,
    skipped: int,
) -> None:
    """
    Save batch tracking information to a master tracking file.
    This helps keep track of all submitted batches.
    """
    tracking_file = "batch_jobs/batch_tracking.json"
    os.makedirs("batch_jobs", exist_ok=True)

    # Load existing tracking data if it exists
    tracking_data = []
    if os.path.exists(tracking_file):
        try:
            with open(tracking_file, "r", encoding="utf-8") as f:
                tracking_data = json.load(f)
        except (json.JSONDecodeError, IOError):
            tracking_data = []

    # Add new batch info
    batch_info = {
        "batch_name": batch_name,
        "batch_ids": batch_ids,
        "timestamp": datetime.now().isoformat(),
        "prompts_file": prompts_file,
        "model": model,
        "total_prompts": total_prompts,
        "skipped_existing": skipped,
        "status": "submitted",
    }

    tracking_data.append(batch_info)

    # Save updated tracking data
    with open(tracking_file, "w", encoding="utf-8") as f:
        json.dump(tracking_data, f, indent=2)

    print(f"📝 Batch tracking info saved to {tracking_file}")


def submit_prompts_batch(
    prompts_file: str = "llm-inference/prompts/prompts_xml_SU234_M3.json",
    model: str = "gpt-4.1-nano",
    batch_name: str = None,
    max_batch_size: int = 50000,
    skip_existing: bool = True,
    dry_run: bool = False,
    add_timestamp: bool = True,
) -> List[str]:
    """
    Submit collected prompts to OpenAI Batch API.

    Parameters:
    - prompts_file: Path to the collected prompts JSON file
    - model: OpenAI model to use (without -batch-api suffix)
    - batch_name: Custom batch name (auto-generated if None)
    - max_batch_size: Maximum number of prompts per batch
    - skip_existing: Skip prompts that already have responses
    - dry_run: Show what would be submitted without actually submitting
    - add_timestamp: Add timestamp to batch name for uniqueness

    Returns:
    - List of batch IDs
    """

    print(f"📂 Loading prompts from {prompts_file}...")
    prompts_data = load_collected_prompts(prompts_file)
    print(f"✅ Loaded {len(prompts_data)} prompts")

    # Prepare prompts and output paths
    prompts_to_submit = []
    output_paths = []
    skipped_count = 0

    for prompt_data in prompts_data:
        # Generate response path
        response_path = generate_response_path(prompt_data, model)

        # Skip if response already exists
        if skip_existing and os.path.exists(response_path):
            skipped_count += 1
            continue

        # Prepare prompt text (system_prompt + text)
        system_prompt = prompt_data.get("system_prompt", "")
        text = prompt_data.get("text", "")

        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{text}"
        else:
            full_prompt = text

        prompts_to_submit.append(full_prompt)
        output_paths.append(response_path)

    print(f"📊 Prompts to submit: {len(prompts_to_submit)}")
    print(f"⏩ Skipped existing: {skipped_count}")

    if not prompts_to_submit:
        print("⚠️ No prompts to submit!")
        return []

    # Generate base batch name if not provided
    if not batch_name:
        # Use prompts filename without extension as base
        batch_name = os.path.splitext(os.path.basename(prompts_file))[0]

    # Add timestamp to make it unique
    unique_batch_name = generate_unique_batch_name(batch_name, add_timestamp)

    print(f"🏷️ Using batch name: {unique_batch_name}")

    if dry_run:
        print("🔍 DRY RUN - Would submit the following batches:")
        for i in range(0, len(prompts_to_submit), max_batch_size):
            batch_size = min(max_batch_size, len(prompts_to_submit) - i)
            print(f"  Batch {i//max_batch_size + 1}: {batch_size} prompts")
        return []

    # Split into batches if needed
    batch_ids = []
    for i in range(0, len(prompts_to_submit), max_batch_size):
        batch_prompts = prompts_to_submit[i : i + max_batch_size]
        batch_outputs = output_paths[i : i + max_batch_size]

        # Generate batch name for this part
        if len(prompts_to_submit) > max_batch_size:
            current_batch_name = f"{unique_batch_name}_part_{i//max_batch_size + 1}"
        else:
            current_batch_name = unique_batch_name

        print(
            f"🚀 Submitting batch {i//max_batch_size + 1} with {len(batch_prompts)} prompts..."
        )

        # Submit batch
        batch_id = solve_graphs_batch(
            prompts=batch_prompts,
            output_paths=batch_outputs,
            model=f"{model}-batch-api",
            batch_name=current_batch_name,
        )

        batch_ids.append(batch_id)
        print(f"✅ Batch submitted: {batch_id}")

        # Record retrieval instructions in batch summary
        os.makedirs("batch_jobs", exist_ok=True)
        summary_path = f"batch_jobs/{current_batch_name}_summary.txt"
        with open(summary_path, "a", encoding="utf-8") as f:
            f.write(f"Batch submitted at: {datetime.now().isoformat()}\n")
            f.write(f"Prompts file: {prompts_file}\n")
            f.write(f"Model: {model}\n")
            f.write(f"Batch ID: {batch_id}\n")
            f.write(f"Total prompts in this batch: {len(batch_prompts)}\n")
            f.write(
                f"Check status: python -m scripts.batch_run_tasks --check {batch_id}\n"
            )
            f.write(
                f"Retrieve results: python -m scripts.batch_run_tasks --retrieve {batch_id}\n\n"
            )

    # Save tracking information
    save_batch_tracking_info(
        batch_name=unique_batch_name,
        batch_ids=batch_ids,
        prompts_file=prompts_file,
        model=model,
        total_prompts=len(prompts_to_submit),
        skipped=skipped_count,
    )

    return batch_ids


def list_recent_batches(limit: int = 10) -> None:
    """
    List recent batch submissions from the tracking file.

    Parameters:
    - limit: Number of recent batches to display
    """
    tracking_file = "batch_jobs/batch_tracking.json"

    if not os.path.exists(tracking_file):
        print("⚠️ No batch tracking file found. No batches have been submitted yet.")
        return

    try:
        with open(tracking_file, "r", encoding="utf-8") as f:
            tracking_data = json.load(f)
    except (json.JSONDecodeError, IOError) as e:
        print(f"❌ Error reading tracking file: {e}")
        return

    if not tracking_data:
        print("⚠️ No batches found in tracking file.")
        return

    # Sort by timestamp (most recent first)
    tracking_data.sort(key=lambda x: x.get("timestamp", ""), reverse=True)

    print(f"\n📋 Recent Batch Submissions (showing last {limit}):")
    print("=" * 80)

    for i, batch in enumerate(tracking_data[:limit], 1):
        print(f"\n{i}. Batch: {batch['batch_name']}")
        print(f"   Submitted: {batch['timestamp']}")
        print(f"   Model: {batch['model']}")
        print(
            f"   Prompts: {batch['total_prompts']} (skipped {batch['skipped_existing']})"
        )
        print("   Batch IDs:")
        for batch_id in batch["batch_ids"]:
            print(f"     - {batch_id}")
        print("   Commands:")
        for batch_id in batch["batch_ids"]:
            print(f"     python -m scripts.batch_run_tasks --check {batch_id}")


def main():
    parser = argparse.ArgumentParser(
        description="Submit collected prompts to OpenAI Batch API"
    )

    parser.add_argument(
        "--prompts_file",
        default="llm-inference/prompts/prompts_xml_SU234_M3.json",
        help="Path to collected prompts JSON file",
    )
    parser.add_argument(
        "--model",
        default="gpt-4.1-nano",
        help="OpenAI model to use (default: gpt-4.1-nano)",
    )
    parser.add_argument(
        "--batch_name",
        help="Custom batch name (auto-generated with timestamp if not specified)",
    )
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=50000,
        help="Maximum prompts per batch (default: 50000)",
    )
    parser.add_argument(
        "--no_skip_existing",
        action="store_true",
        help="Don't skip prompts that already have responses",
    )
    parser.add_argument(
        "--dry_run",
        action="store_true",
        help="Show what would be submitted without actually submitting",
    )
    parser.add_argument(
        "--no_timestamp",
        action="store_true",
        help="Don't add timestamp to batch name (not recommended for production)",
    )
    parser.add_argument(
        "--list_recent",
        action="store_true",
        help="List recent batch submissions and exit",
    )
    parser.add_argument(
        "--list_limit",
        type=int,
        default=10,
        help="Number of recent batches to show (default: 10)",
    )

    args = parser.parse_args()

    # Handle list_recent command
    if args.list_recent:
        list_recent_batches(args.list_limit)
        return

    # Check if prompts file exists
    if not os.path.exists(args.prompts_file):
        print(f"❌ Prompts file not found: {args.prompts_file}")
        return

    # Submit batches
    batch_ids = submit_prompts_batch(
        prompts_file=args.prompts_file,
        model=args.model,
        batch_name=args.batch_name,
        max_batch_size=args.max_batch_size,
        skip_existing=not args.no_skip_existing,
        dry_run=args.dry_run,
        add_timestamp=not args.no_timestamp,
    )

    if batch_ids:
        print(f"\n🎉 Successfully submitted {len(batch_ids)} batch(es)!")
        print("📋 Batch IDs:")
        for i, batch_id in enumerate(batch_ids, 1):
            print(f"  {i}. {batch_id}")

        print("\n💡 Next steps:")
        print("1. Check batch status:")
        for batch_id in batch_ids:
            print(f"   python -m scripts.batch_run_tasks --check {batch_id}")
        print("2. Retrieve results when complete:")
        for batch_id in batch_ids:
            print(f"   python -m scripts.batch_run_tasks --retrieve {batch_id}")
        print("\n3. View recent batches:")
        print("   python -m scripts.submit_collected_prompts_batch --list_recent")


if __name__ == "__main__":
    main()
