#!/usr/bin/env python3
"""
Submit collected prompts to DeepInfra API.
Processes prompts sequentially with rate limiting, or in parallel with a worker pool.
"""

import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any
from api_models.deepinfra.query_deepinfra import solve_graphs_batch


def load_collected_prompts(prompts_file: str) -> List[Dict[str, Any]]:
    """Load prompts from the collected JSON file."""
    with open(prompts_file, "r", encoding="utf-8") as f:
        return json.load(f)


def generate_response_path(prompt_data: Dict[str, Any], model: str) -> str:
    """Generate the response file path based on prompt metadata."""
    metadata = prompt_data["metadata"]

    benchmark = metadata["benchmark"]
    graph_type = metadata["graph_type"]
    encoding = metadata["encoding"]
    pattern = metadata["pattern"]
    system_prompt = metadata["system_prompt"]
    question_type = metadata["question_type"]
    target = metadata["target"]

    # Create response directory
    response_dir = f"datasets/{benchmark}/{graph_type}/responses"
    os.makedirs(response_dir, exist_ok=True)

    # Generate filename following the existing convention
    if question_type == "full_output":
        filename = f"{encoding}-{pattern}-{system_prompt}-{model}.txt"
    else:
        filename = (
            f"{encoding}-{pattern}-{system_prompt}-{question_type}-{target}-{model}.txt"
        )

    return os.path.join(response_dir, filename)


def submit_prompts_deepinfra(
    prompts_file: str = "llm-inference/prompts/prompts_xml_SU234_M3.json",
    model: str = "qwq",
    max_prompts: int = None,
    skip_existing: bool = True,
    mode: str = "sequential",
    max_workers: int = 5,
    delay: float = 0.5,
    verbose: bool = False,
) -> List[Dict[str, Any]]:
    """
    Submit prompts to DeepInfra API.

    Parameters:
    - prompts_file: Path to the collected prompts JSON file
    - model: DeepInfra model to use ('qwq', 'deepseek-r1', etc.)
    - max_prompts: Maximum number of prompts to process (None for all)
    - skip_existing: Skip prompts that already have responses
    - delay: Delay between requests in seconds
    - mode: 'sequential' or 'parallel'
    - max_workers: number of concurrent workers for parallel mode
    - verbose: print verbose banners
    """
    print(f"📂 Loading prompts from {prompts_file}...")
    prompts_data = load_collected_prompts(prompts_file)

    if max_prompts:
        prompts_data = prompts_data[:max_prompts]

    print(f"✅ Loaded {len(prompts_data)} prompts")

    # Prepare prompts and output paths
    prompts_to_submit = []
    output_paths = []
    skipped_count = 0

    # Clean model name for filenames
    model_filename = model.replace("/", "-").lower()

    for prompt_data in prompts_data:
        # Generate response path
        response_path = generate_response_path(prompt_data, model_filename)

        # Skip if response already exists
        if skip_existing and os.path.exists(response_path):
            skipped_count += 1
            continue

        # Prepare prompt text (system_prompt + text)
        system_prompt = prompt_data.get("system_prompt", "")
        text = prompt_data.get("text", "")

        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{text}"
        else:
            full_prompt = text

        prompts_to_submit.append(full_prompt)
        output_paths.append(response_path)

    print(f"📊 Prompts to submit: {len(prompts_to_submit)}")
    print(f"⏩ Skipped existing: {skipped_count}")

    if not prompts_to_submit:
        print("⚠️ No prompts to submit!")
        return []

    # Process prompts
    if mode == "parallel":
        from api_models.deepinfra.query_deepinfra import solve_graphs_parallel

        results = solve_graphs_parallel(
            prompts=prompts_to_submit,
            output_paths=output_paths,
            model=model,
            max_workers=max_workers,
            delay_between_batches=delay,
            verbose=verbose,
        )
    else:  # sequential
        results = solve_graphs_batch(
            prompts=prompts_to_submit,
            output_paths=output_paths,
            model=model,
            delay_between_requests=delay,
        )

    # Save processing summary
    summary_path = f"batch_jobs/deepinfra_{model_filename}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    os.makedirs("batch_jobs", exist_ok=True)

    with open(summary_path, "w", encoding="utf-8") as f:
        json.dump(
            {
                "model": model,
                "prompts_file": prompts_file,
                "timestamp": datetime.now().isoformat(),
                "total_prompts": len(prompts_to_submit),
                "skipped": skipped_count,
                "results": results,
            },
            f,
            indent=2,
        )

    print(f"📝 Summary saved to {summary_path}")

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Submit collected prompts to DeepInfra API"
    )

    parser.add_argument(
        "--prompts_file",
        default="llm-inference/prompts/prompts_xml_SU234_M3.json",
        help="Path to collected prompts JSON file",
    )
    parser.add_argument(
        "--model",
        default="qwq",
        choices=[
            "qwq",
            "qwq-32b",
            "deepseek-r1",
            "deepseek-r1-distill-qwen",
            "deepseek-r1-distill-llama",
        ],
        help="DeepInfra model to use",
    )
    parser.add_argument(
        "--max_prompts",
        type=int,
        help="Maximum number of prompts to process (for testing)",
    )
    parser.add_argument(
        "--no_skip_existing",
        action="store_true",
        help="Don't skip prompts that already have responses",
    )
    parser.add_argument(
        "--delay",
        type=float,
        default=0.5,
        help="Delay between requests in seconds (default: 0.5)",
    )
    parser.add_argument(
        "--mode",
        default="sequential",
        choices=["sequential", "parallel"],
        help="Processing mode: sequential (safe) or parallel (faster)",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=5,
        help="Maximum concurrent requests for parallel mode (default: 5)",
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Verbose logging (model banner, debug prints)",
    )

    args = parser.parse_args()

    # Check if prompts file exists
    if not os.path.exists(args.prompts_file):
        print(f"❌ Prompts file not found: {args.prompts_file}")
        return

    # Submit prompts
    results = submit_prompts_deepinfra(
        prompts_file=args.prompts_file,
        model=args.model,
        max_prompts=args.max_prompts,
        skip_existing=not args.no_skip_existing,
        delay=args.delay,
        mode=args.mode,
        max_workers=args.max_workers,
        verbose=args.verbose,
    )

    if results:
        successful = sum(1 for r in results if r.get("success", False))
        print(
            f"\n🎉 Completed processing {successful}/{len(results)} prompts successfully!"
        )


if __name__ == "__main__":
    main()
