#!/usr/bin/env python3
"""
Submit collected prompts to Google Gemini API.
Supports both sequential and parallel processing modes.
"""

import json
import os
import argparse
from datetime import datetime
from typing import List, Dict, Any
from api_models.google_gemini.query_gemini import solve_graph, solve_graphs_parallel


def load_collected_prompts(prompts_file: str) -> List[Dict[str, Any]]:
    """Load prompts from the collected JSON file."""
    with open(prompts_file, "r", encoding="utf-8") as f:
        return json.load(f)


def generate_response_path(prompt_data: Dict[str, Any], model: str) -> str:
    """Generate the response file path based on prompt metadata."""
    metadata = prompt_data["metadata"]

    benchmark = metadata["benchmark"]
    graph_type = metadata["graph_type"]
    encoding = metadata["encoding"]
    pattern = metadata["pattern"]
    system_prompt = metadata["system_prompt"]
    question_type = metadata["question_type"]
    target = metadata["target"]

    # Create response directory
    response_dir = f"datasets/{benchmark}/{graph_type}/responses"
    os.makedirs(response_dir, exist_ok=True)

    # Clean model name for filename
    model_filename = model.replace("-", "_").replace(".", "_").lower()

    # Generate filename following the existing convention
    if question_type == "full_output":
        filename = f"{encoding}-{pattern}-{system_prompt}-{model_filename}.txt"
    else:
        filename = f"{encoding}-{pattern}-{system_prompt}-{question_type}-{target}-{model_filename}.txt"

    return os.path.join(response_dir, filename)


def submit_prompts_gemini(
    prompts_file: str = "llm-inference/prompts/prompts_xml_SU234_M3.json",
    model: str = "gemini-flash",
    max_prompts: int = None,
    skip_existing: bool = True,
    mode: str = "sequential",
    max_workers: int = 5,
    delay: float = 0.5,
) -> List[Dict[str, Any]]:
    """
    Submit prompts to Gemini API.

    Parameters:
    - prompts_file: Path to the collected prompts JSON file
    - model: Gemini model to use
    - max_prompts: Maximum number of prompts to process
    - skip_existing: Skip prompts that already have responses
    - mode: 'sequential' or 'parallel' processing
    - max_workers: Number of concurrent workers for parallel mode
    - delay: Delay between requests (sequential) or batches (parallel)

    Returns:
    - List of results
    """

    print(f"📂 Loading prompts from {prompts_file}...")
    prompts_data = load_collected_prompts(prompts_file)

    if max_prompts:
        prompts_data = prompts_data[:max_prompts]

    print(f"✅ Loaded {len(prompts_data)} prompts")

    # Prepare prompts and output paths
    prompts_to_submit = []
    output_paths = []
    skipped_count = 0

    for prompt_data in prompts_data:
        # Generate response path
        response_path = generate_response_path(prompt_data, model)

        # Skip if response already exists
        if skip_existing and os.path.exists(response_path):
            skipped_count += 1
            continue

        # Prepare prompt text (system_prompt + text)
        system_prompt = prompt_data.get("system_prompt", "")
        text = prompt_data.get("text", "")

        if system_prompt:
            full_prompt = f"{system_prompt}\n\n{text}"
        else:
            full_prompt = text

        prompts_to_submit.append(full_prompt)
        output_paths.append(response_path)

    print(f"📊 Prompts to submit: {len(prompts_to_submit)}")
    print(f"⏩ Skipped existing: {skipped_count}")

    if not prompts_to_submit:
        print("⚠️ No prompts to submit!")
        return []

    # Process prompts based on mode
    if mode == "parallel":
        results = solve_graphs_parallel(
            prompts=prompts_to_submit,
            output_paths=output_paths,
            model=model,
            max_workers=max_workers,
            batch_delay=delay,
        )
    else:  # sequential
        results = []
        total = len(prompts_to_submit)

        print(f"🚀 Processing {total} prompts using Gemini {model} sequentially...")
        print(f"⏱️ Estimated time: {total * delay / 60:.1f} minutes")

        for i, (prompt, output_path) in enumerate(
            zip(prompts_to_submit, output_paths), 1
        ):
            print(f"Processing {i}/{total}: {os.path.basename(output_path)}")

            try:
                solve_graph(prompt, output_path, model)
                results.append({"output_path": output_path, "success": True})
            except Exception as e:
                print(f"  ❌ Failed: {e}")
                results.append(
                    {"output_path": output_path, "success": False, "error": str(e)}
                )

            # Rate limiting
            if i < total:
                import time

                time.sleep(delay)

    # Save processing summary
    summary_path = f"batch_jobs/gemini_{model.replace('/', '_')}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    os.makedirs("batch_jobs", exist_ok=True)

    with open(summary_path, "w", encoding="utf-8") as f:
        json.dump(
            {
                "model": model,
                "prompts_file": prompts_file,
                "timestamp": datetime.now().isoformat(),
                "total_prompts": len(prompts_to_submit),
                "skipped": skipped_count,
                "mode": mode,
                "results": results,
            },
            f,
            indent=2,
        )

    print(f"📝 Summary saved to {summary_path}")

    return results


def main():
    parser = argparse.ArgumentParser(
        description="Submit collected prompts to Google Gemini API"
    )

    parser.add_argument(
        "--prompts_file",
        default="llm-inference/prompts/prompts_xml_SU234_M3.json",
        help="Path to collected prompts JSON file",
    )
    parser.add_argument(
        "--model",
        default="gemini-flash",
        choices=[
            "gemini-2.5-pro",
            "gemini-2.5-flash",
            "gemini-2.5-flash-lite",
        ],
        help="Gemini model to use",
    )
    parser.add_argument(
        "--max_prompts",
        type=int,
        help="Maximum number of prompts to process (for testing)",
    )
    parser.add_argument(
        "--mode",
        default="sequential",
        choices=["sequential", "parallel"],
        help="Processing mode: sequential (safe) or parallel (faster)",
    )
    parser.add_argument(
        "--max_workers",
        type=int,
        default=5,
        help="Maximum concurrent requests for parallel mode",
    )
    parser.add_argument(
        "--no_skip_existing",
        action="store_true",
        help="Don't skip prompts that already have responses",
    )
    parser.add_argument(
        "--delay",
        type=float,
        default=0.5,
        help="Delay between requests/batches in seconds",
    )

    args = parser.parse_args()

    # Check if prompts file exists
    if not os.path.exists(args.prompts_file):
        print(f"❌ Prompts file not found: {args.prompts_file}")
        return

    # Submit prompts
    results = submit_prompts_gemini(
        prompts_file=args.prompts_file,
        model=args.model,
        max_prompts=args.max_prompts,
        skip_existing=not args.no_skip_existing,
        mode=args.mode,
        max_workers=args.max_workers,
        delay=args.delay,
    )

    if results:
        successful = sum(1 for r in results if r.get("success", False))
        print(
            f"\n🎉 Completed processing {successful}/{len(results)} prompts successfully!"
        )


if __name__ == "__main__":
    main()
