"""
Defines the AIBB-sourced dataset class.

Author(s):
    [Anonymous Authors]

Licensed under the MIT License. Copyright 2022 Anonymized Institution.
"""
from collections import defaultdict
from datetime import datetime
import numpy as np
import os
import pandas as pd
from pathlib import Path
import pickle
import pytorch_lightning as pl
import torch
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union

from data.accession import AccessionConverter


class AIBBSample(NamedTuple):
    AIBB_ID: str
    data: Union[Dict[str, Any], torch.Tensor]
    daterange: int


class AIBBDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_dir: Union[Path, str],
        batch_size: int = 16,
        seed: int = 42
    ):
        """
        Args:
            data_dir: data directory with training, validation, and test
                datasets.
            batch_size: batch size. Default 16.
            seed: random seed. Default 42.
        """
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seed = seed
        with open(os.path.join(data_dir, "train_dataset.pkl"), "rb") as train:
            self.train_tabular = pickle.load(train)
        with open(os.path.join(data_dir, "val_dataset.pkl"), "rb") as val:
            self.val_tabular = pickle.load(val)
        with open(os.path.join(data_dir, "test_dataset.pkl"), "rb") as test:
            self.test_tabular = pickle.load(test)

    def train_dataloader(self):
        train_dataset = self.DatasetFromTabular(self.train_tabular, self.seed)
        return torch.utils.data.DataLoader(
            train_dataset, batch_size=self.batch_size
        )

    def val_dataloader(self):
        val_dataset = self.DatasetFromTabular(self.val_tabular, self.seed)
        return torch.utils.data.DataLoader(
            val_dataset, batch_size=self.batch_size
        )

    def test_dataloader(self):
        test_dataset = self.DatasetFromTabular(self.test_tabular, self.seed)
        return torch.utils.data.DataLoader(
            test_dataset, batch_size=self.batch_size
        )

    def num_features(self):
        return len(self.train_dataloader().dataset[0].data) - 1

    class DatasetFromTabular(torch.utils.data.Dataset):
        def __init__(self, df: pd.DataFrame, seed: int = 42):
            super().__init__()
            self.df = df
            self.rng = np.random.RandomState(seed)
            self.idxs = np.arange(0, len(self))
            self.rng.shuffle(self.idxs)

        def __len__(self):
            return self.df.shape[0]

        def __getitem__(self, idx: int) -> pd.Series:
            item = self.df.iloc[self.idxs[idx]]
            features = [
                int(item["RACE_CODE"] == "WHITE"),
                int(item["RACE_CODE"] == "BLACK"),
                int(item["RACE_CODE"] == "ASIAN"),
                int(item["RACE_CODE"] == "AM IND AK NATIVE"),
                int(item["RACE_CODE"] == "HI PAC ISLAND"),
                int(
                    "OTHER" in item["RACE_CODE"].upper() or
                    "UNKNOWN" in item["RACE_CODE"].upper()
                ),
                int(item["Sex"].title() == "Female"),
                float(item["AGE"]),
                item["BP_SYSTOLIC"],
                item["BP_DIASTOLIC"],
                item["HEIGHT_INCHES"],
                item["WEIGHT_LBS"],
                item["LIVER_MEAN_HU"],
                item["SPLEEN_MEAN_HU"],
                item["VISCERAL_METRIC_AREA_MEAN"],
                item["SUBQ_METRIC_AREA_MEAN"],
                item["RESULT_VALUE_NUM"]
            ]
            dummy_aibb_id = item.name
            dummy_date_range = -1
            return AIBBSample(
                dummy_aibb_id,
                torch.tensor(features),
                dummy_date_range
            )


class AIBBDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        accession_converter: AccessionConverter,
        filenames: Sequence[Union[Path, str]] = [],
        cache_path: Optional[Union[Path, str]] = None,
        verbose: bool = True,
        identifier: str = "A1C",
        seed: int = 42
    ):
        """
        Args:
            accession_converter: a mapping from AIBB accession numbers to study
                dates.
            filenames: a sequence of filenames to the raw AIBB data sheets.
            cache_path: an optional path to a cache of previously loaded data
                for faster load times.
            verbose: optional flag for verbose stdout messages. Default True.
            identifier: specifies key for data points. If `A1C`, each A1C
                value is treated as a separate datapoint and patients can
                be represented multiple times. If `AIBB_ID`, only the most
                recent A1C value for a patient is used.
            seed: optional random seed. Default to seconds since epoch.
        """
        self.accession_converter = accession_converter
        self.filenames = filenames
        self.cache_path = cache_path
        self.identifier = identifier
        self.verbose = verbose
        self.aibb_id_a1c_map = {}
        self.aibb_ids = []
        self.rng = np.random.RandomState(seed)
        self.data = {}

        if self.cache_path is not None and os.path.isfile(self.cache_path):
            self._load_from_cache()
        else:
            self._manual_load()
            self.aibb_id_a1c_map = self._make_trainable()
            self.aibb_ids = np.array(list(self.aibb_id_a1c_map.keys()))
            self.rng.shuffle(self.aibb_ids)
            if self.cache_path is not None and self.cache_path != "None":
                self._save_to_cache()

    def _load_from_cache(self) -> None:
        """
        Loads a dataset object from a cache file.
        Input:
            None.
        Returns:
            None. The fields of the dataset are modified directly.
        """
        with open(os.path.abspath(self.cache_path), "rb") as f:
            cache = pickle.load(f)
        self.filenames = cache["filenames"]
        self.identifier = cache["identifier"]
        if self.verbose != cache["verbose"] and self.verbose:
            print(
                f"Verbose flag updated: {self.verbose} to {cache['verbose']}"
            )
        self.verbose = cache["verbose"]
        self.data = cache["data"]
        self.aibb_id_a1c_map = cache["aibb_id_a1c_map"]
        self.aibb_ids = cache["aibb_ids"]
        self.accession_converter = cache["accession_converter"]
        if self.verbose:
            print(f"Loaded dataset from {os.path.abspath(self.cache_path)}")

    def _manual_load(self) -> None:
        """
        Loads a dataset object by parsing self.filenames.
        Input:
            None.
        Returns:
            None. The fields of the dataset are modified directly.
        """
        # Load the raw specified data files.
        for fn in self.filenames:
            # Assume that .txt file inputs are TSV files.
            if os.path.splitext(fn)[-1].lower() == ".txt":
                sep, header = "\t", 0
            # Otherwise, assume that the file is a CSV file.
            else:
                sep, header = ",", "infer"
            aibb_data = pd.read_csv(
                os.path.abspath(fn), sep=sep, header=header
            )
            # Remove values from emergency or inpatient visits.
            if "PATIENT_CLASS" in aibb_data.columns:
                aibb_data = aibb_data[
                    aibb_data.PATIENT_CLASS == "OUTPATIENT"
                ]
            good_columns = AIBBDataset._columns()[os.path.basename(fn)]
            if len(good_columns) == 0 and self.verbose:
                print(
                    f"Warning: {os.path.abspath(fn)} has no specified columns."
                )
            self.data[os.path.basename(fn)] = aibb_data[good_columns].dropna()

        return
        # For the imaging data, convert AIBB Accession numbers to study dates.
        imaging_keys = {
            "steatosis_run_2_merge.csv": "AIBB_ACCESSION",
            "visceral_merge_log_run_11_1_20.csv": "AIBB_ACCESSION"
        }
        for key, val in imaging_keys.items():
            if key == "steatosis_run_2_merge.csv":
                self.data[key].AIBB_ACCESSION = self.data[
                    key
                ].AIBB_ACCESSION.apply(
                    self.accession_converter.get_date
                )
            elif key == "visceral_merge_log_fun_11_1_20.csv":
                self.data[key].AIBB_ACCESSION = self.data[
                    key
                ].AIBB_ACCESSION.apply(
                    self.accession_converter.get_date
                )
            self.data[key].rename(columns={val: "ENC_DATE_SHIFT"})

    def _make_trainable(self) -> Dict[str, pd.DataFrame]:
        """
        Takes a loaded dataset and gradually restricts it into a dataset that
        can actually be used for model training. This function performs the
        following operations:
            (1) Restricts each data type to only the rows with AIBB IDs of
                patients that have at least one complete entry in every data
                type in this AIBBDataset object.
            (2) Restricts each data type to just the row with the most recent
                entry for every AIBB ID.
        Input:
            None.
        Returns:
            A mapping of AIBB IDs to A1C values and dates.
        """
        valid_aibb_ids = None
        # Generate the intersection of all the AIBB IDs.
        for key in sorted(self.data, key=lambda x: self.data[x].shape[0]):
            if valid_aibb_ids is None:
                valid_aibb_ids = set(self.data[key].AIBB_ID.tolist())
                continue
            valid_aibb_ids = valid_aibb_ids & set(
                self.data[key].AIBB_ID.tolist()
            )
        # (1) Restrict each data type to only the rows with valid AIBB IDs.
        for key in self.data.keys():
            self.data[key] = self.data[key][
                self.data[key]["AIBB_ID"].isin(valid_aibb_ids)
            ]
        # (2) Restrict each data type to just the row with the most recent
        #     entry for every remaining AIBB ID. This step is only performed
        #     if self.identifier == `AIBB_ID`.
        if self.identifier.upper() == "AIBB_ID":
            for key in self.data.keys():
                if "ENC_DATE_SHIFT" in self._columns()[key]:
                    sort_key = ["ENC_DATE_SHIFT"]
                elif "ENC_DT_SHIFT" in self._columns()[key]:
                    sort_key = ["ENC_DT_SHIFT"]
                else:
                    sort_key = None

                if sort_key is not None:
                    self.data[key] = self.data[key].sort_values(
                        by=sort_key, ascending=False
                    )
                self.data[key] = self.data[key].drop_duplicates(
                    subset="AIBB_ID", keep="first"
                ).sort_index()
        aibb_id_a1c_map = {}
        a1c_key = "AIBB_A1C_Deidentified_042020.csv"
        for aibb_id in valid_aibb_ids:
            aibb_id_a1c_map[aibb_id] = self.data[a1c_key][
                self.data[a1c_key].AIBB_ID == aibb_id
            ][["RESULT_DATE_SHIFT", "RESULT_VALUE_NUM"]]
        return aibb_id_a1c_map

    def _save_to_cache(self) -> None:
        """
        Saves a dataset object to a specified cache file.
        Input:
            None.
        Returns:
            None.
        """
        cache = {}
        cache["filenames"] = self.filenames
        cache["verbose"] = self.verbose
        cache["data"] = self.data
        cache["aibb_id_a1c_map"] = self.aibb_id_a1c_map
        cache["aibb_ids"] = self.aibb_ids
        cache["identifier"] = self.identifier
        cache["columns"] = AIBBDataset._columns()
        cache["accession_converter"] = self.accession_converter
        with open(os.path.abspath(self.cache_path), "w+b") as f:
            pickle.dump(cache, f)
        if self.verbose:
            print(f"Saved dataset to {os.path.abspath(self.cache_path)}")

    def __len__(self) -> int:
        """
        Retrieves the length of the map-style dataset.
        Input:
            None.
        Returns:
            Length of the map-style dataset.
        """
        return sum([df.shape[0] for _id, df in self.aibb_id_a1c_map.items()])

    def __getitem__(self, idx: int) -> AIBBSample:
        """
        Retrieves a sample from the AIBB dataset.
        Input:
            idx: index of sample to retrieve.
            None.
        Returns:
            A AIBBSample object from the AIBB dataset.
        """
        if self.identifier.upper() == "AIBB_ID":
            aibb_id = self.aibb_ids[idx]
            data = {}
            min_time = None
            max_time = None
            a1c_date = None
            birth_date = None
            hispanic = False
            for key in self.data.keys():
                patient_data = self.data[key][
                    self.data[key]["AIBB_ID"] == aibb_id
                ]
                for header in patient_data.columns.tolist():
                    if header == "RACE_HISPANIC_YN":
                        hispanic = hispanic or bool(
                            int(float(patient_data[header].tolist()[0]))
                        )
                    # Skip over AIBB IDs and AIBB Accession numbers.
                    if header in ["AIBB_ID", "AIBB_ACCESSION"]:
                        continue
                    # We also don't need normal range limits for A1C values.
                    if header in [
                        "VALUE_LOWER_LIMIT_NUM", "VALUE_UPPER_LIMIT_NUM"
                    ]:
                        continue
                    # Record the max time span of the sample data, and also
                    # calculate the age of the patient at the time of the A1C
                    # result.
                    elif header in [
                        "ENC_DATE_SHIFT",
                        "ENC_DT_SHIFT",
                        "RESULT_DATE_SHIFT",
                        "Birth_date_SHIFT"
                    ]:
                        data_date = datetime.strptime(
                            patient_data[header].tolist()[0], "%Y-%m-%d"
                        )
                        if header == "Birth_date_SHIFT":
                            birth_date = data_date
                            continue
                        if min_time is None or data_date < min_time:
                            min_time = data_date
                        if max_time is None or data_date > max_time:
                            max_time = data_date
                        if header == "RESULT_DATE_SHIFT":
                            a1c_date = data_date
                    else:
                        data[header] = patient_data[header].tolist()[0]
        elif self.identifier.upper() == "A1C":
            counter = 0
            for aibb_id in self.aibb_ids:
                aibb_id_count = self.aibb_id_a1c_map[aibb_id].shape[0]
                if counter + aibb_id_count > idx:
                    break
                counter += aibb_id_count
            a1c_date, a1c_val = list(self.aibb_id_a1c_map[aibb_id].to_numpy()[
                idx - counter
            ])
            data = {}
            data["RESULT_VALUE_NUM"] = a1c_val
            a1c_date = datetime.strptime(a1c_date, "%Y-%m-%d")
            min_time, max_time = a1c_date, a1c_date
            birth_date = None
            hispanic = False

            def nearest(items: Sequence[str], pivot: datetime) -> str:
                return datetime.strftime(
                    pd.to_datetime(
                        min(
                            [i for i in items],
                            key=lambda x: abs(
                                datetime.strptime(x, "%Y-%m-%d") - pivot
                            )
                        )
                    ),
                    "%Y-%m-%d"
                )

            for key in self.data.keys():
                patient_data = self.data[key][
                    self.data[key]["AIBB_ID"] == aibb_id
                ]
                if "ENC_DATE_SHIFT" in patient_data.columns:
                    patient_data = patient_data[
                        patient_data.ENC_DATE_SHIFT == nearest(
                            patient_data.ENC_DATE_SHIFT.to_list(), a1c_date
                        )
                    ]
                elif "ENC_DT_SHIFT" in patient_data.columns:
                    patient_data = patient_data[
                        patient_data.ENC_DT_SHIFT == nearest(
                            patient_data.ENC_DT_SHIFT.to_list(), a1c_date
                        )
                    ]
                else:
                    patient_data = patient_data.sample(frac=1).reset_index()
                for header in patient_data.columns.tolist():
                    if header == "index":
                        continue
                    if header == "RACE_HISPANIC_YN":
                        hispanic = hispanic or bool(
                            int(float(patient_data[header].tolist()[0]))
                        )
                    # Skip over AIBB IDs and AIBB Accession numbers.
                    if header in ["AIBB_ID", "AIBB_ACCESSION"]:
                        continue
                    # We also don't need normal range limits for A1C values.
                    if header in [
                        "VALUE_LOWER_LIMIT_NUM", "VALUE_UPPER_LIMIT_NUM"
                    ]:
                        continue
                    # Record the max time span of the sample data, and also
                    # calculate the age of the patient at the time of the A1C
                    # result.
                    elif header in [
                        "ENC_DATE_SHIFT",
                        "ENC_DT_SHIFT",
                        "Birth_date_SHIFT"
                    ]:
                        data_date = datetime.strptime(
                            patient_data[header].tolist()[0], "%Y-%m-%d"
                        )
                        if header == "Birth_date_SHIFT":
                            birth_date = data_date
                            continue
                        if data_date < min_time:
                            min_time = data_date
                        if data_date > max_time:
                            max_time = data_date
                    else:
                        data[header] = patient_data[header].tolist()[0]

        data["AGE"] = a1c_date.year - birth_date.year - (
            (a1c_date.month, a1c_date.day) < (
                birth_date.month, birth_date.day
            )
        )
        if hispanic:
            data["RACE_CODE"] = "HISPANIC"
        data.pop("RACE_HISPANIC_YN", "")
        return AIBBSample(aibb_id, data, (max_time - min_time).days)

    def to_tabular_partitions(
        self, partitions: Sequence[float] = [0.8, 0.1, 0.1]
    ) -> Tuple[pd.DataFrame]:
        """
        Convert the dataset to partitions of training, validation, and test
        sets in table format for compatibility with the PyTorch Tabular
        framework.
        Input:
            partitions: fractional paritions of training, validation, and test
                set sizes. Should sum to 1.0.
        Returns:
            Training, validation, and test dataset DataFrames.
        """
        if sum(partitions) != 1.0:
            raise ValueError(f"Partitions {partitions} do not sum to 1.0.")
        idxs = np.array(list(range(len(self))))
        self.rng.shuffle(idxs)
        train_cutoff = int(partitions[0] * len(self))
        val_cutoff = int((partitions[0] + partitions[1]) * len(self))

        train_table, val_table, test_table = [], [], []
        columns = list(self[0].data.keys())
        for i in idxs[:train_cutoff]:
            train_item_data = self[i].data
            ordered_data = []
            for key in columns:
                ordered_data.append(train_item_data[key])
            train_table.append(ordered_data)
        for j in idxs[train_cutoff:val_cutoff]:
            val_item_data = self[j].data
            ordered_data = []
            for key in columns:
                ordered_data.append(val_item_data[key])
            val_table.append(ordered_data)
        for k in idxs[val_cutoff:]:
            test_item_data = self[k].data
            ordered_data = []
            for key in columns:
                ordered_data.append(test_item_data[key])
            test_table.append(ordered_data)
        train_df = pd.DataFrame(train_table, columns=columns)
        val_df = pd.DataFrame(val_table, columns=columns)
        test_df = pd.DataFrame(test_table, columns=columns)
        return train_df, val_df, test_df

    @staticmethod
    def _columns() -> defaultdict:
        """
        Returns the columns to keep by AIBB filename.
        Input:
            None.
        Returns:
            A dictionary specifying the columns to keep within each AIBB data
            file.
        """
        columns = defaultdict(list)

        columns[
            "AIBB-Release-2020-2.2_phenotype_vitals-BMI-ct-studies.txt"
        ] = ["AIBB_ID", "BMI", "ENC_DATE_SHIFT"]
        columns[
            "AIBB-Release-2020-2.2_phenotype_race_eth-ct-studies.txt"
        ] = ["AIBB_ID", "RACE_CODE", "RACE_HISPANIC_YN"]
        columns[
            "AIBB-Release-2020-2.2_phenotype_demographics-ct-studies.txt"
        ] = ["AIBB_ID", "Sex", "Birth_date_SHIFT"]
        columns["AIBB_A1C_Deidentified_042020.csv"] = [
            "AIBB_ID",
            "RESULT_DATE_SHIFT",
            "RESULT_VALUE_NUM",
            "VALUE_LOWER_LIMIT_NUM",
            "VALUE_UPPER_LIMIT_NUM"
        ]
        columns["AIBB_SBP_Deidentified_042020.csv"] = [
            "AIBB_ID", "BP_SYSTOLIC", "ENC_DATE_SHIFT"
        ]
        columns["AIBB_DBP_Deidentified_042020.csv"] = [
            "AIBB_ID", "BP_DIASTOLIC", "ENC_DATE_SHIFT"
        ]
        columns["AIBB_Height_Deidentified_042020.csv"] = [
            "AIBB_ID", "HEIGHT_INCHES", "ENC_DATE_SHIFT"
        ]
        columns["AIBB_Weight_Deidentified_042020.csv"] = [
            "AIBB_ID", "WEIGHT_LBS", "ENC_DATE_SHIFT"
        ]
        columns["AIBB_Smoking_History_Deidentified_042020.csv"] = [
            "AIBB_ID", "SOCIAL_HISTORY_USE", "ENC_DT_SHIFT"
        ]
        columns["steatosis_run_2_merge.csv"] = [
            "AIBB_ID", "AIBB_ACCESSION", "LIVER_MEAN_HU", "SPLEEN_MEAN_HU"
        ]
        columns["visceral_merge_log_run_11_1_20.csv"] = [
            "AIBB_ID",
            "AIBB_ACCESSION",
            "VISCERAL_METRIC_AREA_MEAN",
            "SUBQ_METRIC_AREA_MEAN"
        ]

        return columns

    @staticmethod
    def get_num_col_names(
        use_clinical: bool = True,
        use_idp: bool = True,
        use_intelligent: bool = False
    ) -> Sequence[str]:
        """
        Retrieves the column names of continuous data fields in the dataset.
        Input:
            use_clinical: whether to include clinical data names.
            use_idp: whether to include image-derived phenotype (IDP) data
                names.
            use_intelligent: whether to use intelligently derived clinical
                variables and IDPs.
        Returns:
            Column names of the appropriate continuous data fields.
        """
        clinical_features = [
            "AGE", "BP_SYSTOLIC", "BP_DIASTOLIC", "WEIGHT_LBS", "HEIGHT_INCHES"
        ]
        if use_intelligent:
            clinical_features = clinical_features[:-2] + ["BMI"]
        idp_features = [
            "LIVER_MEAN_HU",
            "SPLEEN_MEAN_HU",
            "VISCERAL_METRIC_AREA_MEAN",
            "SUBQ_METRIC_AREA_MEAN"
        ]
        if use_intelligent:
            idp_features = idp_features[2:] + ["HEPATIC_FAT"]
        features = []
        if use_clinical:
            features += clinical_features
        if use_idp:
            features += idp_features
        return features

    @staticmethod
    def get_cat_col_names(
        use_clinical: bool = True,
        use_idp: bool = True,
        use_intelligent: bool = False
    ) -> Sequence[str]:
        """
        Retrieves the column names of categorical data fields in the dataset.
        Input:
            use_clinical: whether to include clinical data names.
            use_idp: whether to include image-derived phenotype (IDP) data
                names.
            use_intelligent: whether to use intelligently derived clinical
                variables and IDPs.
        Returns:
            Column names of the appropriate categorical data fields.
        """
        clinical_features = ["RACE_CODE", "Sex"]
        idp_features = []
        features = []
        if use_clinical:
            features += clinical_features
        if use_idp:
            features += idp_features
        return features

    @staticmethod
    def get_target_col_name() -> Sequence[str]:
        """
        Retrieves the column names of output data field in the dataset.
        Input:
            None.
        Returns:
            Column names of the target output data field.
        """
        return ["RESULT_VALUE_NUM"]
