from copy import deepcopy
import asyncio

from agents.base_agent import BaseAgent
from envs.base_env import BaseEnv
from utils.helper import load_perception_config
from utils.logger import PerceptionLogger, PerceptionInfo

async def run_predict(predict_id, config, env, dataset_i, logger):
    agent = BaseAgent.from_config(config["agent"])
    raw_response, agent_info = await agent.perception(dataset_i['prompt'], env.schema)
    reward = env.get_perception_reward(raw_response, dataset_i['label'])
    predict_info = PerceptionInfo(raw_response, dataset_i['label'], agent_info, reward)
    logger.log_predict(predict_id, predict_info)

async def run_perception(exp_name, model = None):
    config, dataset = load_perception_config(exp_name, model)
    logger = PerceptionLogger(config, exp_name, len(dataset))
    async_mode = config["experiment"].get("async_mode", True)
    config["environment"]["params"]["image_dir"] = ""
    env = BaseEnv.from_config(config["environment"])
    if async_mode:
        batch_size = config['experiment']['batch_size']
        for i in range(0, len(dataset), batch_size):
            batch_indices = range(i, min(i + batch_size, len(dataset)))
            tasks = [run_predict(i, deepcopy(config), env, dataset[str(i)], logger) for i in batch_indices]
            await asyncio.gather(*tasks)
    else:
        for i in range(len(dataset)):
            await run_predict(i, deepcopy(config), env, dataset[str(i)], logger)

    logger.save_results()
    logger.close()
