#!/usr/bin/python3
"""
IWPC Warfarin Dosing Task.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] The International Warfarin Pharmacogenetics Consortium. Estimation of
        the warfarin dose with clinical and pharmacogenetic data. New Eng J
        Med 360(8): 753-64. (2009). doi: 10.1056/NEJMoa0809329

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import os
import numpy as np
import pandas as pd
import torch
from copy import deepcopy
from math import isclose
from pathlib import Path
from pydantic import BaseModel
from tabpfn import TabPFNRegressor  # type: ignore
from tabpfn.model.loading import (  # type: ignore
    load_fitted_tabpfn_model, save_fitted_tabpfn_model
)
from typing import Any, Dict, List, NamedTuple, Optional, Type, Union

from .base import BaseTask
from ..data import IWPCWarfarinDataset, IWPCWarfarinPatient


class WarfarinDoseDesign(BaseModel):
    warfarin_dose: float

    def __str__(self) -> str:
        """
        Return a string representation of the design.
        Input:
            None.
        Returns:
            A string representation of the design.
        """
        return f"{self.warfarin_dose:.4f}"

    def __repr__(self) -> str:
        """
        Return a string representation of the design.
        Input:
            None.
        Returns:
            A string representation of the design.
        """
        return str(self) + " mg/week"


class WarfarinDosingTask(BaseTask[IWPCWarfarinDataset]):
    def __init__(
        self,
        task_id: str,
        train: IWPCWarfarinDataset,
        test: IWPCWarfarinDataset,
        seed: int,
        online: Union[float, bool] = False,
        cache_dir: Optional[Union[Path, str]] = (
            Path.home() / ".cache" / "leon" / "checkpoints"
        ),
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            task_id: the string ID of the task.
            train: a dataset of the train patients to fit the train model to.
            test: a dataset of the test patients to fit the test model to.
            seed: random seed.
            online: whether to make the task an online optimization task.
            cache_dir: the directory to cache the fitted models in.
        """
        X_train, y_train = train.data, train.target.astype(float)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        train_path = None
        if cache_dir is not None:
            train_path = os.path.join(
                str(cache_dir),
                f"{self.__class__.__name__}_train_model.tabpfn_fit"
            )

        if train_path is not None and os.path.exists(train_path):
            train_model = load_fitted_tabpfn_model(
                train_path, device=self.device
            )
        else:
            train_model = TabPFNRegressor(
                device=self.device,
                random_state=(seed + sum(map(ord, "train"))),
                ignore_pretraining_limits=True
            )
            train_model.fit(X_train, y_train)
            if train_path is not None:
                save_fitted_tabpfn_model(train_model, train_path)

        class DummyModel:
            def predict(inner_self, x: List[NamedTuple]) -> List[float]:
                best_doses = [
                    self.test._best_dose(xi)  # type: ignore[arg-type]
                    for xi in x
                ]
                doses = [getattr(xi, "warfarin_dose", 0.0) for xi in x]
                scores = [
                    -1.0 * abs(best_doses[i] - doses[i])
                    for i in range(len(x))
                ]
                return [(sc - self.test._mu) / self.test._std for sc in scores]
        test_model = DummyModel()

        super(WarfarinDosingTask, self).__init__(
            task_id,
            train=train,
            test=test,
            train_model=train_model,
            test_model=test_model,
            seed=seed,
            online=online,
            **kwargs
        )

        doses = np.array([train[i].warfarin_dose for i in range(len(train))])
        self.mu: Dict[str, float] = {"warfarin_dose": doses.mean()}
        self.std: Dict[str, float] = {"warfarin_dose": doses.std()}

    @torch.no_grad()
    def predict(
        self, x: List[IWPCWarfarinPatient], **kwargs: Dict[str, Any]
    ) -> List[float]:
        """
        Evaluate a set of designs according to the test function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The oracle scores associated with the designs.
        """
        assert all(pt.warfarin_dose is not None for pt in x)
        return super(WarfarinDosingTask, self).predict(x, **kwargs)

    @torch.no_grad()
    def __call__(self, x: List[IWPCWarfarinPatient], **kwargs) -> List[float]:
        """
        Evaluate a set of designs according to the training function.
        Input:
            x: the set of designs to evaluate.
        Returns:
            The predicted scores associated with the designs.
        """
        if isclose(float(self.online), 1.0):
            return super(WarfarinDosingTask, self).__call__(x, **kwargs)
        orig_x = deepcopy(x)
        x = [
            pt._replace(medications="; ".join(sorted(pt.medications)))
            for pt in x
        ]
        inp = pd.DataFrame.from_records(x, columns=next(iter(x))._fields)

        if self.train.target_name in inp.columns:
            inp = inp.drop(self.train.target_name, axis=1)

        ypred = super(WarfarinDosingTask, self).__call__(inp, **kwargs)
        if not isinstance(self.online, bool) and not isclose(self.online, 0.0):
            self._Base_Task__is_dirty = False
            ypred = [
                (self.online * _ygt) + ((1.0 - self.online) * _ypred)
                for _ygt, _ypred in zip(self.predict(orig_x, **kwargs), ypred)
            ]
            self._Base_Task__is_dirty = False
        return ypred

    @property
    def design_schema(self) -> Type[BaseModel]:
        """
        The Pydantic model for the designs proposed by an optimizer.
        Input:
            None.
        Returns:
            A Pydantic model type for the designs proposed by an optimizer.
        """
        return WarfarinDoseDesign

    def task_description(self, ablate_distribution_shift: bool = False) -> str:
        """
        The task description for the task.
        Input:
            ablate_distribution_shift: whether to ablate knowledge of the
                distribution shift.
        Returns:
            A string of the task description.
        """
        prefix = ""
        if not ablate_distribution_shift:
            prefix = (
                "The provided design scores are predictions from a model "
                "trained on White patients only, and therefore may not be "
                "accurate for all patients."
            )
        return (
            f"{prefix} Propose an optimal warfarin dose (in mg/week) for the "
            "patient."
        )

    @property
    def ndim(self) -> int:
        """
        The number of dimensions of the design space (not including the fixed
        covariates).
        Input:
            None.
        Returns:
            An integer of the number of dimensions of the design space.
        """
        return 1

    def extend(
        self, x: List[BaseModel], ref: IWPCWarfarinPatient
    ) -> List[IWPCWarfarinPatient]:
        """
        Extend a design or set of designs to include the fixed covariates.
        Input:
            x: the design or set of designs to extend.
            ref: the reference design to extend from.
        Returns:
            A list of designs with the fixed covariates included.
        """
        return [
            ref._replace(**json.loads(design.model_dump_json()))
            for design in x
        ]

    def reduce(self, x: List[Any]) -> List[BaseModel]:
        """
        Reduce a design or set of designs to exclude the fixed covariates.
        Input:
            x: the design or set of designs to reduce.
            ref: the reference design to reduce from.
        Returns:
            A list of designs with the fixed covariates excluded.
        """
        if any(not isinstance(xi, IWPCWarfarinPatient) for xi in x):
            raise ValueError
        return [self.design_schema(warfarin_dose=xi.warfarin_dose) for xi in x]

    @property
    def disease_name(self) -> str:
        """
        The name of the disease associated with the task.
        Input:
            None.
        Returns:
            A string of the disease name.
        """
        return "blood clotting"
