#!/usr/bin/python3
"""
HIV medication combination oracle.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Tang MW, Liu TF, Shafer RW. The HIVdb system for HIV-1 genotypic
        resistance interpretation. Intervirol 55(2): 98-101. (2012).
        doi: 10.1159/000331998
    [2] de Oliveira T, Shafer RW, Seebregts C. Public database for HIV drug
        resistance in southern Africa. Nature 464(7289): 673. (2010).
        doi: 10.1038/464673c

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import os
import numpy as np
import pandas as pd
import pickle
import torch
import torch_frame  # type: ignore
from pandas.api.types import is_numeric_dtype
from pandas.core.indexes.base import Index
from pathlib import Path
from pydantic import BaseModel, ConfigDict, create_model
from sklearn.neural_network import MLPRegressor  # type: ignore
from torch_frame.data import Dataset  # type: ignore
from typing import Any, Dict, Final, List, Optional, Type, Union

from .base import BaseTask
from ..data import HIVDBDataset, HIVDBPatient
from ..data.utils import HIVDB_FEATURES
from ..model import TabularResNet


class HIVMedicationTask(BaseTask[HIVDBDataset]):
    def __init__(
        self,
        task_id: str,
        train: HIVDBDataset,
        test: HIVDBDataset,
        seed: int,
        online: Union[float, bool] = False,
        cache_dir: Optional[Union[Path, str]] = (
            Path.home() / ".cache" / "leon" / "checkpoints"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task_id: the string ID of the task.
            train: a dataset of the train patients to fit the train model to.
            test: a dataset of the test patients to fit the test model to.
            seed: random seed.
            online: whether to make the task an online optimization task.
            cache_dir: the directory to cache the fitted models in.
        """
        # Threshold for the target prediction L1 error.
        self.y_thresh: float = kwargs.get("y_thresh", 0.1)  # type: ignore

        X_train, y_train = train.data, train.target.astype(float)
        X_test, y_test = test.data, test.target.astype(float)
        self.columns: Final[Index] = X_train.columns
        hidden_layer_sizes = kwargs.get("hidden_layer_sizes", [1024, 1024])

        X_test["y_"] = y_test
        self.col_to_stype = {
            key: (
                torch_frame.numerical
                if is_numeric_dtype(val) and str(val) != "bool"
                else torch_frame.categorical
            )
            for key, val in X_train.dtypes.items()
        }
        self.col_to_stype["y_"] = torch_frame.numerical
        test_tf = Dataset(
            X_test, col_to_stype=self.col_to_stype, target_col="y_"
        )
        test_tf.materialize()

        self.drug_abbrs: Final[Dict[str, str]] = train._drug_abbrs

        train_path, test_path = None, None
        if cache_dir is not None:
            train_path = os.path.join(
                str(cache_dir), f"{self.__class__.__name__}_train_model.pkl"
            )
            test_path = os.path.join(
                str(cache_dir), f"{self.__class__.__name__}_test_model.pt"
            )

        if train_path is not None and os.path.exists(train_path):
            with open(train_path, "rb") as f:
                train_model = pickle.load(f)
        else:
            train_model = MLPRegressor(
                random_state=(seed + sum(map(ord, "train"))),
                hidden_layer_sizes=hidden_layer_sizes
            )
            train_model.fit(X_train, y_train)
            if train_path is not None:
                os.makedirs(os.path.dirname(train_path), exist_ok=True)
                with open(train_path, "wb") as f:
                    pickle.dump(train_model, f)

        if test_path is not None and os.path.exists(test_path):
            test_model = TabularResNet.load(test_path)
        else:
            test_model = TabularResNet(
                target_name=test.target_name,
                col_stats=test_tf.col_stats,
                col_names_dict=test_tf.tensor_frame.col_names_dict,
                channels=256,
                num_layers=4,
                lr=1e-4,
                epochs=200,
                batch_size=512,
                random_state=(seed + sum(map(ord, "test")))
            )
            test_model.fit(test_tf)
            if test_path is not None:
                test_model.save(test_path)

        if self.y_thresh is not None:
            ypred = test_model.predict([test[i] for i in range(len(test))])
            idxs = np.where(np.abs(np.array(ypred) - y_test) <= self.y_thresh)
            test.filter(idxs[0])

        self.__design_schema: Final[Type[BaseModel]] = (
            self._load_design_schema()
        )

        super(HIVMedicationTask, self).__init__(
            task_id,
            train=train,
            test=test,
            train_model=train_model,
            test_model=test_model,
            seed=seed,
            online=online,
            **kwargs
        )

    @torch.no_grad()
    def predict(
        self, x: List[HIVDBPatient], **kwargs: Dict[str, Any]
    ) -> List[float]:
        """
        Evaluate a set of designs according to the test function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The oracle scores associated with the designs.
        """
        assert all(pt.medication_list is not None for pt in x)
        return super(HIVMedicationTask, self).predict(x, **kwargs)

    @torch.no_grad()
    def __call__(self, x: List[HIVDBPatient], **kwargs) -> List[float]:
        """
        Evaluate a set of designs according to the training function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The predicted scores associated with the designs.
        """
        inp = pd.DataFrame([pt.to_dict() for pt in x])  # type: ignore
        if self.train.target_name in inp.columns:
            inp = inp.drop(self.train.target_name, axis=1)
        return super(HIVMedicationTask, self).__call__(inp, **kwargs)

    def _load_design_schema(self) -> Type[BaseModel]:
        """
        Create a design schema for the HIV medication task.
        Input:
            None.
        Returns:
            A Pydantic model for the design.
        """
        model_type = create_model(  # type: ignore
            "HIVMedicationDesign",
            **{med: (bool, False) for med in HIVDB_FEATURES["medications"]},
            __config__=ConfigDict(use_enum_values=True)
        )

        def design2str(design: BaseModel) -> str:
            meds = [
                f"[{abbr}] {self.drug_abbrs.get(abbr, '')}"
                for abbr, val in design.model_dump().items()
                if val and abbr in self.drug_abbrs
            ]
            if len(meds) == 0:
                return "No active therapy"
            elif len(meds) == 1:
                return meds[0]
            elif len(meds) == 2:
                return f"{meds[0]} and {meds[1]}"
            else:
                return f"{', '.join(meds[:-1])} and {meds[-1]}"

        model_type.__str__ = design2str
        model_type.__repr__ = design2str
        return model_type

    @property
    def design_schema(self) -> Type[BaseModel]:
        """
        The Pydantic model for the designs proposed by an optimizer.
        Input:
            None.
        Returns:
            A Pydantic model type for the designs proposed by an optimizer.
        """
        return self.__design_schema

    def task_description(self, ablate_distribution_shift: bool = False) -> str:
        """
        The task description for the task.
        Input:
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        Returns:
            A string of the task description.
        """
        prefix = ""
        if not ablate_distribution_shift:
            prefix = (
                "The provided design scores are predictions from a model "
                "trained on patients from older studies (before 2008) only, "
                "and therefore may not be accurate for all patients."
            )
        return (
            f"{prefix} Propose an optimal HIV medication regimen for the "
            "patient."
        )

    @property
    def ndim(self) -> int:
        """
        The number of dimensions of the design space (not including the fixed
        covariates).
        Input:
            None.
        Returns:
            An integer of the number of dimensions of the design space.
        """
        return len(self.train.drugs)

    def extend(
        self, x: List[BaseModel], ref: HIVDBPatient
    ) -> List[HIVDBPatient]:
        """
        Extend a design or set of designs to include the fixed covariates.
        Input:
            x: the design or set of designs to extend.
            ref: the reference design to extend from.
        Returns:
            A list of designs with the fixed covariates included.
        """
        return [
            ref._replace(
                medication_list=set(
                    filter(
                        design.model_dump().__getitem__,
                        design.model_dump().keys()
                    )
                )
            )
            for design in x
        ]

    def reduce(self, x: List[Any]) -> List[BaseModel]:
        """
        Reduce a design or set of designs to exclude the fixed covariates.
        Input:
            x: the design or set of designs to reduce.
            ref: the reference design to reduce from.
        Returns:
            A list of designs with the fixed covariates excluded.
        """
        if any(not isinstance(xi, HIVDBPatient) for xi in x):
            raise ValueError

        medications = HIVDB_FEATURES["medications"]

        return [
            self.design_schema(
                **{med: (med in xi.medication_list) for med in medications}
            )
            for xi in x
        ]

    @property
    def disease_name(self) -> str:
        """
        The name of the disease associated with the task.
        Input:
            None.
        Returns:
            A string of the disease name.
        """
        return "HIV"
