#!/usr/bin/python3
"""
Base oracle and environment implementation and API.

Author(s):
    Anonymized Authors @anonymized-authors

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import abc
import gymnasium as gym
import torch
from math import isclose
from pydantic import BaseModel
from typing import (
    Any, Dict, Final, Generic, List, NamedTuple, Tuple, Type, TypeVar, Union
)

from ..data import BaseDataset


T = TypeVar("T", bound=BaseDataset)


class BaseTask(abc.ABC, gym.Env, Generic[T]):
    def __init__(
        self,
        task_name: str,
        train: T,
        test: T,
        train_model: Any,
        test_model: Any,
        seed: int,
        online: Union[float, bool] = False,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task_name: the name of the task associated with the oracle.
            train: a training dataset to fit the train model to.
            test: a test dataset to fit the test model to.
            train_model: the predictive model trained on the training dataset.
            test_model: the predictive model trained on the test dataset.
            seed: random seed.
            online: whether to make the task an online optimization task.
        """
        self.task_name: Final[str] = task_name
        self.seed: Final[int] = seed
        self.online: Final[Union[float, bool]] = online
        self.train: T = train
        self.test: T = test
        self._test_model: Final[Any] = test_model
        if (isinstance(self.online, bool) and self.online) or isclose(
            float(self.online), 1.0
        ):
            _train_model = test_model
        else:
            _train_model = train_model
        self._train_model: Final[Any] = _train_model

        for key, val in kwargs.items():
            setattr(self, key, val)

        self.__is_dirty: bool = False
        self.mu: Dict[str, float] = NotImplemented
        self.std: Dict[str, float] = NotImplemented

    def step(
        self, x: NamedTuple
    ) -> Tuple[NamedTuple, float, bool, bool, Dict[str, Any]]:
        """
        Runs a single time step of the design interaction with the environment
        consistent with the Gymnasium API.
        Input:
            x: the design or set of design to evaluate.
        Returns:
            observation: the original input design.
            reward: the reward associated with the design.
            terminated: always True since the oracle function is always myopic.
            truncated: always False since there are no truncation conditions.
            info: a dictionary with optional auxiliary diagnostic information.
        """
        return x, self.predict([x])[0], True, False, {}

    @torch.no_grad()
    def predict(self, x: Any, **kwargs: Dict[str, Any]) -> List[float]:
        """
        Evaluate a set of designs according to the test function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The oracle scores associated with the designs.
        """
        if self.__is_dirty:
            raise RuntimeError
        self.__is_dirty = not self.online
        return self._test_model.predict(x, **kwargs)

    @torch.no_grad()
    def __call__(self, x: Any, **kwargs: Dict[str, Any]) -> List[float]:
        """
        Evaluate a set of designs according to the training function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The predicted scores associated with the designs.
        """
        return self._train_model.predict(x, **kwargs)

    @property
    @abc.abstractmethod
    def design_schema(self) -> Type[BaseModel]:
        """
        The Pydantic model for the designs proposed by an optimizer.
        Input:
            None.
        Returns:
            A Pydantic model type for the designs proposed by an optimizer.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def task_description(self, ablate_distribution_shift: bool = False) -> str:
        """
        The task description for the task.
        Input:
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        Returns:
            A string of the task description.
        """
        raise NotImplementedError

    @property
    @abc.abstractmethod
    def ndim(self) -> int:
        """
        The number of dimensions of the design space (not including the fixed
        covariates).
        Input:
            None.
        Returns:
            An integer of the number of dimensions of the design space.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def extend(self, x: List[BaseModel], ref: Any) -> List[Any]:
        """
        Extend a design or set of designs to include the fixed covariates.
        Input:
            x: the design or set of designs to extend.
            ref: the reference design to extend from.
        Returns:
            A list of designs with the fixed covariates included.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def reduce(self, x: List[NamedTuple]) -> List[Any]:
        """
        Reduce a design or set of designs to exclude the fixed covariates.
        Input:
            x: the design or set of designs to reduce.
        Returns:
            A list of designs with the fixed covariates excluded.
        """
        raise NotImplementedError

    @property
    def sampling_bounds(self) -> torch.Tensor:
        """
        The sampling bounds for the task.
        Input:
            None.
        Returns:
            A tensor of the sampling bounds.
        """
        designs = torch.vstack([
            self.train[i].as_tensor() for i in range(len(self.train))
        ])
        bounds = torch.vstack([
            designs.min(dim=0).values, designs.max(dim=0).values
        ])
        return bounds.detach()

    @property
    def disease_name(self) -> str:
        """
        The name of the disease associated with the task.
        Input:
            None.
        Returns:
            A string of the disease name.
        """
        raise NotImplementedError
