import torch
import torch.utils.data as data
import pandas as pd
import numpy as np
import pickle
from uuid import uuid4

from utils import OneHotEncoder

import os

PROJECT_ROOT = "."


class PhysioNet2012(data.Dataset):
    def __init__(
        self,
        files,
        biom_one_hot_embedder,
        save_path=None,
        load_cached_dataset=False,
        predictive_label: str = 'mortality',
        los_threshold_days: int = 3,
    ):
        self.files = files

        self.biomarker_features = ['ALP', 'ALT', 'AST', 'Albumin', 'BUN', 'Bilirubin', 'Cholesterol', 'Creatinine',
             'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'MAP',
             'MechVent', 'Mg', 'NIDiasABP', 'NIMAP', 'NISysABP', 'Na', 'PaCO2', 'PaO2',
             'Platelets', 'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT',
             'Urine', 'WBC', 'pH']

        self.static_features = ['Age', 'Gender', 'ICUType'] 
        
        self.time_variable = "ICULOS"

        self.load_cached_dataset = load_cached_dataset

        self.biom_encoder = OneHotEncoder(self.biomarker_features)
        self.biom_one_hot_embedder = biom_one_hot_embedder

        self.save_path = save_path

        # Labeling options
        assert predictive_label in ['mortality', 'LoS'], "predictive_label must be 'mortality' or 'LoS'"
        self.predictive_label = predictive_label
        self.los_threshold_days = los_threshold_days

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        return self._build_set_of_graphs(file_path)
            
    def _build_set_of_graphs(
        self,
        file_path
        ):

        abs_file_path = os.path.join(PROJECT_ROOT, file_path)
        dataframe = pd.read_csv(abs_file_path, sep="|")

        graph_dict = {}

        dataframe = dataframe.sort_values('ICULOS')
        dataframe = dataframe.dropna(axis=1, how='all')

        # Determine graph-level label
        if self.load_cached_dataset and 'newlabel' in dataframe.columns:
            graph_label = float(dataframe['newlabel'].iloc[0])
            text_label = 'long_stay' if graph_label == 1.0 else 'short_stay'
        else:
            if self.predictive_label == 'mortality':
                graph_label = dataframe['Survival'].iloc[-1]
                text_label = 'survived' if graph_label == 1 else 'no_survived'
            else:
                # Length of Stay classification: 1 if stay > threshold days, else 0
                # ICULOS is in hours; use max or last value as length of stay in ICU
                iculos_last = dataframe[self.time_variable].iloc[-1]
                los_days = float(iculos_last) / 24.0
                graph_label = 1.0 if los_days > float(self.los_threshold_days) else 0.0
                text_label = 'long_stay' if graph_label == 1.0 else 'short_stay'
        static_features = dataframe[self.static_features].iloc[0]

        biomarkers = [col for col in dataframe.columns if col in self.biomarker_features]

        for biomarker in biomarkers:
            biom_data = dataframe[biomarker].dropna()

            biomarker_encoded = self.biom_one_hot_embedder(torch.tensor(self.biom_encoder.encode(biomarker).tolist()).to(torch.float)).tolist()

            node_features = torch.tensor(
                [[x] + static_features.tolist() + biomarker_encoded for x in biom_data],
                dtype=torch.float
            )

            time_dim = (dataframe[self.time_variable][biom_data.index]).to_numpy()
            # Use continuous hours (no integer rounding) to match Raindrop's time resolution
            node_distances = (time_dim[:, None] - time_dim[None, :])
            node_distances = np.tril(node_distances, k=0).astype(np.float32)
            node_distances[node_distances == 0] = np.inf
            np.fill_diagonal(node_distances, 0)

            node_distances += 0.1
            node_distances = 1 / node_distances

            graph_dict[biomarker] = (node_features, torch.tensor(node_distances, dtype=torch.float), torch.tensor([graph_label], dtype=torch.float))

            if self.save_path:
                with open(f"{self.save_path}/{text_label}/{uuid4()}.pkl", 'wb') as f:
                    pickle.dump(graph_dict, f)
        
        return graph_dict, torch.tensor([graph_label], dtype=torch.float)


