import asyncio
import time
from typing import Optional

import typer
from dotenv import load_dotenv
from loguru import logger

from utils.exp import (
    Config,
    build_overrides_from_args,
    load_config,
    process_results,
    setup_experiment,
)
from utils.model_enums import ModelNames

load_dotenv()

app = typer.Typer()


async def run_experiment(
    overrides: dict[str, str] = None, config_path: str = "config.yaml"
):
    if overrides is None:
        overrides = {}
    config: Config = load_config(config_path, overrides)

    experiment = setup_experiment(config)

    experiment.prepare_data()

    start_time = time.time()
    result_list = []
    for data in experiment.data_iterator():
        result = await experiment.run_agent(data)
        result_list.append(result)
        experiment.cleanup(query_id=result.get("query_id"))
        await asyncio.sleep(0.2)
    end_time = time.time()
    elapsed_time = end_time - start_time
    logger.info(
        f"Experiment {config.dataset_name}/{config.task.name} used: {elapsed_time:.2f} seconds"
    )
    process_results(result_list, config, experiment, elapsed_time=elapsed_time)


@app.command()
def main(
    task: Optional[str] = None,
    num_test: Optional[int] = None,
    model_name: Optional[ModelNames] = None,  # type: ignore
    dataset_name: Optional[str] = None,
    logs_dir: Optional[str] = None,
    result_dir: Optional[str] = None,
    config_path: str = "config.yaml",
):
    overrides = build_overrides_from_args(
        task=task,
        num_test=num_test,
        model_name=model_name,
        dataset_name=dataset_name,
        logs_dir=logs_dir,
        result_dir=result_dir,
    )

    asyncio.run(run_experiment(overrides, config_path=config_path))


if __name__ == "__main__":
    app()
