"""Remote evaluation server for SWE-Smith.

This server runs on a VM with Docker access and handles evaluation requests
from training/eval jobs running on Mosaic (which don't have Docker access).

Usage:
    python -m advisor_models.swe_smith.eval_server --port 5152 --host 0.0.0.0

Requirements:
    - Docker must be installed and running
    - Port must be accessible from Mosaic cluster
"""

import argparse
import logging
import uuid
from flask import Flask, request, jsonify
import threading
import queue
import time

# Import the actual evaluation function
from .config import compute_score as _compute_score_local

# Setup logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

app = Flask(__name__)

# Queue for managing concurrent evaluations
eval_queue = queue.Queue()
results_cache = {}
MAX_CACHE_SIZE = 1000
CACHE_TTL = 3600  # 1 hour


def worker():
    """Background worker that processes evaluation requests."""
    while True:
        try:
            eval_id, patch, instance = eval_queue.get()
            logger.info(f"Processing evaluation {eval_id}")

            try:
                reward, _, info = _compute_score_local(patch, instance)
                results_cache[eval_id] = {
                    "status": "completed",
                    "reward": reward,
                    "info": info,
                    "timestamp": time.time(),
                }
                logger.info(f"Completed evaluation {eval_id}: reward={reward}")
            except Exception as e:
                logger.error(f"Error in evaluation {eval_id}: {e}", exc_info=True)
                results_cache[eval_id] = {
                    "status": "error",
                    "error": str(e),
                    "timestamp": time.time(),
                }
            finally:
                eval_queue.task_done()

        except Exception as e:
            logger.error(f"Worker error: {e}", exc_info=True)


def cleanup_cache():
    """Periodically clean up old cache entries."""
    while True:
        time.sleep(300)  # Every 5 minutes
        current_time = time.time()
        to_delete = []

        for eval_id, result in results_cache.items():
            if current_time - result.get("timestamp", 0) > CACHE_TTL:
                to_delete.append(eval_id)

        for eval_id in to_delete:
            del results_cache[eval_id]

        if to_delete:
            logger.info(f"Cleaned up {len(to_delete)} old cache entries")


@app.route("/health", methods=["GET"])
def health():
    """Health check endpoint."""
    return jsonify(
        {
            "status": "healthy",
            "queue_size": eval_queue.qsize(),
            "cache_size": len(results_cache),
        }
    )


@app.route("/evaluate", methods=["POST"])
def evaluate():
    """Submit an evaluation request.

    Request body:
    {
        "patch": "diff content...",
        "instance": {
            "instance_id": "...",
            "repo": "...",
            ...
        }
    }

    Returns:
    {
        "eval_id": "unique-id",
        "status": "queued"
    }
    """
    try:
        data = request.get_json()

        if not data or "patch" not in data or "instance" not in data:
            return jsonify({"error": "Missing 'patch' or 'instance' in request"}), 400

        patch = data["patch"]
        instance = data["instance"]

        # Generate unique ID
        eval_id = str(uuid.uuid4())

        # Add to queue
        eval_queue.put((eval_id, patch, instance))

        # Initialize cache entry
        results_cache[eval_id] = {"status": "queued", "timestamp": time.time()}

        logger.info(
            f"Queued evaluation {eval_id} for instance {instance.get('instance_id', 'unknown')}"
        )

        return jsonify(
            {
                "eval_id": eval_id,
                "status": "queued",
                "queue_position": eval_queue.qsize(),
            }
        )

    except Exception as e:
        logger.error(f"Error in /evaluate: {e}", exc_info=True)
        return jsonify({"error": str(e)}), 500


@app.route("/result/<eval_id>", methods=["GET"])
def get_result(eval_id):
    """Get the result of an evaluation.

    Returns:
    {
        "status": "queued" | "processing" | "completed" | "error",
        "reward": float (if completed),
        "info": str (if completed),
        "error": str (if error)
    }
    """
    try:
        if eval_id not in results_cache:
            return jsonify({"error": "Evaluation ID not found"}), 404

        result = results_cache[eval_id]
        return jsonify(result)

    except Exception as e:
        logger.error(f"Error in /result: {e}", exc_info=True)
        return jsonify({"error": str(e)}), 500


@app.route("/evaluate_sync", methods=["POST"])
def evaluate_sync():
    """Synchronous evaluation endpoint (blocks until complete).

    Request body: same as /evaluate

    Returns:
    {
        "reward": float,
        "info": str
    }
    """
    try:
        data = request.get_json()

        if not data or "patch" not in data or "instance" not in data:
            return jsonify({"error": "Missing 'patch' or 'instance' in request"}), 400

        patch = data["patch"]
        instance = data["instance"]

        logger.info(
            f"Synchronous evaluation for instance {instance.get('instance_id', 'unknown')}"
        )

        # Run evaluation directly
        reward, run_id, info = _compute_score_local(patch, instance)

        return jsonify({"reward": reward, "run_id": run_id, "info": info})

    except Exception as e:
        logger.error(f"Error in /evaluate_sync: {e}", exc_info=True)
        return jsonify({"error": str(e)}), 500


def main():
    parser = argparse.ArgumentParser(description="SWE-Smith Remote Evaluation Server")
    parser.add_argument("--host", type=str, required=True, help="Host to bind to")
    parser.add_argument("--port", type=int, required=True, help="Port to listen on")
    parser.add_argument(
        "--workers", type=int, default=12, help="Number of worker threads (default: 12)"
    )

    args = parser.parse_args()

    # Start worker threads
    logger.info(f"Starting {args.workers} worker threads")
    for _ in range(args.workers):
        thread = threading.Thread(target=worker, daemon=True)
        thread.start()

    # Start cache cleanup thread
    cleanup_thread = threading.Thread(target=cleanup_cache, daemon=True)
    cleanup_thread.start()

    # Start server
    logger.info(f"Starting server on {args.host}:{args.port}")
    app.run(host=args.host, port=args.port, threaded=True)


if __name__ == "__main__":
    main()
