
from tau_trait.envs.base import Env
from tau_trait.envs.telehealth.data import load_data
from tau_trait.envs.telehealth.rules import RULES
from tau_trait.envs.telehealth.tools import ALL_TOOLS
from tau_trait.envs.telehealth.wiki import WIKI
from typing import Optional, Union, Dict
from tau_trait.envs.user import UserStrategy


class MockTelehealthDomainEnv(Env):
    def __init__(
        self,
        user_strategy: Union[str, UserStrategy] = UserStrategy.LLM,
        user_model: str = "gpt-4o",
        user_provider: Optional[str] = None,
        task_split: str = "test",
        task_index: Optional[int] = None,
        trait_dict: Optional[Dict[str, int]] = None,
        endpoint: Optional[str] = None,
    ):
        match task_split:
            case "test":
                from tau_trait.envs.telehealth.tasks_test import TASKS_TEST as tasks
            case "train":
                from tau_trait.envs.telehealth.tasks_train import TASKS_TRAIN as tasks
            case "dev":
                from tau_trait.envs.telehealth.tasks_dev import TASKS_DEV as tasks
            case _:
                raise ValueError(f"Unknown task split: {task_split}")
        super().__init__(
            data_load_func=load_data,
            tools=ALL_TOOLS,
            tasks=tasks,
            wiki=WIKI,
            rules=RULES,
            user_strategy=user_strategy,
            user_model=user_model,
            user_provider=user_provider,
            trait_dict=trait_dict,
            task_index=task_index,
            endpoint=endpoint,
        )
        self.terminate_tools = ["transfer_to_human_support"]