"""
Async Multi-Session Agent Runner

Runs multiple agent sessions concurrently using asyncio.to_thread() to wrap
the synchronous run_agent_session() function. This approach:
- Reuses existing agent handlers (GeminiHandler, OpenAIHandler, etc.)
- Preserves all logging functionality (logs sent to backend)
- Maintains prompt loading and code extraction logic
- Enables true parallel execution with controlled concurrency

Usage:
    python -m agent.async_runner --config experiment_configs/main_exp.json --workers 5
    python -m agent.async_runner --config experiment_configs/main_exp.json --workers 5 --dry-run
    python -m agent.async_runner --resume results/async_results_20250101.json --workers 5
"""

import asyncio
import json
import os
import sys
import shutil
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
import logging
import requests

# Global thread pool executor (will be initialized with proper max_workers)
_thread_pool: Optional[ThreadPoolExecutor] = None

# Load environment variables
try:
    from dotenv import load_dotenv
    load_dotenv()
except ImportError:
    pass

# Add paths for imports
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


@dataclass
class SessionState:
    """State for tracking a single agent session."""
    experiment: str = ""
    model: str = ""
    run_number: int = 1
    max_turns: int = 100

    # Results (populated after session completes)
    session_id: Optional[str] = None
    success: bool = False
    victory: Optional[bool] = None
    survival_rate: Optional[float] = None
    error: Optional[str] = None
    turns: int = 0
    drones_used: int = 0

    # Timing
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None

    def task_key(self) -> Tuple[str, str, int]:
        """Return unique key for this task."""
        return (self.experiment, self.model, self.run_number)

    def duration_seconds(self) -> float:
        """Get session duration in seconds."""
        if self.start_time and self.end_time:
            return (self.end_time - self.start_time).total_seconds()
        return 0.0


async def run_session_async(
    state: SessionState,
    semaphore: asyncio.Semaphore,
    base_url: str = "http://localhost:8000"
) -> SessionState:
    """
    Run a single agent session asynchronously.

    Uses a custom ThreadPoolExecutor to run the synchronous run_agent_session()
    in a thread pool, allowing concurrent execution beyond the default limit.
    """
    global _thread_pool

    async with semaphore:
        state.start_time = datetime.now()
        logger.info(f"[{state.experiment}/{state.model}#{state.run_number}] Starting session")

        try:
            # Import here to avoid circular imports
            from run_agent import run_agent_session

            # Run the synchronous function in a thread pool
            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                _thread_pool,  # Use our custom thread pool
                lambda: run_agent_session(
                    model_name=state.model,
                    experiment=state.experiment,
                    max_turns=state.max_turns,
                    base_url=base_url,
                    enable_thinking=True,
                    verbose=False  # Quiet mode for async runs
                )
            )

            # Copy results to state
            state.success = result.get('success', False)
            state.victory = result.get('victory')
            state.survival_rate = result.get('survival_rate')
            state.session_id = result.get('session_id')
            state.error = result.get('error')
            state.turns = result.get('turns', 0)
            state.drones_used = result.get('drones_used', 0)

        except Exception as e:
            state.success = False
            state.error = str(e)
            logger.error(f"[{state.experiment}/{state.model}#{state.run_number}] Error: {e}")

        state.end_time = datetime.now()
        duration = state.duration_seconds()

        # Log completion
        status_str = "SUCCESS" if state.success else "FAILED"
        victory_str = f"Victory={state.victory}" if state.victory is not None else ""
        rate_str = f"Rate={state.survival_rate:.1%}" if state.survival_rate is not None else ""
        logger.info(
            f"[{state.experiment}/{state.model}#{state.run_number}] "
            f"{status_str} in {duration:.1f}s {victory_str} {rate_str}"
        )

        return state


def save_results_incremental(
    results: List[SessionState],
    output_path: Path,
    config_path: str,
    existing_successful: List[Dict] = None,
    start_time: datetime = None
):
    """Save results incrementally after each session completes."""
    new_results = [
        {
            "experiment": r.experiment,
            "model": r.model,
            "run_number": r.run_number,
            "session_id": r.session_id,
            "success": r.success,
            "victory": r.victory,
            "survival_rate": r.survival_rate,
            "error": r.error,
            "turns": r.turns,
            "drones_used": r.drones_used,
            "duration_seconds": r.duration_seconds(),
        }
        for r in results
    ]

    # Merge with existing successful results
    all_results = (existing_successful or []) + new_results

    duration = (datetime.now() - start_time).total_seconds() if start_time else 0

    results_data = {
        "timestamp": datetime.now().isoformat(),
        "duration_seconds": duration,
        "config": config_path,
        "results": all_results,
    }

    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(results_data, f, indent=2)


