from pathlib import Path

import numpy as np
import pandas as pd
import sparse
from xarray import DataArray

from pyrregular.data_utils import data_final_folder, data_original_folder
from pyrregular.io_utils import (
    load_from_file,
    load_yaml,
    read_csv,
    save_to_file,
)
from pyrregular.reader_interface import ReaderInterface


class Physionet2019(ReaderInterface):
    @staticmethod
    def read_original_version(verbose=False):
        return read_physionet2019(verbose=verbose)

    @staticmethod
    def _fix_intermediate_version(data: DataArray, verbose=True) -> DataArray:
        # Each patient has a time series of sepsis label, given that the problem is early classification
        # we label a patient as septic if at least one of the time points is septic (label=1)
        labels = np.nanmax(
            data.sel({"signal_id": "SepsisLabel"}).data.todense(), axis=1
        ).astype(int)
        mapping = {"a": "train", "b": "test"}
        split = data["set"].to_numpy()
        split = [mapping[split[i]] for i in range(len(split))]

        data = data.drop_sel(
            dict(signal_id="SepsisLabel")
        )  # the sepis label is not a signal

        data = data.assign_coords(
            split_default=("ts_id", split), class_default=("ts_id", labels)
        )

        # there are 2 nan instances generated by dropping the sepsis label
        to_drop = np.where(sparse.all(sparse.isnan(data.data), axis=(1, 2)).todense())[
            0
        ]
        to_drop_dict = {"ts_id": data[to_drop]["ts_id"]}

        data = data.drop_sel(to_drop_dict)
        return data


def _dataset_physionet2019(filenames: dict):
    static_columns = ["Age", "Gender", "Unit1", "Unit2", "HospAdmTime"]

    set_a_files = sorted([file for file in filenames["set-a"].glob("*.psv")])
    set_b_files = sorted([file for file in filenames["set-b"].glob("*.psv")])

    dfs_a = pd.DataFrame()
    for i, file in enumerate(set_a_files):
        df = pd.read_csv(file, sep="|")
        df["tid"] = file.stem
        df["set"] = "a"
        dfs_a = pd.concat([dfs_a, df])
        # break

    dfs_b = pd.DataFrame()
    for i, file in enumerate(set_b_files):
        df = pd.read_csv(file, sep="|")
        df["tid"] = file.stem
        df["set"] = "b"
        dfs_b = pd.concat([dfs_b, df])
        # break

    dfs = pd.concat([dfs_a, dfs_b])
    dfs_melted = dfs.melt(id_vars=static_columns + ["tid", "ICULOS"] + ["set"])

    dfs_melted = dfs_melted[dfs_melted["value"].notna()].reset_index(drop=True)
    dfs_melted.replace(np.nan, pd.NA, inplace=True)
    for i in range(len(dfs_melted)):
        yield dfs_melted.iloc[i : i + 1].to_dict(orient="records")[0]


def read_physionet2019(verbose=False):
    attrs = load_yaml(str(data_original_folder() / "physionet2019/attrs.yml"))
    return read_csv(
        filenames={
            "set-a": data_original_folder() / "physionet2019/training/training_setA",
            "set-b": data_original_folder() / "physionet2019/training/training_setB",
        },
        ts_id="tid",
        time_id="ICULOS",
        signal_id="variable",
        value_id="value",
        dims={
            "ts_id": ["Age", "Gender", "Unit1", "Unit2", "HospAdmTime", "set"],
            "signal_id": [],
            "time_id": [],
        },
        reader_fun=_dataset_physionet2019,
        attrs=attrs,
        verbose=verbose,
        time_index_as_datetime=False,
    )


if __name__ == "__main__":
    Physionet2019.save_final_version()
    df = Physionet2019.load_final_version()
