import logging
import os
import numpy as np
import pickle as pkl
from torch.utils.data import Dataset

from preprocessing import ECGProcessing
from augmentation import ECGRandAugment

logger = logging.getLogger(__name__)


def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pkl.load(f)
    return ret_di


# Data loader class define
class ECGDataset(Dataset):
    def __init__(self, path_x, df, params, mode):
        logger.info(f'{mode} mode data size is {df.values.shape}')

        filenames = df[params['preproc']['filename']].tolist()
        self.df = df
        self.filenames = filenames
        self.filenames_x = [os.path.join(path_x, f) for f in self.filenames]
        self.lead = params['preproc']['lead']  # which leads to use
        self.task = params['preproc']['task']  # whcih tasks to use
        self.mode = mode
        self.params = params
        self.Y_np_arr = self.get_label()
        self.sample_rates = self.df[self.params['preproc']['samplerate']].values

        self.signal_processor = ECGProcessing.init_from_params(mode, params)

        self.rand_augment_use = params['data_augmentation']['rand_augment_use']
        if self.rand_augment_use:
            self.augment_fn = ECGRandAugment(
                params=params,
                op_names=params['data_augmentation']['rand_augment_params']['op_names'],
                level=params['data_augmentation']['rand_augment_params']['level'],
                num_layers=params['data_augmentation']['rand_augment_params']['num_layers']
            )

    def __len__(self):
        return len(self.filenames)

    def get_label(self):
        target_label = self.params['preproc']['task']
        if all([task in self.df.columns for task in target_label]):
            selected_df = self.df[target_label]
            return selected_df.values
        else:
            return None

    def __getitem__(self, idx):
        sample_rate = int(self.sample_rates[idx])
        ecg_dict = load_dict(self.filenames_x[idx])

        # read values of predefined leads
        data = [ecg_dict[lead] for lead in self.lead]
        ecg = np.stack(data)

        ecg = self.signal_processor(ecg,
                                    original_sample_rate=sample_rate)

        if self.rand_augment_use and self.mode == 'train':
            ecg = self.augment_fn(ecg)

        data = {'input': ecg}
        
        if self.Y_np_arr is not None:
            label = self.Y_np_arr[idx]
            label = label.astype(np.int64)
            data['label'] = label

        if self.mode != 'train':
            fname = self.filenames[idx]
            data['fname'] = fname
        return data