def archive_session_file(session_id: str, sweep_dir: Path) -> bool:
    """Archive a single session file to the sweep folder."""
    if not session_id or not sweep_dir:
        return False

    sessions_dest = sweep_dir / "sessions"
    sessions_dest.mkdir(exist_ok=True)

    # Try to find the session file
    src_paths = [
        Path(f"agent_records/sessions/{session_id}.json"),
        Path(f"agent_records/{session_id}.json"),
    ]

    for src_file in src_paths:
        if src_file.exists():
            dest_file = sessions_dest / f"{session_id}.json"
            try:
                shutil.copy2(src_file, dest_file)
                logger.debug(f"Archived session {session_id}")
                return True
            except Exception as e:
                logger.warning(f"Failed to archive {session_id}: {e}")

    return False


async def run_all_sessions(
    tasks: List[Tuple[str, str, int]],  # (experiment, model, run_number)
    base_url: str = "http://localhost:8000",
    max_concurrent: int = 5,
    max_turns: int = 100,
    sweep_dir: Optional[Path] = None,
    output_path: Optional[Path] = None,
    config_path: str = "",
    existing_successful: List[Dict] = None
) -> List[SessionState]:
    """
    Run all sessions concurrently with controlled parallelism.

    Args:
        tasks: List of (experiment, model, run_number) tuples
        base_url: Backend API URL
        max_concurrent: Maximum concurrent sessions
        max_turns: Maximum turns per session
        sweep_dir: Directory to archive session files
        output_path: Path to save results incrementally
        config_path: Config file path for metadata
        existing_successful: Existing successful results to merge
    """
    global _thread_pool

    # Create session states
    states = [
        SessionState(
            experiment=exp,
            model=model,
            run_number=run_num,
            max_turns=max_turns
        )
        for exp, model, run_num in tasks
    ]

    if not states:
        logger.warning("No tasks to run!")
        return []

    # Initialize thread pool with enough workers for max_concurrent sessions
    # This overrides Python's default limit of min(32, cpu_count + 4)
    _thread_pool = ThreadPoolExecutor(max_workers=max_concurrent)
    logger.info(f"Initialized thread pool with {max_concurrent} workers")

    logger.info(f"Starting {len(states)} sessions with max {max_concurrent} concurrent")

    # Create semaphore to limit concurrent sessions
    semaphore = asyncio.Semaphore(max_concurrent)

    # Create coroutines for all sessions
    coroutines = [
        run_session_async(state, semaphore, base_url)
        for state in states
    ]

    # Run with progress tracking and incremental saving
    completed = 0
    results = []
    start_time = datetime.now()

    for coro in asyncio.as_completed(coroutines):
        result = await coro
        completed += 1
        results.append(result)

        # Archive session file immediately
        if sweep_dir and result.session_id:
            archive_session_file(result.session_id, sweep_dir)

        # Save results incrementally
        if output_path:
            save_results_incremental(
                results, output_path, config_path,
                existing_successful, start_time
            )

        # Progress update
        success_count = sum(1 for r in results if r.success)
        victory_count = sum(1 for r in results if r.victory)
        logger.info(
            f"Progress: {completed}/{len(coroutines)} "
            f"(Success: {success_count}, Victory: {victory_count})"
        )

    # Cleanup thread pool
    if _thread_pool:
        _thread_pool.shutdown(wait=False)
        logger.info("Thread pool shutdown")

    return results


def delete_session_from_backend(session_id: str, base_url: str = "http://localhost:8000") -> bool:
    """Delete a session from the backend."""
    try:
        url = f"{base_url}/api/admin/sessions/{session_id}"
        response = requests.delete(url, timeout=10)
        if response.status_code == 200:
            logger.info(f"Deleted failed session: {session_id}")
            return True
        else:
            logger.warning(f"Failed to delete session {session_id}: {response.status_code}")
            return False
    except Exception as e:
        logger.warning(f"Failed to delete session {session_id}: {e}")
        return False


