#!/usr/bin/python3
"""
Stanford HIV Drug Database (HIVDB).

Author(s):
    Anonymized Authors @anonymized-authors

Citation(s):
    [1] Tang MW, Liu TF, Shafer RW. The HIVdb system for HIV-1 genotypic
        resistance interpretation. Intervirol 55(2): 98-101. (2012).
        doi: 10.1159/000331998
    [2] de Oliveira T, Shafer RW, Seebregts C. Public database for HIV drug
        resistance in southern Africa. Nature 464(7289): 673. (2010).
        doi: 10.1038/464673c

Licensed under the Apache License, Version 2.0. Copyright Anonymized, Inc. 2025.
"""
import json
import numpy as np
import os
import re
import pandas as pd
import torch
from bs4 import BeautifulSoup
from math import isnan
from pathlib import Path
from selenium import webdriver
from selenium.webdriver import FirefoxOptions
from typing import (
    Any, Dict, Final, List, NamedTuple, Optional, Set, Tuple, Union
)
from urllib.parse import urljoin

from .base import BaseDataset
from .utils import HIVDB_FEATURES


class HIVDBPatient(NamedTuple):
    id_: str
    protease_mutations: Set[str]
    reverse_transcriptase_mutations: Set[str]
    neg_viral_load: Optional[float] = None
    cd4_count: Optional[int] = None
    medication_list: Optional[Set[str]] = None

    def __str__(self) -> str:
        """
        Returns a string representation of the patient.
        Input:
            None.
        Returns:
            A string representation of the patient.
        """
        desc = "Patient newly diagnosed with HIV-1 has the following HIV "
        desc += f"Protease Mutations: {'; '.join(self.protease_mutations)}. "
        desc += "The patient also has the following HIV Reverse Transcriptase "
        desc += f"Mutations: {'; '.join(self.reverse_transcriptase_mutations)}"
        if self.medication_list is not None:
            desc += "\n\nPrescribed Medications: "
            if len(self.medication_list):
                desc += ", ".join(self.medication_list)
            else:
                desc += "None"
        return desc

    def __repr__(self) -> str:
        """
        Returns a string representation of the patient.
        Input:
            None.
        Returns:
            A string representation of the patient.
        """
        return str(self)

    def as_tensor(self) -> torch.Tensor:
        """
        Returns a tensor of the patient's features.
        Input:
            None.
        Returns:
            A tensor of the patient's features.
        """
        assert self.medication_list is not None
        features: List[int] = []
        if self.medication_list is not None:
            features += [
                med in self.medication_list
                for med in HIVDB_FEATURES["medications"]
            ]
        return torch.tensor(features, dtype=torch.int32)

    @classmethod
    def ignored_features(cls) -> List[str]:
        """
        Returns a list of features to ignore.
        Input:
            None.
        Returns:
            A list of features to ignore.
        """
        return [
            "id_",
            "neg_viral_load",
            "cd4_count",
            "protease_mutations",
            "reverse_transcriptase_mutations"
        ]

    @classmethod
    def discrete_features(cls) -> Dict[str, List[Union[str, bool]]]:
        """
        Returns a dictionary of discrete features and their possible values.
        Input:
            None.
        Returns:
            A dictionary of discrete features and their possible values.
        """
        return {med: [False, True] for med in HIVDB_FEATURES["medications"]}

    def to_dict(self) -> Dict[str, Any]:
        """
        One-hot-encodes a dataset for input into a tabular deep learning model.
        Input:
            dataset: the list of patients to encode.
        Returns:
            The encoded dict of features as input.
        """
        protease_seqlength: Final[int] = 99
        reverse_transcriptase_seqlength: Final[int] = 240
        drug_names: Final[List[str]] = HIVDB_FEATURES["medications"]

        cols = sorted(drug_names)
        cols += [f"PR_P{i + 1}" for i in range(protease_seqlength)]
        cols += [
            f"RT_P{i + 1}" for i in range(reverse_transcriptase_seqlength)
        ]
        aa_mut_re = re.compile(r"([A-Za-z])([1-9]\d*)([A-Za-z*]+)")

        drugs = [0] * len(drug_names)
        if self.medication_list is not None:
            for med in sorted(drug_names):
                if med in self.medication_list:
                    drugs[drug_names.index(med)] = 1

        protease = [0] * protease_seqlength
        for mut in self.protease_mutations:
            mut = mut.strip()
            match = aa_mut_re.fullmatch(mut)
            if match is None:
                raise ValueError(mut)

            _, pos, mut = match.groups()
            pos = int(pos)
            if not (1 <= pos <= protease_seqlength):
                raise ValueError(pos)
            protease[pos - 1] = 1

        rt = [0] * reverse_transcriptase_seqlength
        for mut in self.reverse_transcriptase_mutations:
            mut = mut.strip()
            match = aa_mut_re.fullmatch(mut)
            if match is None:
                raise ValueError(mut)

            _, pos, _ = match.groups()
            pos = int(pos)
            if not (1 <= pos <= reverse_transcriptase_seqlength):
                continue
            rt[pos - 1] = 1

        return {key: val for key, val in zip(cols, drugs + protease + rt)}


