#!/usr/bin/python3
"""
International Warfarin Pharmacogenetics Consortium (IWPC) Warfarin Dataset.

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] The International Warfarin Pharmacogenetics Consortium. Estimation of
        the warfarin dose with clinical and pharmacogenetic data. New Eng J
        Med 360(8): 753-64. (2009). doi: 10.1056/NEJMoa0809329
    [2] Truda G, Marais P. Evaluating warfarin dosing models on multiple
        datasets with a novel software framework and evolutionary optimisation.
        J Biomed Inform 113: 103634. (2021). doi: 10.1016/j.jbi.2020.103634

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import numpy as np
import pandas as pd
import torch
from warfit_learn import datasets, preprocessing  # type: ignore
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union

from .base import BaseDataset


class IWPCWarfarinPatient(NamedTuple):
    id_: str
    height: float
    weight: float
    race: str
    age: str
    is_current_smoker: bool
    medications: Union[str, Set[str]]
    CYP2C9_Consensus: str
    Imputed_VKORC1: str
    warfarin_dose: Optional[float] = None
    fitness_score: Optional[float] = None

    def __str__(self) -> str:
        """
        Returns a string representation of the patient.
        Input:
            None.
        Returns:
            A string representation of the patient.
        """
        bmi = self.weight / ((self.height / 100.0) * (self.height / 100.0))
        desc = f"{self.race.title()} "
        desc += f"{'current smoker' if self.is_current_smoker else 'patient'} "
        desc += f"(age {self.age} years old) with a BMI of {bmi:.1f}"
        if self.medications:
            if isinstance(self.medications, set):
                meds = ", ".join(med for med in self.medications)
            else:
                meds = self.medications
            desc += f", currently on {meds}"
        desc += f". CYP2C9 Genotype Variant: {self.CYP2C9_Consensus}. "
        desc += f"VKORC1 SNP: {self.Imputed_VKORC1}."

        if self.warfarin_dose is not None:
            desc += f"\n\nWarfarin Dose: {self.warfarin_dose:.1f} mg/week"
        return desc

    def __repr__(self) -> str:
        """
        Returns a string representation of the patient.
        Input:
            None.
        Returns:
            A string representation of the patient.
        """
        return str(self)

    def as_tensor(self) -> torch.Tensor:
        """
        Returns a tensor of the patient's features.
        Input:
            None.
        Returns:
            A tensor of the patient's features.
        """
        assert self.warfarin_dose is not None
        return torch.tensor([self.warfarin_dose], dtype=torch.float32)

    @classmethod
    def ignored_features(cls) -> List[str]:
        """
        Returns a list of features to ignore.
        Input:
            None.
        Returns:
            A list of features to ignore.
        """
        return [
            "id_",
            "height",
            "weight",
            "race",
            "age",
            "is_current_smoker",
            "medications",
            "CYP2C9_Consensus",
            "Imputed_VKORC1",
            "fitness_score"
        ]

    @classmethod
    def discrete_features(cls) -> Dict[str, List[Union[str, bool]]]:
        """
        Returns a dictionary of discrete features and their possible values.
        Input:
            None.
        Returns:
            A dictionary of discrete features and their possible values.
        """
        return {}


class IWPCWarfarinDataset(BaseDataset):
    race_prefix: str = "Race (OMB)_"

    age_prefix: str = "Age_"

    CYP2C9_consensus_prefix: str = "CYP2C9 consensus_"

    imputed_VKORC1_prefix: str = "Imputed VKORC1_"

    enzyme_inducers: Tuple[str, str, str] = (
        "Carbamazepine", "Phenytoin", "Rifampin"
    )

    # The normalization factors below are calculated according to the
    # distribution of scores from White patients only.
    _mu: float = 0.0

    _std: float = 9.315738124764447

    def __init__(self, race_split: str, seed: int = 2025, **kwargs):
        """
        Args:
            race_split: the patient self-reported race to filter by. Must be
                one of [`white`, `non-white`].
            seed: random seed.
        """
        super(IWPCWarfarinDataset, self).__init__(split=race_split, **kwargs)
        self.__mask_designs = False
        assert self.split in ["white", "non-white"]
        self._dataset = preprocessing.prepare_iwpc(
            datasets.load_iwpc(), drop_inr=True
        )
        if str(self.split).lower() == "white":
            self._dataset = self._dataset[self._dataset["Race (OMB)_White"]]
        elif str(self.split).lower() == "non-white":
            self._dataset = self._dataset[~self._dataset["Race (OMB)_White"]]
        else:
            raise NotImplementedError

        self._dataset = self._dataset.sample(frac=1, random_state=seed)

        # Warfarin dosing oracle from the IWPC et al. NEJM (2009).
        _diff: List[float] = []
        for i in range(len(self)):
            if self[i].warfarin_dose is None:
                raise ValueError
            real_dose = self[i].warfarin_dose or self._best_dose(self[i])
            _diff.append(self._best_dose(self[i]) - real_dose)
        diff = np.array(_diff)
        self._y: np.ndarray = self.normalize(  # type: ignore
            -1.0 * np.sqrt(diff * diff)
        )

    def __len__(self) -> int:
        """
        Returns the number of patients in the dataset.
        Input:
            None.
        Returns:
            The number of patients in the dataset.
        """
        return len(self._dataset)

    def __getitem__(self, idx: int) -> IWPCWarfarinPatient:
        """
        Returns a specified patient from the dataset.
        Input:
            idx: the index of the patient to retrieve from the dataset.
        Returns:
            The requested patient from the dataset.
        """
        pt = self._dataset.iloc[idx]

        def ohe_to_label(prefix: str) -> str:
            ohes = filter(
                lambda col: col.startswith(prefix), self._dataset.columns
            )
            label = " ".join(filter(pt.__getitem__, ohes))
            return label.replace(prefix, "")

        features = ["race", "age", "CYP2C9_Consensus", "Imputed_VKORC1"]
        labels = map(
            ohe_to_label,
            [
                self.race_prefix,
                self.age_prefix,
                self.CYP2C9_consensus_prefix,
                self.imputed_VKORC1_prefix
            ]
        )
        kwargs = {key: val for key, val in zip(features, labels)}

        meds = [
            "Amiodarone (Cordarone)",
            "Carbamazepine (Tegretol)",
            "Phenytoin (Dilantin)",
            "Rifampin or Rifampicin"
        ]
        medications = set(
            map(lambda m: m.split(" ", 1)[0], filter(pt.__getitem__, meds))
        )

        fitness_score: Optional[float] = None
        if hasattr(self, "_y") and not self.__mask_designs:
            fitness_score = self._y[idx]
        warfarin_dose: Optional[float] = None
        if not self.__mask_designs:
            warfarin_dose = pt["Therapeutic Dose of Warfarin"].item()

        return IWPCWarfarinPatient(
            id_=f"PT{idx:04}",
            height=pt["Height (cm)"].item(),
            weight=pt["Weight (kg)"].item(),
            is_current_smoker=bool(pt["Current Smoker"].item()),
            medications=medications,
            fitness_score=fitness_score,
            warfarin_dose=warfarin_dose,
            **kwargs
        )

    @property
    def data(self) -> pd.DataFrame:
        """
        Returns a DataFrame table of the X values as inputs in the dataset.
        Input:
            None.
        Returns:
            A DataFrame table of the X values as inputs in the dataset.
        """
        df = pd.DataFrame.from_records(
            [self[i] for i in range(len(self))],
            columns=self[0]._fields
        )
        df["medications"] = df["medications"].apply(
            lambda x: "; ".join(sorted(x))
        )
        return df.drop(self.target_name, axis=1)

    @property
    def target(self) -> np.ndarray:
        """
        Returns a vector of the y values to predict from the dataset.
        Input:
            None.
        Returns:
            A vector of the y values to predict from the dataset.
        """
        return np.array(self._y)

    @property
    def target_name(self) -> str:
        """
        Returns the name of the target feature in the dataset.
        Input:
            None.
        Returns:
            The string name of the target feature in the dataset.
        """
        return "fitness_score"

    def relabel(self, y: Union[List[float], np.ndarray]) -> None:
        """
        Relabels the objective values in the training dataset.
        Input:
            y: the computed objective values to use for relabelling.
        Returns:
            None.
        """
        assert len(y) == len(self._y)
        self._y = y if isinstance(y, np.ndarray) else np.array(y)

    def mask_designs(self) -> None:
        """
        Masks the designs available in the test dataset.
        Input:
            None.
        Returns:
            None.
        """
        self.__mask_designs = True

    def _best_dose(self, x: IWPCWarfarinPatient) -> float:
        """
        Predicts the stable weekly dose of warfarin (in mg) according to the
        pharmacogenetic dosing algorithm from IWPC. (NEJM 2009).
        Input:
            x: the patient to predict the warfarin dose for.
        Returns:
            The predicted stable weekly dose of warfarin for the patient.
        """
        rdose = self._weights["bias"]
        rdose += self._weights["height"] * x.height
        rdose += self._weights["weight"] * x.weight
        if "race_" + x.race in self._weights.keys():
            rdose += self._weights["race_" + x.race]
        rdose += self._weights["age"] * int(
            x.age.split("-")[0].split("+")[0].strip()
        )
        if "CYP2C9 consensus_" + x.CYP2C9_Consensus in self._weights.keys():
            rdose += self._weights["CYP2C9 consensus_" + x.CYP2C9_Consensus]
        if "Imputed VKORC1_" + x.Imputed_VKORC1 in self._weights.keys():
            rdose += self._weights["Imputed VKORC1_" + x.Imputed_VKORC1]
        if any(med in x.medications for med in self.enzyme_inducers):
            rdose += self._weights[self.enzyme_inducers]
        if "Amiodarone" in x.medications:
            rdose += self._weights["Amiodarone"]
        return rdose * rdose

    @property
    def _weights(self) -> Dict[Union[Tuple[str, str, str], str], float]:
        """
        Returns the pharmacogenetic dosing algorithm model parameters.
        Input:
            None.
        Returns:
            A dictionary mapping patient attributes to their respective
            multiplicative weights in the linear regression dosing model.
        """
        return {
            "bias": 5.6044,
            "age": -0.02614,
            "height": 0.0087,
            "weight": 0.0128,
            "Imputed VKORC1_A/G": -0.8677,
            "Imputed VKORC1_A/A": -1.6974,
            "Imputed VKORC1_Unknown": -0.4854,
            "CYP2C9 consensus_*1/*2": -0.5211,
            "CYP2C9 consensus_*1/*3": -0.9357,
            "CYP2C9 consensus_*2/*2": -1.0616,
            "CYP2C9 consensus_*2/*3": -1.9206,
            "CYP2C9 consensus_*3/*3": -2.3312,
            "CYP2C9 consensus_unknown": -0.2188,
            "race_Asian": -0.1092,
            "race_Black or African American": -0.2760,
            "race_Unknown": -0.1032,
            self.enzyme_inducers: 1.1816,
            "Amiodarone": -0.5503
        }
