import numpy as np
import torch
import torch.utils.data as data

from utils import OneHotEncoder


class PhysioNet2012FromRaindrop(data.Dataset):
    """
    GMAN-compatible dataset that builds per-biomarker graphs from Raindrop's
    preprocessed P12 patient dictionaries (in-memory), ensuring identical splits
    and preprocessing sources as Raindrop baselines.

    Each sample returns (graph_dict, label) where graph_dict maps biomarker name
    to a tuple of (node_features, distance_matrix, label_placeholder) matching
    the existing GMAN collate and model expectations.
    """

    def __init__(
        self,
        patients,
        labels,
        biom_one_hot_embedder,
        predictive_label: str = 'mortality',
        los_threshold_days: int = 3,
    ):
        self.patients = patients
        self.labels = labels

        # Order must match Raindrop's processed feature ordering (36 biomarkers)
        self.biomarker_features = [
            'ALP', 'ALT', 'AST', 'Albumin', 'BUN', 'Bilirubin', 'Cholesterol', 'Creatinine',
            'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'MAP',
            'MechVent', 'Mg', 'NIDiasABP', 'NIMAP', 'NISysABP', 'Na', 'PaCO2', 'PaO2',
            'Platelets', 'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT',
            'Urine', 'WBC', 'pH'
        ]

        # GMAN uses these three static features
        # We'll reconstruct them from Raindrop's extended_static (9 dims):
        # ['Age','Gender=0','Gender=1','Height','ICUType=1','ICUType=2','ICUType=3','ICUType=4','Weight']
        self.static_feature_names = ['Age', 'Gender', 'ICUType']

        self.predictive_label = predictive_label
        self.los_threshold_days = los_threshold_days

        self.biom_encoder = OneHotEncoder(self.biomarker_features)
        self.biom_one_hot_embedder = biom_one_hot_embedder

    def __len__(self):
        return len(self.patients)

    def __getitem__(self, idx):
        sample = self.patients[idx]

        # Label: ensure float tensor of shape [1]
        # Raindrop provides y as shape [1] per sample when indexed
        raw_y = self.labels[idx]
        if isinstance(raw_y, np.ndarray):
            if raw_y.ndim == 0:
                y_val = float(raw_y)
            else:
                y_val = float(raw_y.squeeze())
        else:
            y_val = float(raw_y)

        label_tensor = torch.tensor([y_val], dtype=torch.float)

        arr = sample['arr']  # [T, F]
        time_vec = sample['time']  # [T, 1]
        static_ext = sample['extended_static']  # [9]

        # Reconstruct required static features
        age = float(static_ext[0]) if len(static_ext) > 0 else 0.0
        # Gender single value: use Gender=1 (male) if available else 0/1 from the two-hot
        gender_val = 1.0 if (len(static_ext) > 2 and static_ext[2] == 1) else 0.0
        # ICUType as single integer category 1..4 derived from one-hot positions 4..7
        icu_type = 0.0
        if len(static_ext) >= 8:
            icu_one_hot = static_ext[4:8]
            if np.any(icu_one_hot == 1):
                icu_type = float(int(np.argmax(icu_one_hot)) + 1)

        static_triplet = [age, gender_val, icu_type]

        graph_dict = {}

        # Ensure arrays
        arr = np.asarray(arr)
        time_vec = np.asarray(time_vec).reshape(-1)

        T, F = arr.shape
        assert F == len(self.biomarker_features), "Unexpected number of biomarkers in arr"

        for biom_idx, biom_name in enumerate(self.biomarker_features):
            values = arr[:, biom_idx]
            # Observed entries per Raindrop convention: value > 0 and not NaN
            obs_mask = np.logical_and(values > 0, ~np.isnan(values))
            if not np.any(obs_mask):
                continue

            biom_values = values[obs_mask].astype(float)
            biom_times_minutes = time_vec[obs_mask].astype(float)

            # One-hot encode biomarker and embed
            biom_one_hot = torch.tensor(self.biom_encoder.encode(biom_name).tolist(), dtype=torch.float)
            biom_embed = self.biom_one_hot_embedder(biom_one_hot).tolist()

            # Node features: [x] + [Age, Gender, ICUType] + biom_embed
            node_features = torch.tensor(
                [[x] + static_triplet + biom_embed for x in biom_values],
                dtype=torch.float
            )

            # Distances in continuous hours (Raindrop stores minutes → convert to hours)
            # time_hours = biom_times_minutes / 60.0
            node_distances = (biom_times_minutes[:, None] - biom_times_minutes[None, :])
            node_distances = np.tril(node_distances, k=0).astype(np.float32)
            node_distances[node_distances == 0] = np.inf
            np.fill_diagonal(node_distances, 0)
            node_distances += 0.1
            node_distances = 1.0 / node_distances

            graph_dict[biom_name] = (
                node_features,
                torch.tensor(node_distances, dtype=torch.float),
                label_tensor,
            )

        return graph_dict, label_tensor