class HIVDBDataset(BaseDataset):
    study_to_publication_date: Dict[str, int] = {
        "ACTG5241": 2020,
        "ACTG5257": 2014,
        "ACTG5208": 2012,
        "ACTG5202": 2012,
        "ACTGA5142": 2008,
        "ACTGA5095": 2004,
        "ACTG384": 2003,
        "ACTG364": 2003,
        "ACTG306": 2004,
        "ACTG302": 2002,
        "HAVANA": 2002
    }

    base_url: str = "https://hivdb.stanford.edu/clinical_studies/"

    files: List[str] = ["RX.txt", "PR.txt", "RT.txt", "RNA.txt", "CD4.txt"]

    drug_ref_url: str = "https://hivdb.stanford.edu/TCEs/pages/drug_code.html"

    # The normalization factors below are calculated according to the
    # distribution of scores from patients from studies published between 2002
    # and 2008 inclusive only.
    _mu: float = -4.570365314738427

    _std: float = 0.9068654087317167

    def __init__(
        self,
        year_split: Tuple[int, int],
        cachedir: Union[Path, str] = (
            Path.home() / ".cache" / "leon" / "hivdb"
        ),
        seed: Optional[int] = 2025,
        **kwargs: Dict[str, Any]
    ):
        """
        Args:
            year_split: a range of years to filter by. Only patients enrolled
                in trials published between the two years (inclusive) will
                be included in the dataset.
            cachedir: a path to locally cache the dataset to.
            seed: random seed. Default 2025.
        """
        del kwargs
        super(HIVDBDataset, self).__init__(split=year_split)
        self.cachedir = cachedir
        self.__mask_designs = False
        self._rng = np.random.default_rng(seed=seed)
        os.makedirs(self.cachedir, exist_ok=True)

        studies = list(self.study_to_publication_date.keys())
        miny, maxy = year_split
        studies = list(
            filter(
                lambda st: miny <= self.study_to_publication_date[st] <= maxy,
                studies
            )
        )

        ds = [self._load_and_cache_dataset(st) for st in studies]
        self._dataset = {
            key: pd.concat([d[key] for d in ds], ignore_index=True)
            for key in map(lambda fn: fn.split(".")[0], self.files)
        }

        patients: List[str] = []
        for pid in self._dataset["RNA"]["PtID"].tolist():
            if all(
                pid in self._dataset[key]["PtID"].tolist()
                for key in self._dataset.keys()
            ):
                patients.append(pid)
        patients = sorted(list(set(patients)))
        self._rng.shuffle(patients)

        self._drug_abbrs: Final[Dict[str, str]] = self._load_drug_info()

        self._processed_dataset = []
        for ptid in patients:
            patient = self._build_patient(ptid)
            if patient.medication_list is not None and (
                len(patient.medication_list) > 0
            ):
                self._processed_dataset.append(patient)

    def filter(self, idxs: np.ndarray) -> None:
        """
        Filter the dataset by the given indices.
        Input:
            idxs: the indices to include in the filter dataset.
        Returns:
            None.
        """
        self._processed_dataset = [
            x for i, x in enumerate(self._processed_dataset) if i in idxs
        ]

    @property
    def drugs(self) -> List[str]:
        """
        Returns the list of drugs in the HIVDB dataset.
        Input:
            None.
        Returns:
            A list of drug names in the HIVDB dataset.
        """
        return HIVDB_FEATURES["medications"]

    @property
    def data(self) -> pd.DataFrame:
        """
        Returns a DataFrame table of the X values as inputs in the dataset.
        Input:
            None.
        Returns:
            A DataFrame table of the X values as inputs in the dataset.
        """
        return pd.DataFrame([self[i].to_dict() for i in range(len(self))])

    @property
    def target(self) -> np.ndarray:
        """
        Returns a vector of the y values to predict from the dataset.
        Input:
            None.
        Returns:
            A vector of the y values to predict from the dataset.
        """
        return np.array([
            getattr(x, self.target_name) for x in self._processed_dataset
        ])

    @property
    def target_name(self) -> str:
        """
        Returns the name of the target feature in the dataset.
        Input:
            None.
        Returns:
            The string name of the target feature in the dataset.
        """
        return "neg_viral_load"

    def _load_and_cache_dataset(self, study: str) -> Dict[str, pd.DataFrame]:
        """
        Loads and (optionally) caches a dataset associated with a clinical
        trial in the HIVDB dataset.
        Input:
            study: the name of the study to load.
        Returns:
            A dictionary containing the loaded datasets for the clinical trial.
        """
        assert study in self.study_to_publication_date.keys()
        dfs = {}
        for fn in self.files:
            loaded = False
            key = fn.split(".")[0]
            cachepath = None
            if self.cachedir is not None:
                cachepath = os.path.join(self.cachedir, f"{study}_{fn}")
                if os.path.isfile(cachepath):
                    dfs[key] = pd.read_csv(cachepath)
                    loaded = True

            if not loaded:
                dfs[key] = pd.read_csv(
                    urljoin(self.base_url, f"{study}/{study}_{fn}"), sep="\t"
                )

            if self.cachedir is not None and not loaded:
                dfs[key].to_csv(cachepath, index=False)

        return dfs

    def __len__(self) -> int:
        """
        Returns the number of patients in the dataset.
        Input:
            None.
        Returns:
            The number of patients in the dataset.
        """
        return len(self._processed_dataset)

    def __getitem__(self, idx: int) -> HIVDBPatient:
        """
        Returns a specified patient from the dataset.
        Input:
            idx: the index of the patient to retrieve from the dataset.
        Returns:
            The requested patient from the dataset.
        """
        item = self._processed_dataset[idx]
        if self.__mask_designs:
            return item._replace(
                neg_viral_load=None, cd4_count=None, medication_list=None
            )
        return item

    def relabel(self, y: List[float]) -> None:
        """
        Relabels the objective values in the training dataset.
        Input:
            y: the computed objective values to use for relabelling.
        Returns:
            None.
        """
        assert len(y) == len(self._processed_dataset)
        self._processed_dataset = [
            x._replace(neg_viral_load=y[i])
            for i, x in enumerate(self._processed_dataset)
        ]

    def mask_designs(self) -> None:
        """
        Masks the designs available in the test dataset.
        Input:
            None.
        Returns:
            None.
        """
        self.__mask_designs = True

    def _build_patient(self, ptid: str) -> HIVDBPatient:
        """
        Builds a patient entity from the raw, pre-processed dataset.
        Input:
            ptid: the specific patient ID to build the dataset entry for.
        Returns:
            The requested patient entity.
        """
        # 1. Anchor on the earliest viral-load date.
        rna = self._dataset["RNA"][self._dataset["RNA"]["PtID"] == ptid]
        assert not rna.empty
        anchor = rna.loc[rna["RNADate"].idxmin()]
        anchor_date = anchor["RNADate"].item()
        viral_load = anchor["VLoad"].item()
        if "<" in anchor["VLoadMatch"]:
            viral_load = self._rng.uniform(low=0.0, high=viral_load)
        neg_viral_load: float = self.normalize(  # type: ignore
            -1.0 * viral_load
        )

        # 2. Find the closest CD4 count measurement (either before or after
        #    the anchor date).
        cd4 = self._dataset["CD4"][self._dataset["CD4"]["PtID"] == ptid].copy()
        assert not cd4.empty
        cd4["delta"] = (cd4["CD4Date"] - anchor_date).abs()
        cd4_count = int(cd4.loc[cd4["delta"].idxmin()]["CD4Count"])

        # 3. Find the drug regimen with most recent start date (before the
        #    anchor date).
        rx = self._dataset["RX"][self._dataset["RX"]["PtID"] == ptid]
        assert not rx.empty
        active = rx[rx["StartDate"] <= anchor_date]
        medications = []
        if not active.empty:
            medications = [
                self._drug_abbrs[drug]
                for drug in HIVDB_FEATURES["medications"]
                if active.loc[active["StartDate"].idxmax()][drug] == 1 and (
                    drug in self._drug_abbrs.keys()
                )
            ]

        # 4. Find the most recent protease sequence (before the anchor date).
        pr = self._dataset["PR"][self._dataset["PR"]["PtID"] == ptid]
        assert not pr.empty
        pr = pr[pr["IsolateDate"] <= anchor_date]
        p_muts_str = ""
        if not pr.empty:
            p_muts_str = pr.loc[pr["IsolateDate"].idxmax()]["MutList"]
            if isinstance(p_muts_str, float) and isnan(p_muts_str):
                p_muts_str = ""
        p_muts = [mut.strip() for mut in p_muts_str.split(",")]
        if p_muts == [""]:
            p_muts = []
        common_muts = list(
            map(
                lambda mut: int(mut.replace("PR_P", "", 1)),
                HIVDB_FEATURES["protease_mutations"]
            )
        )
        p_muts = list(
            filter(
                lambda mut: (
                    int(str("".join(filter(str.isdigit, mut)))) in common_muts
                ),
                p_muts
            )
        )

        # 5. Find the most recent reverse transcriptase mutations (before the
        #    anchor date).
        rt = self._dataset["RT"][self._dataset["RT"]["PtID"] == ptid]
        assert not rt.empty
        rt = rt[rt["IsolateDate"] <= anchor_date]
        rt_muts_str = ""
        if not rt.empty:
            rt_muts_str = rt.loc[rt["IsolateDate"].idxmax()]["MutList"]
            if isinstance(rt_muts_str, float) and isnan(rt_muts_str):
                rt_muts_str = ""
        rt_muts = [mut.strip() for mut in rt_muts_str.split(",")]
        if rt_muts == [""]:
            rt_muts = []
        common_muts = list(
            map(
                lambda mut: int(mut.replace("RT_P", "", 1)),
                HIVDB_FEATURES["reverse_transcriptase_mutations"]
            )
        )
        rt_muts = list(
            filter(
                lambda mut: (
                    int(str("".join(filter(str.isdigit, mut)))) in common_muts
                ),
                rt_muts
            )
        )

        return HIVDBPatient(
            id_=ptid,
            neg_viral_load=neg_viral_load,
            cd4_count=cd4_count,
            medication_list=set(medications),
            protease_mutations=set(p_muts),
            reverse_transcriptase_mutations=set(rt_muts)
        )

    def _load_drug_info(self) -> Dict[str, str]:
        """
        Loads the drug names associated with the drug abbreviations in the
        dataset.
        Input:
            None
        Returns:
            A dictionary of the HIV drug abbreviations and corresponding names.
        """
        cache_fn = os.path.join(self.cachedir, "drugs.json")
        if os.path.isfile(cache_fn):
            with open(cache_fn) as f:
                return json.load(f)
        opts = FirefoxOptions()
        opts.add_argument("--headless")
        driver = webdriver.Firefox(options=opts)
        driver.get(self.drug_ref_url)
        soup = BeautifulSoup(driver.page_source, "html.parser")

        drug_dict = {}
        for table in soup.find_all(  # type: ignore
            "table", class_="drug_code_table"
        ):
            headers = [
                th.get_text(strip=True)
                for th in table.find_all("th")  # type: ignore
            ]
            if "Generic Names" not in headers:
                continue

            for row in table.find_all("tr")[1:]:  # type: ignore
                cells = row.find_all("td")  # type: ignore
                if len(cells) < 3:
                    continue
                abbreviations = [
                    s.strip()
                    for s in cells[0].get_text(separator="\n").split("\n")
                    if s.strip()
                ]
                generics = [
                    s.strip()
                    for s in cells[1].get_text(separator="\n").split("\n")
                    if s.strip()
                ]
                brands = [
                    s.strip()
                    for s in cells[2].get_text(separator="\n").split("\n")
                    if s.strip()
                ]
                for abbr, gen, br in zip(abbreviations, generics, brands):
                    drug_dict[abbr] = f"{gen} ({br})"

        for abbr in list(drug_dict.keys()):
            if "(" not in abbr or ")" not in abbr:
                continue
            name1, name2 = abbr.rsplit(")", 1)[0].split("(", 1)
            drug_dict[name1.strip()] = drug_dict[abbr]
            drug_dict[name2.strip()] = drug_dict[abbr]

        with open(cache_fn, "w") as f:
            json.dump(drug_dict, f, indent=2, ensure_ascii=True)
        return drug_dict