def backup_and_clear_sessions(base_url: str = "http://localhost:8000", sweep_name: str = "unknown") -> int:
    """
    Backup existing sessions and clear them before starting a new sweep.

    This ensures a clean state for each sweep run.
    Sessions are backed up to tmp/backup_sessions/<sweep_name>_<timestamp>/
    """
    agent_records_dir = Path("agent_records/sessions")

    # Create backup directory with sweep name and timestamp
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    backup_dir = Path("tmp/backup_sessions") / f"{sweep_name}_{timestamp}"
    backup_dir.mkdir(parents=True, exist_ok=True)

    # Backup local session files
    backed_up = 0
    if agent_records_dir.exists():
        session_files = list(agent_records_dir.glob("*.json"))
        if session_files:
            logger.info(f"[BACKUP] Found {len(session_files)} existing session files")
            for src in session_files:
                try:
                    dest = backup_dir / src.name
                    shutil.copy2(src, dest)
                    src.unlink()  # Remove after backup
                    backed_up += 1
                except Exception as e:
                    logger.warning(f"Failed to backup {src.name}: {e}")

            if backed_up > 0:
                logger.info(f"[BACKUP] Backed up {backed_up} sessions to {backup_dir}")

    # Clear sessions from backend
    try:
        response = requests.delete(f"{base_url}/api/admin/sessions", timeout=10)
        if response.status_code == 200:
            result = response.json()
            deleted_count = result.get('deleted', 0)
            logger.info(f"[CLEAR] Cleared {deleted_count} sessions from backend")
        else:
            logger.warning(f"Failed to clear backend sessions: {response.status_code}")
    except Exception as e:
        logger.warning(f"Failed to clear backend sessions: {e}")

    return backed_up


def load_existing_results(
    resume_path: str,
    base_url: str = "http://localhost:8000",
    delete_failed: bool = True
) -> Tuple[Dict[str, Any], set]:
    """
    Load existing results and return (data, completed_keys).

    Only tasks that succeeded are considered completed.
    Failed sessions are deleted from the backend so they can be re-run cleanly.

    Args:
        resume_path: Path to results file or directory
        base_url: Backend URL for deleting failed sessions
        delete_failed: Whether to delete failed sessions from backend
    """
    path = Path(resume_path)

    # If it's a directory, find the results file
    if path.is_dir():
        results_file = path / "results.json"
        if not results_file.exists():
            # Try to find async_results_*.json
            async_files = list(path.glob("async_results_*.json"))
            if async_files:
                results_file = max(async_files, key=lambda p: p.stat().st_mtime)
            else:
                logger.warning(f"No results file found in {path}")
                return {}, set()
        path = results_file

    if not path.exists():
        logger.warning(f"Results file not found: {path}")
        return {}, set()

    with open(path) as f:
        data = json.load(f)

    # Build set of completed task keys (only fully successful ones with stage 2 complete)
    # Also collect failed/incomplete sessions to delete
    completed = set()
    failed_sessions = []
    incomplete_sessions = []

    for r in data.get("results", []):
        # A task is only truly completed if:
        # 1. success=True AND
        # 2. stage 2 was completed (victory is not None or survival_rate is not None)
        if r.get("success") is True:
            has_stage2_result = r.get("victory") is not None or r.get("survival_rate") is not None
            if has_stage2_result:
                key = (r["experiment"], r["model"], r.get("run_number", 1))
                completed.add(key)
            elif r.get("session_id"):
                # Success but no stage 2 result - needs to be re-run
                incomplete_sessions.append(r["session_id"])
                logger.info(f"Marking incomplete (no stage2): {r['model']}/{r['experiment']}")
        elif r.get("session_id"):
            # Failed run with a session_id - needs to be deleted
            failed_sessions.append(r["session_id"])

    logger.info(f"Loaded {len(completed)} completed, {len(failed_sessions)} failed, {len(incomplete_sessions)} incomplete tasks from {path}")

    # Delete failed and incomplete sessions from backend
    sessions_to_delete = failed_sessions + incomplete_sessions

    if delete_failed and sessions_to_delete:
        logger.info(f"Deleting {len(sessions_to_delete)} failed/incomplete sessions from backend...")
        deleted = 0
        for session_id in sessions_to_delete:
            if delete_session_from_backend(session_id, base_url):
                deleted += 1
        logger.info(f"Deleted {deleted}/{len(sessions_to_delete)} sessions")

    # Sync backend with local session files - delete any backend sessions not in local
    resume_dir = path.parent if path.is_file() else path
    sessions_dir = resume_dir / "sessions"
    if sessions_dir.exists() and delete_failed:
        local_session_ids = set()
        for f in sessions_dir.glob("*.json"):
            try:
                with open(f) as fp:
                    s = json.load(fp)
                local_session_ids.add(s.get("session_id"))
            except:
                pass

        if local_session_ids:
            # Get backend sessions and delete those not in local
            try:
                response = requests.get(f"{base_url}/api/admin/sessions", timeout=10)
                if response.status_code == 200:
                    backend_sessions = response.json()
                    orphaned = [s["session_id"] for s in backend_sessions
                               if s.get("session_id") and s["session_id"] not in local_session_ids]
                    if orphaned:
                        logger.info(f"Syncing backend: deleting {len(orphaned)} orphaned sessions...")
                        deleted = 0
                        for sid in orphaned:
                            if delete_session_from_backend(sid, base_url):
                                deleted += 1
                        logger.info(f"Deleted {deleted}/{len(orphaned)} orphaned sessions")
            except Exception as e:
                logger.warning(f"Failed to sync backend sessions: {e}")

    return data, completed