class PhysioNet2012New(data.Dataset):
    def __init__(

        self,
        files,
        biom_one_hot_embedder,
        save_path=None,
        load_cached_dataset=False,
    ):
        self.files = files

        self.biomarker_features = ['ALP', 'ALT', 'AST', 'Albumin', 'BUN', 'Bilirubin', 'Cholesterol', 'Creatinine',
             'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'MAP',
             'MechVent', 'Mg', 'NIDiasABP', 'NIMAP', 'NISysABP', 'Na', 'PaCO2', 'PaO2',
             'Platelets', 'RespRate', 'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT',
             'Urine', 'WBC', 'pH']
        # self.static_features = ['Age', 'Weight', 'Gender', 'Height', 'ICUType'] 
        self.static_features = ['Age', 'Gender', 'ICUType'] 
        # self.static_features = ["Age","Gender","Unit1","Unit2","HospAdmTime"] Unit1 and Unit2 are sometimes missig
        self.time_variable = "ICULOS"

        self.load_cached_dataset = load_cached_dataset

        self.biom_encoder = OneHotEncoder(self.biomarker_features)
        self.biom_one_hot_embedder = biom_one_hot_embedder

        self.save_path = save_path

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        if self.load_cached_dataset:
            pass
        else:
            return self._build_set_of_graphs(file_path)
            
    def _build_set_of_graphs(
        self,
        file_path
        ):

        abs_file_path = os.path.join(PROJECT_ROOT, file_path)
        dataframe = pd.read_csv(abs_file_path, sep="|")

        graph_dict = {}

        dataframe = dataframe.sort_values('ICULOS')
        dataframe = dataframe.dropna(axis=1, how='all')

        # FFFFFUUUUUUUUCCCCKKKK!!!!!!
        # graph_label = dataframe['Survival'].iloc[-1] 
        # text_label = 'survived' if graph_label == 1 else 'no_survived'

        graph_label = dataframe['newlabel'].iloc[-1]
        text_label = 'long' if graph_label == 1 else 'short'

        static_features = dataframe[self.static_features].iloc[0]

        biomarkers = [col for col in dataframe.columns if col in self.biomarker_features]

        for biomarker in biomarkers:
            biom_data = dataframe[biomarker].dropna()

            biomarker_encoded = self.biom_one_hot_embedder(torch.tensor(self.biom_encoder.encode(biomarker).tolist()).to(torch.float)).tolist()

            node_features = torch.tensor(
                [[x] + static_features.tolist() + biomarker_encoded for x in biom_data],
                dtype=torch.float
            )

            time_dim = (dataframe[self.time_variable][biom_data.index]).to_numpy()
            node_distances = (time_dim[:, None] - time_dim[None, :]).astype(int)
            node_distances = np.tril(node_distances, k=0).astype(np.float32)
            # node_distances[node_distances == 0] = np.inf
            np.fill_diagonal(node_distances, 0)

            # node_distances += 0.1
            # node_distances = 1 / node_distances

            graph_dict[biomarker] = (node_features, torch.tensor(node_distances, dtype=torch.float), torch.tensor([graph_label], dtype=torch.float))

            if self.save_path:
                with open(f"{self.save_path}/{text_label}/{uuid4()}.pkl", 'wb') as f:
                    pickle.dump(graph_dict, f)
        
        return graph_dict, torch.tensor([graph_label], dtype=torch.float)

    def _z_normalise(self, x, mean, std):

        if float(std) == 0.0:
            return 0.0

        return (x - mean) / std

    def _feature_scale(self, x, x_min, x_max):
        if x_min == x_max:
            return 0.0

        return (x - x_min) / (x_max - x_min)