import os
import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *


class Hyper_Graph_Dataset_biomarker(Dataset):
	def __init__(self, input_source, data_type='data'):
		if 'MIMIC' not in input_source:
			self.sequences = np.load(os.path.join(input_source, data_type + '_sequences.npy'), allow_pickle=True)
			# self.times: [N, 20]
			self.times = self.sequences[:,:,0]
			# self.biomarkers: [N, 20, 21]
			self.biomarkers = self.sequences[:,:,1:22].astype(int)
			# self.train_biomarkers: [N, 19, 21]
			self.train_biomarkers = self.biomarkers[:, :-1, :].astype(int)
			# self.risk_factor: [N, 20, 34]
			self.risk_factor = self.sequences[:,:,22:].astype(int)
			# self.train_risk_factor: [N, 19, 34]
			# self.train_risk_factor = self.risk_factor[:, :-1, :].astype(int)
			# self.labels: [N, 19, 21]
			self.labels = self.biomarkers[:, 1:, :].astype(int)
		else:
			self.biomarkers = np.load(os.path.join(input_source, data_type + '_matrix.npy'), allow_pickle=True)
			# only keep the first '1'
			# self.biomarkers = (np.cumsum(self.biomarkers, axis=1) == 1).astype(int)
			self.train_biomarkers = self.biomarkers[:, :-1, :].astype(int)
			self.risk_factor = np.zeros((self.train_biomarkers.shape[0], 20, 34), dtype=int)
			self.labels = self.biomarkers[:, 1:, :].astype(int)
			self.times = np.load(os.path.join(input_source, data_type + '_time.npy'), allow_pickle=True)
		hyperedges = [
			[2, 43],  # HbA1c_low -> Hypoglycemia
			[45, 33, 15],  # HbA1c_high -> Neuropathy -> Foot Ulcer
			[45, 14, 58],  # HbA1c_high -> Obesity -> Cancer
			[45, 32],  # HbA1c_high -> Nephropathy
			[45, 6, 12, 55],  # HbA1c_high -> Poor Lipid (LDL_HighRisk) -> Hypertension -> Cardiac Revascularization
			[45, 6, 12, 56, 29],
			# HbA1c_high -> Poor Lipid (LDL_HighRisk) -> Hypertension -> Atrial_fibrillation -> Congestive_heart_failure
			[45, 6, 12, 36, 39],
			# HbA1c_high -> Poor Lipid (LDL_HighRisk) -> Hypertension -> Cerebrovascular_Disease -> Stroke

			[45, 6, 48, 55],
			# HbA1c_high -> Poor Lipid (HDL_LowRisk) -> Poor BP (BP_Diastolic_LowRisk) -> Cardiac Revascularization
			[45, 6, 48, 56, 29],
			# HbA1c_high -> Poor Lipid (HDL_LowRisk) -> Poor BP (BP_Diastolic_LowRisk) -> Atrial_fibrillation -> Congestive_heart_failure
			[45, 6, 48, 36, 39],
			# bA1c_high -> Poor Lipid (HDL_LowRisk) -> Poor BP (BP_Diastolic_LowRisk) -> Cerebrovascular_Disease -> Stroke
			[45, 34, 26, 25],  # HbA1c_high -> Retinopathy -> Visual_impairment -> Blindness_and_vision_loss
			[45, 40],  # HbA1c_high -> Depression
			[45, 64],  # HbA1c_high -> Ketoacidosis
		]

		self.mask = (self.times != -1.0)

		self.H = np.zeros((21, 13), dtype=np.float32)

		index_map = {idx: i for i, idx in enumerate(
			[2, 6, 12, 14, 15, 25, 26, 29, 32, 33, 34, 36, 39, 40, 43, 45, 48, 55, 56, 58, 64]
		)}

		for edge_idx, hyperedge in enumerate(hyperedges):
			for node in hyperedge:
				if node in index_map:  # Ensure node is in mapped indices
					self.H[index_map[node], edge_idx] = 1.0  # Mark as connected edge

	def __len__(self):

		return len(self.biomarkers)

	def __getitem__(self, idx):

		return self.times[idx], self.train_biomarkers[idx], self.risk_factor[idx], self.labels[idx], self.mask[idx], self.H