def parse_experiment_config(config_path: str) -> Tuple[List[Dict], List[Dict]]:
    """
    Parse experiment config file.

    Returns:
        (experiments, models) - Lists of enabled experiments and models
    """
    with open(config_path) as f:
        exp_config = json.load(f)

    # Get enabled experiments
    experiments = [
        e for e in exp_config.get("experiments", [])
        if e.get("enabled", True)
    ]

    # Handle both model formats (list or dict by provider)
    models_config = exp_config.get("models", {})
    models = []

    if isinstance(models_config, list):
        models = [m for m in models_config if m.get("enabled", True)]
    else:
        for provider, provider_models in models_config.items():
            if provider.startswith("_"):
                continue
            if isinstance(provider_models, list):
                for m in provider_models:
                    if m.get("enabled", True):
                        m["provider"] = provider
                        models.append(m)

    return experiments, models


def load_agent_config() -> Dict[str, Any]:
    """Load agent config from config/agent.json."""
    config_path = Path("config/agent.json")
    if config_path.exists():
        with open(config_path) as f:
            return json.load(f)
    return {}


async def main():
    import argparse

    # Load default max_turns from agent config
    agent_config = load_agent_config()
    default_max_turns = agent_config.get("execution", {}).get("max_turns", 100)

    parser = argparse.ArgumentParser(
        description="Async Multi-Session Agent Runner",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run with default config
  python -m agent.async_runner --config experiment_configs/main_exp.json --workers 5

  # Dry run to see what would run
  python -m agent.async_runner --config experiment_configs/main_exp.json --dry-run

  # Resume interrupted run
  python -m agent.async_runner --resume results/async_results_20250101.json --workers 5
        """
    )
    parser.add_argument(
        "--config", type=str, default="experiment_configs/default.json",
        help="Experiment config file"
    )
    parser.add_argument(
        "--workers", type=int, default=5,
        help="Max concurrent sessions (default: 5)"
    )
    parser.add_argument(
        "--max-turns", type=int, default=default_max_turns,
        help=f"Max turns per session (default: {default_max_turns} from config/agent.json)"
    )
    parser.add_argument(
        "--base-url", type=str, default="http://localhost:8000",
        help="Backend URL"
    )
    parser.add_argument(
        "--dry-run", action="store_true",
        help="Show what would run without running"
    )
    parser.add_argument(
        "--resume", type=str, default=None,
        help="Resume from existing results file or directory"
    )
    parser.add_argument(
        "--output-dir", type=str, default="experiment_results",
        help="Output directory for results (default: experiment_results)"
    )
    parser.add_argument(
        "--no-delete-failed", action="store_true",
        help="Don't delete failed sessions from backend when resuming"
    )
    parser.add_argument(
        "--no-clear", action="store_true",
        help="Don't backup and clear existing sessions before starting (default: clear with backup)"
    )
    args = parser.parse_args()

    # Load existing results if resuming
    existing_data = {}
    completed_tasks = set()
    output_path = None

    if args.resume:
        existing_data, completed_tasks = load_existing_results(
            args.resume,
            base_url=args.base_url,
            delete_failed=not args.no_delete_failed  # Delete failed sessions so they can be re-run
        )
        resume_path = Path(args.resume)
        if resume_path.is_dir():
            output_path = resume_path / "async_results.json"
        else:
            output_path = resume_path

    # Parse experiment config
    experiments, models = parse_experiment_config(args.config)

    if not experiments:
        logger.error("No experiments enabled in config!")
        return

    if not models:
        logger.error("No models enabled in config!")
        return

    # Collect all tasks
    all_tasks = []
    for exp in experiments:
        for model in models:
            runs = model.get("runs", 1)
            for run_num in range(1, runs + 1):
                all_tasks.append((exp["name"], model["name"], run_num))

    # Filter out completed tasks if resuming
    if completed_tasks:
        tasks = [t for t in all_tasks if t not in completed_tasks]
        logger.info(
            f"Resuming: {len(all_tasks)} total, "
            f"{len(completed_tasks)} completed, "
            f"{len(tasks)} remaining"
        )
    else:
        tasks = all_tasks

    logger.info(
        f"Config: {len(experiments)} experiments x {len(models)} models = "
        f"{len(tasks)} tasks"
    )

    # Dry run mode
    if args.dry_run:
        print(f"\n{'='*60}")
        print(f"DRY RUN - Would run {len(tasks)} sessions")
        print(f"{'='*60}")
        print(f"Workers: {args.workers}")
        print(f"Max turns: {args.max_turns}")
        print(f"Backend: {args.base_url}")
        print()

        for exp in experiments:
            total = sum(1 for t in all_tasks if t[0] == exp["name"])
            remaining = sum(1 for t in tasks if t[0] == exp["name"])
            if completed_tasks:
                print(f"  {exp['name']}: {remaining}/{total} tasks remaining")
            else:
                print(f"  {exp['name']}: {total} tasks")

        print()
        print("Models:")
        for model in models:
            runs = model.get("runs", 1)
            print(f"  {model['name']}: {runs} run(s)")

        return

    if not tasks:
        print("All tasks already completed!")
        return

    # Create sweep folder (for new runs, not resume)
    sweep_dir = None
    existing_successful = []

    if args.resume:
        # Resuming - use existing sweep folder if it's a directory
        resume_path = Path(args.resume)
        if resume_path.is_dir():
            sweep_dir = resume_path
        else:
            sweep_dir = resume_path.parent
        # Extract existing successful results for merging (must have stage 2 complete)
        existing_successful = [
            r for r in existing_data.get("results", [])
            if r.get("success") is True and (r.get("victory") is not None or r.get("survival_rate") is not None)
        ]
        logger.info(f"Resuming with {len(existing_successful)} existing successful results (with stage 2)")
    else:
        # New run - create sweep folder and backup/clear sessions
        output_dir = Path(args.output_dir)
        output_dir.mkdir(exist_ok=True)

        # Extract sweep name from config
        with open(args.config) as f:
            config_data = json.load(f)
        sweep_name = config_data.get("sweep_name", "default")
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

        sweep_dir = output_dir / f"sweep_{sweep_name}_{timestamp}"
        sweep_dir.mkdir(parents=True, exist_ok=True)
        (sweep_dir / "sessions").mkdir(exist_ok=True)
        (sweep_dir / "logs").mkdir(exist_ok=True)

        logger.info(f"Created sweep folder: {sweep_dir}")

        # Backup and clear existing sessions (unless --no-clear)
        if not args.no_clear:
            backup_and_clear_sessions(args.base_url, sweep_name)

    # Set output path
    if output_path is None:
        output_path = sweep_dir / "results.json"

    # Run all sessions with incremental saving
    start_time = datetime.now()

    results = await run_all_sessions(
        tasks=tasks,
        base_url=args.base_url,
        max_concurrent=args.workers,
        max_turns=args.max_turns,
        sweep_dir=sweep_dir,
        output_path=output_path,
        config_path=args.config,
        existing_successful=existing_successful
    )

    end_time = datetime.now()
    duration = (end_time - start_time).total_seconds()

    # Summary statistics
    success_count = sum(1 for r in results if r.success)
    victory_count = sum(1 for r in results if r.victory)
    error_count = sum(1 for r in results if r.error)

    print(f"\n{'='*60}")
    print("ASYNC RUN COMPLETE")
    print(f"{'='*60}")
    print(f"Total time: {duration:.1f}s ({duration/60:.1f}m)")
    print(f"Sessions: {len(results)}")
    print(f"Success: {success_count}")
    print(f"Victory: {victory_count}")
    print(f"Errors: {error_count}")
    if duration > 0:
        print(f"Throughput: {len(results)/duration*60:.1f} sessions/min")
    print(f"Sweep folder: {sweep_dir}")
    print(f"{'='*60}")

    # Print per-experiment summary
    print(f"\nPer-experiment summary:")
    for exp in experiments:
        exp_results = [r for r in results if r.experiment == exp["name"]]
        exp_victories = sum(1 for r in exp_results if r.victory)
        exp_success = sum(1 for r in exp_results if r.success)
        print(f"  {exp['name']}: {exp_victories}/{exp_success} victories")


if __name__ == "__main__":
    asyncio.run(main())
