import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch.nn.functional as F
import utils as utils
import tarfile
import torch
import numpy as np
import urllib.request
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def download_url(url, root, filename=None):
    """Download a file from a url and place it in root."""
    if filename is None:
        filename = os.path.basename(url.split('?')[0])
    fpath = os.path.join(root, filename)
    os.makedirs(root, exist_ok=True)
    if not os.path.exists(fpath):
        print(f'Downloading {url} to {fpath}')
        urllib.request.urlretrieve(url, fpath)
    return fpath
# Adapted from: https://github.com/rtqichen/time-series-datasets

class PhysioNet(object):

    urls = [
        'https://physionet.org/files/challenge-2012/1.0.0/set-a.tar.gz?download',
        'https://physionet.org/files/challenge-2012/1.0.0/set-b.tar.gz?download',
        'https://physionet.org/files/challenge-2012/1.0.0/set-c.tar.gz?download',
    ]

    # outcome_urls = ['https://physionet.org/files/challenge-2012/1.0.0/Outcomes-a.txt']

    params = [
        'Age', 'Gender', 'Height', 'ICUType', 'Weight', 'Albumin', 'ALP', 'ALT', 'AST', 'Bilirubin', 'BUN',
        'Cholesterol', 'Creatinine', 'DiasABP', 'FiO2', 'GCS', 'Glucose', 'HCO3', 'HCT', 'HR', 'K', 'Lactate', 'Mg',
        'MAP', 'MechVent', 'Na', 'NIDiasABP', 'NIMAP', 'NISysABP', 'PaCO2', 'PaO2', 'pH', 'Platelets', 'RespRate',
        'SaO2', 'SysABP', 'Temp', 'TroponinI', 'TroponinT', 'Urine', 'WBC'
    ]

    params_dict = {k: i for i, k in enumerate(params)}

    labels = [ "SAPS-I", "SOFA", "Length_of_stay", "Survival", "In-hospital_death" ]
    labels_dict = {k: i for i, k in enumerate(labels)}

    def __init__(self, root, download = False,
        quantization = None, n_samples = None, device = torch.device("cpu")):

        self.root = root
        # self.train = train
        self.reduce = "average"
        self.quantization = quantization

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        # if self.train:
        #   data_file = self.training_file
        # else:
        #   data_file = self.test_file
        
        if device == torch.device("cpu"):
            # self.data = torch.load(os.path.join(self.processed_folder, data_file), map_location='cpu')
            # self.labels = torch.load(os.path.join(self.processed_folder, self.label_file), map_location='cpu')
            data_a = torch.load(os.path.join(self.processed_folder, self.set_a), map_location='cpu')
            data_b = torch.load(os.path.join(self.processed_folder, self.set_b), map_location='cpu')
            data_c = torch.load(os.path.join(self.processed_folder, self.set_c), map_location='cpu')
        else:
            data_a = torch.load(os.path.join(self.processed_folder, self.set_a))
            data_b = torch.load(os.path.join(self.processed_folder, self.set_b))
            data_c = torch.load(os.path.join(self.processed_folder, self.set_c))
            # self.data = torch.cat([data_a, data_b, data_c], dim=0)
            # self.labels = torch.load(os.path.join(self.processed_folder, self.label_file))
            # print(len(self.data), len(self.labels.shape))

        self.data = data_a + data_b + data_c # a list with length 12000

        if n_samples is not None:
            print('Total records:', len(self.data))
            self.data = self.data[:n_samples]
            # self.labels = self.labels[:n_samples]


    def download(self):
        if self._check_exists():
            return

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        # Download outcome data
        # for url in self.outcome_urls:
        #   filename = url.rpartition('/')[2]
        #   download_url(url, self.raw_folder, filename, None)

        #   txtfile = os.path.join(self.raw_folder, filename)
        #   with open(txtfile) as f:
        #       lines = f.readlines()
        #       outcomes = {}
        #       for l in lines[1:]:
        #           l = l.rstrip().split(',')
        #           record_id, labels = l[0], np.array(l[1:]).astype(float)
        #           outcomes[record_id] = torch.Tensor(labels).to(self.device)

        #       torch.save(
        #           labels,
        #           os.path.join(self.processed_folder, filename.split('.')[0] + '.pt')
        #       )

        for url in self.urls:
            filename = url.rpartition('/')[2]
            download_url(url, self.raw_folder, filename, None)
            tar = tarfile.open(os.path.join(self.raw_folder, filename), "r:gz")
            tar.extractall(self.raw_folder)
            tar.close()

            print('Processing {}...'.format(filename))

            dirname = os.path.join(self.raw_folder, filename.split('.')[0])
            patients = []
            total = 0
            cnt = 0
            for txtfile in os.listdir(dirname):
                record_id = txtfile.split('.')[0]
                with open(os.path.join(dirname, txtfile)) as f:
                    lines = f.readlines()
                    prev_time = 0
                    tt = [0.]
                    vals = [torch.zeros(len(self.params))]
                    mask = [torch.zeros(len(self.params))]
                    nobs = [torch.zeros(len(self.params))]
                    for l in lines[1:]:
                        total += 1
                        time, param, val = l.split(',')
                        # Time in hours
                        time = float(time.split(':')[0]) + float(time.split(':')[1]) / 60.

                        # round up the time stamps (up to 6 min by default)
                        # used for speed -- we actually don't need to quantize it in Latent ODE
                        if(self.quantization != None and self.quantization != 0):
                            time = round(time / self.quantization) * self.quantization

                        if time != prev_time:
                            tt.append(time)
                            vals.append(torch.zeros(len(self.params)))
                            mask.append(torch.zeros(len(self.params)))
                            nobs.append(torch.zeros(len(self.params)))
                            prev_time = time

                        if param in self.params_dict:
                            #vals[-1][self.params_dict[param]] = float(val)
                            n_observations = nobs[-1][self.params_dict[param]]
                            if self.reduce == 'average' and n_observations > 0:
                                prev_val = vals[-1][self.params_dict[param]]
                                new_val = (prev_val * n_observations + float(val)) / (n_observations + 1)
                                vals[-1][self.params_dict[param]] = new_val
                            else:
                                vals[-1][self.params_dict[param]] = float(val)
                            mask[-1][self.params_dict[param]] = 1
                            nobs[-1][self.params_dict[param]] += 1
                        else:
                            assert (param == 'RecordID' or param ==''), 'Read unexpected param {}'.format(param)
                            if(param != 'RecordID'):
                                cnt += 1
                                print(cnt, param, l)

                tt = torch.tensor(tt).to(self.device)
                vals = torch.stack(vals).to(self.device)
                mask = torch.stack(mask).to(self.device)

                # labels = None
                # if record_id in outcomes:
                #   # Only training set has labels
                #   labels = outcomes[record_id]
                #   # Out of 5 label types provided for Physionet, take only the last one -- mortality
                #   labels = labels[4]

                patients.append((record_id, tt, vals, mask))

            torch.save(
                patients,
                os.path.join(self.processed_folder, 
                    filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
            )
                
        print('Done!')

    def _check_exists(self):
        for url in self.urls:
            filename = url.rpartition('/')[2]

            if not os.path.exists(
                os.path.join(self.processed_folder, 
                    filename.split('.')[0] + "_" + str(self.quantization) + '.pt')
            ):
                return False
        return True

    @property
    def raw_folder(self):
        return os.path.join(self.root, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'processed')

    # @property
    # def training_file(self):
    #   return 'set-a_{}.pt'.format(self.quantization)

    # @property
    # def test_file(self):
    #   return 'set-b_{}.pt'.format(self.quantization)

    @property
    def set_a(self):
        return 'set-a_{}.pt'.format(self.quantization)

    @property
    def set_b(self):
        return 'set-b_{}.pt'.format(self.quantization)
    
    @property
    def set_c(self):
        return 'set-c_{}.pt'.format(self.quantization)

    # @property
    # def label_file(self):
    #   return 'Outcomes-a.pt'

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

    def get_label(self, record_id):
        return self.labels[record_id]

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Split: {}\n'.format('train' if self.train is True else 'test')
        fmt_str += '    Root Location: {}\n'.format(self.root)
        fmt_str += '    Quantization: {}\n'.format(self.quantization)
        fmt_str += '    Reduce: {}\n'.format(self.reduce)
        return fmt_str

    def visualize(self, timesteps, data, mask, plot_name):
        width = 15
        height = 15

        non_zero_attributes = (torch.sum(mask,0) > 2).numpy()
        non_zero_idx = [i for i in range(len(non_zero_attributes)) if non_zero_attributes[i] == 1.]
        n_non_zero = sum(non_zero_attributes)

        mask = mask[:, non_zero_idx]
        data = data[:, non_zero_idx]
        
        params_non_zero = [self.params[i] for i in non_zero_idx]
        params_dict = {k: i for i, k in enumerate(params_non_zero)}

        n_col = 3
        n_row = n_non_zero // n_col + (n_non_zero % n_col > 0)
        fig, ax_list = plt.subplots(n_row, n_col, figsize=(width, height), facecolor='white')

        #for i in range(len(self.params)):
        for i in range(n_non_zero):
            param = params_non_zero[i]
            param_id = params_dict[param]

            tp_mask = mask[:,param_id].long()

            tp_cur_param = timesteps[tp_mask == 1.]
            data_cur_param = data[tp_mask == 1., param_id]

            ax_list[i // n_col, i % n_col].plot(tp_cur_param.numpy(), data_cur_param.numpy(),  marker='o') 
            ax_list[i // n_col, i % n_col].set_title(param)

        fig.tight_layout()
        fig.savefig(plot_name)
        plt.close(fig)


def get_data_min_max_mean(records, device):
    inf = torch.Tensor([float("Inf")])[0].to(device)

    data_min, data_max, time_max = None, None, -inf

    features_sum = 0
    features_num = 0
    x_sum = [0 for i in range(records[0][2].size(-1))]
    x_num = [0 for i in range(records[0][2].size(-1))]
    for b, (record_id, tt, vals, mask) in enumerate(records):
        n_features = vals.size(-1)

        batch_min = []
        batch_max = []
        for i in range(n_features):
            non_missing_vals = vals[:,i][mask[:,i] == 1]
            if len(non_missing_vals) == 0:
                batch_min.append(inf)
                batch_max.append(-inf)
            else:
                batch_min.append(torch.min(non_missing_vals))
                batch_max.append(torch.max(non_missing_vals))
                x_sum[i] += non_missing_vals.sum()
                x_num[i] += len(non_missing_vals)

        batch_min = torch.stack(batch_min)
        batch_max = torch.stack(batch_max)

        if (data_min is None) and (data_max is None):
            data_min = batch_min
            data_max = batch_max
        else:
            data_min = torch.min(data_min, batch_min)
            data_max = torch.max(data_max, batch_max)

        time_max = torch.max(time_max, tt.max())

    # data_min = torch.where(torch.isinf(data_min), 1., data_min)
    # data_max = torch.where(torch.isinf(data_max), 1., data_max)

    x_mean = torch.stack(x_sum) / torch.tensor(x_num).to(x_sum[0].device)
    print('data_max:', data_max)
    print('data_min:', data_min)
    print('time_max:', time_max)

    return data_min, data_max, time_max, x_mean

def get_data_min_max(records, device):
    inf = torch.Tensor([float("Inf")])[0].to(device)

    data_min, data_max, time_max = None, None, -inf

    for b, (record_id, tt, vals, mask) in enumerate(records):
        n_features = vals.size(-1)

        batch_min = []
        batch_max = []
        for i in range(n_features):
            non_missing_vals = vals[:,i][mask[:,i] == 1]
            if len(non_missing_vals) == 0:
                batch_min.append(inf)
                batch_max.append(-inf)
            else:
                batch_min.append(torch.min(non_missing_vals))
                batch_max.append(torch.max(non_missing_vals))

        batch_min = torch.stack(batch_min)
        batch_max = torch.stack(batch_max)

        if (data_min is None) and (data_max is None):
            data_min = batch_min
            data_max = batch_max
        else:
            data_min = torch.min(data_min, batch_min)
            data_max = torch.max(data_max, batch_max)

        time_max = torch.max(time_max, tt.max())

    # data_min = torch.where(torch.isinf(data_min), 1., data_min)
    # data_max = torch.where(torch.isinf(data_max), 1., data_max)

    print('data_max:', data_max)
    print('data_min:', data_min)
    print('time_max:', time_max)

    return data_min, data_max, time_max

def get_seq_length(args, records):
    
    max_input_len = 0
    max_pred_len = 0
    lens = []
    for b, (record_id, tt, vals, mask) in enumerate(records):
        n_observed_tp = torch.lt(tt, args.history).sum()
        max_input_len = max(max_input_len, n_observed_tp)
        max_pred_len = max(max_pred_len, len(tt) - n_observed_tp)
        lens.append(n_observed_tp)
    lens = torch.stack(lens, dim=0)
    median_len = lens.median()

    return max_input_len, max_pred_len, median_len


def patch_variable_time_collate_fn(batch, args, device = torch.device("cpu"), data_type = "train", 
    data_min = None, data_max = None, time_max = None):
    """
    Expects a batch of time series data in the form of (record_id, tt, vals, mask) where
        - record_id is a patient id
        - tt is a (T, ) tensor containing T time values of observations.
        - vals is a (T, D) tensor containing observed values for D variables.
        - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
    Returns:
    Data form as input:
        batch_tt: (B, M, L_in, D) the batch contains a maximal L_in time values of observations among M patches.
        batch_vals: (B, M, L_in, D) tensor containing the observed values.
        batch_mask: (B, M, L_in, D) tensor containing 1 where values were observed and 0 otherwise.
    Data form to predict:
        flat_tt: (L_out) the batch contains a maximal L_out time values of observations.
        flat_vals: (B, L_out, D) tensor containing the observed values.
        flat_mask: (B, L_out, D) tensor containing 1 where values were observed and 0 otherwise.
    """
    valid_batch = [sample for sample in batch if torch.ge(sample[1], args.history).sum() > 0]
    if not valid_batch:
        return None
    batch = valid_batch
    
    D = batch[0][2].shape[1]
    combined_tt, inverse_indices = torch.unique(torch.cat([ex[1] for ex in batch]), sorted=True, return_inverse=True)
    # combined_tt = combined_tt.to(device)
    # print(combined_tt.shape)
    # print(inverse_indices.shape, np.sum([len(ex[1]) for ex in batch]), inverse_indices.max())
    # print(inverse_indices)

    # the number of observed time points 
    n_observed_tp = torch.lt(combined_tt, args.history).sum()
    observed_tp = combined_tt[:n_observed_tp] # (n_observed_tp, )
    # print(n_observed_tp, len(combined_tt)-n_observed_tp)
    # print(combined_tt[:n_observed_tp])
    # print(combined_tt[n_observed_tp:])

    patch_indices = []
    st, ed = 0, args.patch_size
    for i in range(args.npatch):
        if(i == args.npatch-1):
            inds = torch.where((observed_tp >= st) & (observed_tp <= ed))[0]
        else:
            inds = torch.where((observed_tp >= st) & (observed_tp < ed))[0]
        patch_indices.append(inds)
        # print(st, ed, observed_tp[inds[0]: inds[-1]+1])

        st += args.stride
        ed += args.stride

    offset = 0
    combined_vals = torch.zeros([len(batch), len(combined_tt), D]).to(device)
    combined_mask = torch.zeros([len(batch), len(combined_tt), D]).to(device)
    predicted_tp = []
    predicted_data = []
    predicted_mask = [] 
    for b, (record_id, tt, vals, mask) in enumerate(batch):
        # tt = tt.to(device)
        # vals = vals.to(device)
        # mask = mask.to(device)
        indices = inverse_indices[offset:offset+len(tt)]
        offset += len(tt)
        combined_vals[b, indices] = vals
        combined_mask[b, indices] = mask

        tmp_n_observed_tp = torch.lt(tt, args.history).sum()
        predicted_tp.append(tt[tmp_n_observed_tp:])
        predicted_data.append(vals[tmp_n_observed_tp:])
        predicted_mask.append(mask[tmp_n_observed_tp:])

    combined_tt = combined_tt[:n_observed_tp]
    combined_vals = combined_vals[:, :n_observed_tp]
    combined_mask = combined_mask[:, :n_observed_tp]
    predicted_tp = pad_sequence(predicted_tp, batch_first=True)
    predicted_data = pad_sequence(predicted_data, batch_first=True)
    predicted_mask = pad_sequence(predicted_mask, batch_first=True)

    if(args.dataset != 'ushcn'):
        combined_vals = utils.normalize_masked_data(combined_vals, combined_mask, 
            att_min = data_min, att_max = data_max)
        predicted_data = utils.normalize_masked_data(predicted_data, predicted_mask, 
            att_min = data_min, att_max = data_max)

    combined_tt = utils.normalize_masked_tp(combined_tt, att_min = 0, att_max = time_max)
    predicted_tp = utils.normalize_masked_tp(predicted_tp, att_min = 0, att_max = time_max)
    #print(predicted_data.sum(), predicted_tp.sum())
        
    data_dict = {
        "data": combined_vals, # (n_batch, T_o, D)
        "time_steps": combined_tt, # (T_o, )
        "mask": combined_mask, # (n_batch, T_o, D)
        "data_to_predict": predicted_data,
        "tp_to_predict": predicted_tp,
        "mask_predicted_data": predicted_mask,
        }

    data_dict = utils.split_and_patch_batch(data_dict, args, n_observed_tp, patch_indices)
    # print("patchdata:", data_dict["data_to_predict"].sum(), data_dict["mask_predicted_data"].sum())

    return data_dict


def variable_time_collate_fn(batch, args, device = torch.device("cpu"), data_type = "train", 
    data_min = None, data_max = None, time_max = None):
    """
    Expects a batch of time series data in the form of (record_id, tt, vals, mask) where
        - record_id is a patient id
        - tt is a (T, ) tensor containing T time values of observations.
        - vals is a (T, D) tensor containing observed values for D variables.
        - mask is a (T, D) tensor containing 1 where values were observed and 0 otherwise.
    Returns:
        batch_tt: (B, L) the batch contains a maximal L time values of observations.
        batch_vals: (B, L, D) tensor containing the observed values.
        batch_mask: (B, L, D) tensor containing 1 where values were observed and 0 otherwise.
    """
    
    valid_batch = [sample for sample in batch if torch.ge(sample[1], args.history).sum() > 0]
    if not valid_batch:
        return None
    batch = valid_batch
    
    # n_observed_tps = []
    observed_tp = []
    observed_data = []
    observed_mask = [] 
    predicted_tp = []
    predicted_data = []
    predicted_mask = [] 

    for b, (record_id, tt, vals, mask) in enumerate(batch):
        n_observed_tp = torch.lt(tt, args.history).sum()
        # n_observed_tps.append(n_observed_tp)
        observed_tp.append(tt[:n_observed_tp])
        observed_data.append(vals[:n_observed_tp])
        observed_mask.append(mask[:n_observed_tp])
        
        predicted_tp.append(tt[n_observed_tp:])
        predicted_data.append(vals[n_observed_tp:])
        predicted_mask.append(mask[n_observed_tp:])

    observed_tp = pad_sequence(observed_tp, batch_first=True)
    observed_data = pad_sequence(observed_data, batch_first=True)
    observed_mask = pad_sequence(observed_mask, batch_first=True)
    predicted_tp = pad_sequence(predicted_tp, batch_first=True)
    predicted_data = pad_sequence(predicted_data, batch_first=True)
    predicted_mask = pad_sequence(predicted_mask, batch_first=True)
    # print(observed_tp.shape, observed_data.shape, observed_mask.shape, predicted_tp.shape, predicted_data.shape, predicted_mask.shape)

    if(args.dataset != 'ushcn'):
        observed_data = utils.normalize_masked_data(observed_data, observed_mask, 
            att_min = data_min, att_max = data_max)
        predicted_data = utils.normalize_masked_data(predicted_data, predicted_mask, 
            att_min = data_min, att_max = data_max)
    
    observed_tp = utils.normalize_masked_tp(observed_tp, att_min = 0, att_max = time_max)
    predicted_tp = utils.normalize_masked_tp(predicted_tp, att_min = 0, att_max = time_max)
    # print(predicted_data.sum(), predicted_tp.sum())
    # print(observed_tp.max())
    # print(predicted_tp.max())
        
    data_dict = {"observed_data": observed_data,
            "observed_tp": observed_tp,
            "observed_mask": observed_mask,
            "data_to_predict": predicted_data,
            "tp_to_predict": predicted_tp,
            "mask_predicted_data": predicted_mask,
            }
    # print("vecdata:", data_dict["data_to_predict"].sum(), data_dict["mask_predicted_data"].sum())
    
    return data_dict


def variable_time_collate_fn_max(batch, args, device = torch.device("cpu"), data_type = "train", 
    data_min = None, data_max = None, time_max = None):
    """
    args.ts_len을 사용하여 observed 데이터의 길이를 고정시키는 collate_fn.
    """
    valid_batch = [sample for sample in batch if torch.ge(sample[1], args.history).sum() > 0]
    if not valid_batch:
        return None
    batch = valid_batch
    
    # 기존과 동일하게 observed와 predicted 데이터를 리스트에 분리하여 저장
    observed_tp, observed_data, observed_mask = [], [], []
    predicted_tp, predicted_data, predicted_mask = [], [], []

    for b, (record_id, tt, vals, mask) in enumerate(batch):
        n_observed_tp = torch.lt(tt, args.history).sum()
        
        observed_tp.append(tt[:n_observed_tp])
        observed_data.append(vals[:n_observed_tp])
        observed_mask.append(mask[:n_observed_tp])
        
        predicted_tp.append(tt[n_observed_tp:])
        predicted_data.append(vals[n_observed_tp:])
        predicted_mask.append(mask[n_observed_tp:])

    # ==================================================================
    # ### 핵심 수정 부분: observed 데이터를 args.ts_len으로 수동 패딩 ###
    # ==================================================================
    
    # 1. 패딩에 필요한 정보 가져오기
    batch_size = len(batch)
    num_features = batch[0][2].shape[1] # vals의 feature dimension (D)
    max_len = args.maxlen

    # 2. 목표 길이(max_len)를 갖는 0으로 채워진 템플릿 텐서 생성
    padded_observed_data = torch.zeros(batch_size, max_len, num_features, device=device)
    padded_observed_mask = torch.zeros(batch_size, max_len, num_features, device=device)
    padded_observed_tp = torch.zeros(batch_size, max_len, device=device)

    # 3. 루프를 돌며 각 시퀀스의 실제 데이터를 템플릿 텐서에 복사
    for i in range(batch_size):
        # 현재 데이터의 실제 길이 (max_len보다 길 경우 잘라냄)
        length = min(observed_data[i].shape[0], max_len)
        
        padded_observed_data[i, :length, :] = observed_data[i][:length]
        padded_observed_mask[i, :length, :] = observed_mask[i][:length]
        padded_observed_tp[i, :length] = observed_tp[i][:length]

    # 4. 패딩 완료된 텐서를 원래 변수 이름으로 할당
    observed_data = padded_observed_data
    observed_mask = padded_observed_mask
    observed_tp = padded_observed_tp

    # ==================================================================
    # ### 수정 끝 ###
    # ==================================================================

    # predicted 데이터는 기존과 동일하게 pad_sequence 사용 (배치 내 최대 길이로 패딩)
    predicted_tp = pad_sequence(predicted_tp, batch_first=True)
    predicted_data = pad_sequence(predicted_data, batch_first=True)
    predicted_mask = pad_sequence(predicted_mask, batch_first=True)

    # 정규화 로직 (기존과 동일)
    if(args.dataset != 'ushcn'):
        observed_data = utils.normalize_masked_data(observed_data, observed_mask, 
            att_min = data_min, att_max = data_max)
        predicted_data = utils.normalize_masked_data(predicted_data, predicted_mask, 
            att_min = data_min, att_max = data_max)
    
    observed_tp = utils.normalize_masked_tp(observed_tp, att_min = 0, att_max = time_max)
    predicted_tp = utils.normalize_masked_tp(predicted_tp, att_min = 0, att_max = time_max)
        
    # 최종 반환 딕셔너리 (기존과 동일)
    data_dict = {
        "observed_data": observed_data,
        "observed_tp": observed_tp,
        "observed_mask": observed_mask,
        "data_to_predict": predicted_data,
        "tp_to_predict": predicted_tp,
        "mask_predicted_data": predicted_mask,
    }
    
    return data_dict



if __name__ == '__main__':
    torch.manual_seed(1991)

    dataset = PhysioNet('../data/physionet', train=False, download=True)
    dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=variable_time_collate_fn)
    print(dataloader.__iter__().next())
 

