import asyncio
import json
import yaml
import os
import time
import argparse
from datetime import datetime
from typing import List, Dict

from core.history_manager import HistoryManager
from single import run_single_case

async def safe_run_case(semaphore: asyncio.Semaphore, idx: int, item: Dict, config: Dict, history_manager: HistoryManager, save_dir: str):

    dataset_id = str(item.get("meta", {}).get("id") or item.get("id", "") or "unknown")
    async with semaphore:
        print(f">>> [Batch] Starting Case ID: {dataset_id}")
        try:
            await run_single_case(idx, item, config, history_manager, save_dir)
        except Exception as e:
            print(f"!!! [Batch Error] Case {dataset_id} failed: {e}")
        finally:
            print(f"<<< [Batch] Finished Case ID: {dataset_id}")

async def run_batch(dataset_path: str, config_path: str, max_concurrency: int, start_idx: int, end_idx: int, instant_time: str):

    print(f">>> [System] Loading Config from {config_path}...")

    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)

    print(f">>> [System] Loading Dataset from {dataset_path}...")
    with open(dataset_path, 'r', encoding='utf-8') as f:
        dataset = json.load(f)


    end_idx = min(end_idx, len(dataset))

    dataset = dataset[start_idx:end_idx]
    print(f">>> [System] Found {len(dataset)} cases. Max concurrency set to: {max_concurrency}")

    save_dir = os.path.join('data', f'{config_path.split("/")[-1].split(".")[0]}', f'{dataset_path.split("/")[-1].split(".")[0]}', f'{start_idx}_{end_idx}_{instant_time}')
    history_manager = HistoryManager(save_dir=save_dir)

    semaphore = asyncio.Semaphore(max_concurrency)

    tasks = []
    for idx, item in enumerate(dataset):

        task = asyncio.create_task(
            safe_run_case(semaphore, idx, item, config, history_manager, save_dir)
        )
        tasks.append(task)

    await asyncio.gather(*tasks, return_exceptions=True)

    print("\n" + "="*60)
    print(">>> [Batch] All batch tasks completed.")

if __name__ == "__main__":
    start_time = time.time()
    parser = argparse.ArgumentParser(description="Async Batch Runner for Organizer Agent")
    parser.add_argument("--data", type=str, default="data/source/toolace_11295.json", help="Path to the dataset JSON file")
    parser.add_argument("--config", type=str, default="configs/80a3_explorer_toolace.yaml", help="Path to the config YAML file")
    parser.add_argument("--concurrency", type=int, default=100, help="Max concurrent tasks (default: 5)")
    parser.add_argument("--instant_time", type=str, default=f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", help="Instant time of starting")
    parser.add_argument("--start_idx", type=int, default=0, help="Instant time of starting")
    parser.add_argument("--end_idx", type=int, default=11000, help="Instant time of starting")
    
    args = parser.parse_args()
    

    asyncio.run(run_batch(args.data, args.config, args.concurrency, args.start_idx, args.end_idx, args.instant_time))

    end_time = time.time()
    print(f'{end_time-start_time:>8.2f} s')