#!/usr/bin/python3
"""
LLM Entropy-guided Optimization with kNowledgeable priors (LEON).

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import gymnasium as gym
import os
from gymnasium import pprint_registry
from gymnasium.envs.registration import register, registry, make as _make
from importlib import import_module
from importlib.metadata import version
from math import isclose
from typing import Any, Dict, Union
from . import (
    core, data, embedding, envs, knowledge, model, optim, optimize, utils
)


try:
    __version__ = version(__name__)
except ImportError:
    __version__ = "unknown"


__all__ = [
    "core",
    "data",
    "embedding",
    "envs",
    "knowledge",
    "model",
    "optim",
    "optimize",
    "utils",
    "pprint_registry",
    "make",
    "registry",
    "__version__"
]


if os.getenv("LEON_RETAIN_REGISTRY", "False") == "False":
    registry.clear()


register(
    id="IWPCWarfarin-v0",
    entry_point="leon.envs:WarfarinDosingTask",
    max_episode_steps=1,
    kwargs={
        "dataset": "leon.data:IWPCWarfarinDataset",
        "train_split": "white",
        "test_split": "non-white"
    },
    disable_env_checker=True
)


register(
    id="HIVDB-v0",
    entry_point="leon.envs:HIVMedicationTask",
    max_episode_steps=1,
    kwargs={
        "dataset": "leon.data:HIVDBDataset",
        "train_split": (2002, 2008),
        "test_split": (2002, 2020)
    },
    disable_env_checker=True
)


def make(
    task_name: str,
    seed: int = 2025,
    relabel: bool = True,
    online: Union[float, bool] = False,
    dataset_kwargs: Dict[str, Any] = {},
    env_kwargs: Dict[str, Any] = {}
) -> envs.BaseTask:
    """
    Create an offline optimization task.
    Input:
        task_name: the name of the task to create.
        seed: random seed. Default 2025.
        relabel: whether to relabel the train dataset using the train function.
        online: whether to make the task an online optimization task.
        dataset_kwargs: additional arguments for the dataset constructors.
        env_kwargs: additional arguments for the environment constructor.
    Returns:
        The instantiated optimization task.
    """
    env_spec = gym.spec(task_name)
    module, attribute = env_spec.kwargs["dataset"].split(":", 1)
    dataset_cls = getattr(import_module(module), attribute)

    test_split = env_spec.kwargs["test_split"]
    train_split = (
        env_spec.kwargs["train_split"]
        if isclose(float(online), 1.0) else test_split
    )
    kwargs = {
        "task_id": task_name,
        "train": dataset_cls(train_split, **dataset_kwargs),
        "test": dataset_cls(test_split, **dataset_kwargs),
        "seed": seed,
        "online": online,
        **env_kwargs
    }
    task: envs.BaseTask = _make(task_name, **kwargs).unwrapped  # type: ignore
    if relabel:
        task.train.relabel(
            task([task.train[i] for i in range(len(task.train))])
        )
    if os.getenv("DEBUG", "False") != "True":
        task.test.mask_designs()
    return task
