import logging
import random
from copy import deepcopy
from pathlib import Path

import numpy as np
from agentlab.agents.agent_args import AgentArgs
from agentlab.agents.generic_agent.generic_agent import GenericAgentArgs
from agentlab.agents.generic_agent_hinter.generic_agent import (
    GenericAgentArgs as GenericAgentHinterArgs,
)
from agentlab.agents.tool_use_agent import TaskHint, ToolUseAgentArgs
from agentlab.experiments.study import Study
from bgym import DEFAULT_BENCHMARKS
from browsergym.experiments.benchmark.metadata.utils import task_list_from_metadata, task_metadata
from browsergym.experiments.benchmark.utils import make_env_args_list_from_repeat_tasks
from flask.cli import load_dotenv

from finetuning.src.finetuning.core.benchmarks_splits import TASKS_MINIWOB
from jephhinter.configs import AgentLabRunConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
load_dotenv()


class AgentLabRun:
    def __init__(self, config: AgentLabRunConfig):
        logger.info(f"Initializing AgentLab run with config: {config}")
        self.config = config
        self._configure_agent()
        self._configure_benchmark()

    def _configure_agent(self):
        agent_args = deepcopy(self.config.agent_args)
        if isinstance(agent_args, ToolUseAgentArgs):
            agent_args.config.obs.use_som = False
            agent_args.config.multiaction = False
            agent_args.config.task_hint = TaskHint(use_task_hint=self.config.use_task_hint)
            if self.config.use_task_hint:
                agent_args.config.task_hint.hint_db_rel_path = self.config.hint_db_path
                agent_args.config.task_hint.hint_retrieval_mode = self.config.hint_mode
        elif isinstance(agent_args, GenericAgentHinterArgs):
            agent_args.flags.hint_db_path = self.config.hint_db_path  # type: ignore
            agent_args.flags.use_task_hint = self.config.use_task_hint
            agent_args.flags.task_hint_retrieval_mode = self.config.hint_mode  # type: ignore
        else:
            raise ValueError(f"Unsupported agent args type: {type(agent_args)}")
        logger.info(f"Configured agent args: {agent_args}")
        self.agent_args: list[AgentArgs] = [agent_args]

    def _configure_benchmark(self):
        benchmark_name = self.config.benchmark_name
        self.benchmark = DEFAULT_BENCHMARKS[benchmark_name]()
        if benchmark_name == "miniwob":
            self.benchmark.env_args_list = make_env_args_list_from_repeat_tasks(
                task_list=task_list_from_metadata(metadata=task_metadata("miniwob")),
                max_steps=self.config.max_steps,
                n_repeats=self.config.n_repeats,
                seeds_rng=np.random.RandomState(42),
            )

            # selected_tasks = [
            #     "miniwob.book-flight",
            #     "miniwob.count-shape",
            #     "miniwob.form-sequence-2",
            #     "miniwob.number-checkboxes",
            #     "miniwob.search-engine",
            #     "miniwob.stock-market",
            #     "miniwob.use-colorwheel-2",
            #     "miniwob.bisect-angle",
            #     "miniwob.click-menu",
            #     "miniwob.click-scroll-list",
            #     "miniwob.daily-calendar",
            #     "miniwob.drag-items-grid",
            #     "miniwob.grid-coordinate",
            #     "miniwob.hot-cold",
            #     "miniwob.right-angle",
            #     "miniwob.social-media-all",
            # ]
            # selected_tasks = TASKS_MINIWOB
            selected_tasks = []
        elif benchmark_name == "workarena_l1":
            self.benchmark.env_args_list = make_env_args_list_from_repeat_tasks(
                task_list=task_list_from_metadata(metadata=task_metadata("workarena")),
                max_steps=self.config.max_steps,
                n_repeats=self.config.n_repeats,
                seeds_rng=np.random.RandomState(42),
            )

        elif benchmark_name == "webarena_lite":
            self.benchmark.env_args_list = make_env_args_list_from_repeat_tasks(
                task_list=task_list_from_metadata(metadata=task_metadata("webarenalite")),
                max_steps=30,
                n_repeats=self.config.n_repeats or 1,
                seeds_rng=np.random.RandomState(42), # not relevant
            )

            # selected_tasks = [
            #     "workarena.servicenow.all-menu",
            #     "workarena.servicenow.create-hardware-asset",
            #     "workarena.servicenow.create-incident",
            #     "workarena.servicenow.filter-asset-list",
            #     "workarena.servicenow.filter-change-request-list",
            #     "workarena.servicenow.filter-hardware-list",
            #     "workarena.servicenow.filter-incident-list",
            #     "workarena.servicenow.filter-service-catalog-item-list",
            #     "workarena.servicenow.filter-user-list",
            #     "workarena.servicenow.knowledge-base-search",
            #     "workarena.servicenow.order-apple-mac-book-pro15",
            #     "workarena.servicenow.order-apple-watch",
            #     "workarena.servicenow.order-developer-laptop",
            #     "workarena.servicenow.order-development-laptop-p-c",
            #     "workarena.servicenow.order-ipad-mini",
            #     "workarena.servicenow.order-ipad-pro",
            #     "workarena.servicenow.order-sales-laptop",
            #     "workarena.servicenow.sort-asset-list",
            #     "workarena.servicenow.sort-change-request-list",
            #     "workarena.servicenow.sort-hardware-list",
            #     "workarena.servicenow.sort-incident-list",
            #     "workarena.servicenow.sort-service-catalog-item-list",
            #     "workarena.servicenow.sort-user-list",
            #     "workarena.servicenow.multi-chart-value-retrieval",
            #     "workarena.servicenow.single-chart-value-retrieval",
            #     "workarena.servicenow.single-chart-min-max-retrieval",
            # ]
            selected_tasks = []
        else:
            raise ValueError(f"Unsupported benchmark: {benchmark_name}")
        if selected_tasks:
            self.benchmark = self.benchmark.subset_from_list(selected_tasks)
        for env_args in self.benchmark.env_args_list:
            env_args.headless = True

    def run(self):
        if self.config.reproducibility_mode:
            [a.set_reproducibility_mode() for a in self.agent_args]
        if self.config.relaunch:
            study = Study.load_most_recent(root_dir=Path(self.config.exp_root), contains=None)
            study.find_incomplete(include_errors=True)
        else:
            study = Study(self.agent_args, self.benchmark, logging_level_stdout=logging.WARNING)
        logger.info(
            f"Running study with {len(self.benchmark.env_args_list)} environments and {len(study.agent_args)} agents: {study.agent_args}"
        )
        study.run(
            n_jobs=self.config.n_jobs,
            parallel_backend=self.config.backend,
            strict_reproducibility=self.config.reproducibility_mode,
            n_relaunch=self.config.n_relaunch,
            exp_root=Path(self.config.exp_root),
        )
        _, summary, errs = study.get_results()
        logger.info(f"Study results:\n{summary}")
        if self.config.reproducibility_mode:
            study.append_to_journal(strict_reproducibility=True)


# Example usage
if __name__ == "__main__":
    config = AgentLabRunConfig(n_repeats=3)
    agentlab_run = AgentLabRun(config)
    agentlab_run.run()
