import collections
import glob
import logging
import math
import os
import warnings
from PIL import Image
import pandas as pd
import numpy as np
import soundfile as sf
from scipy import signal
from copy import deepcopy
import random
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

logger = logging.getLogger(__name__)


class AVE(Dataset):
    TYPE = 'CIL'
    CLASS_ORDER = [
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27],
        [26, 19, 24,  2,  8, 13, 10,  6, 15, 18, 14, 21, 12, 23, 11,  5,  0, 20,  4, 22,  9,  7, 16,  3, 27, 17, 25,  1],
        [21,  9,  2, 25, 23, 16,  6, 22,  0,  7, 20, 13, 24, 27, 18, 17, 26, 12,  4, 11, 10,  5,  3,  1, 14,  8, 19, 15]
    ]    
    N_CLASSES = 28
    N_TASKS = 7
    N_CLASSES_PER_TASK = 4
    
    def __init__(self, args, task_id, train=True, augment=False):
        
        self.args = args
        self.root = args["data_path"]
        self.interval = 9
        self.augment = augment

        start_class = self.N_CLASSES_PER_TASK * task_id[1]
        end_class = self.N_CLASSES_PER_TASK * (1 + task_id[1])
        classes = self.CLASS_ORDER[task_id[0]][start_class:end_class]
                
        self.split = 'train' if train else 'test'
        
        all_classId_vid_dict = np.load(
            os.path.join(self.root, 'all_classId_vid_dict.npy'), allow_pickle=True
        ).item()[self.split]
        
        self.samples = []
        self.targets = []
        for idx in range(start_class, end_class):
            class_idx = self.CLASS_ORDER[task_id[0]][idx]
            vids = all_classId_vid_dict[class_idx]
            for vid in vids:
                self.targets.append(idx)
                self.samples.append((vid, len(glob.glob(f"{self.root}/AVE_Dataset/rgb/{vid}/frame_*.jpg"))))
        self.targets = np.array(self.targets)

        self.audio_samplerate = 16000

        from mmaction.utils import register_all_modules    
        register_all_modules()

        from mmengine.dataset.base_dataset import Compose        
        from .pipeline import Collect, Normalize3D     

        from mmengine.logging import MMLogger
        MMLogger.get_current_instance().setLevel(logging.CRITICAL)

        clip_len, frame_interval = 32, 2
        mean, std = [123.675, 116.28, 103.53], [58.395, 57.12, 57.375]
                    
        if augment and train:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': clip_len, 'frame_interval': frame_interval, 'num_clips': 1},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'RandomResizedCrop'},
                {'type': 'Resize', 'scale': (224, 224), 'keep_ratio': False},
                {'type': 'Flip', 'flip_ratio': 0.5},
                {'type': 'Normalize3D', 'mean': mean, 'std': std, 'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])
        else:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': clip_len, 'frame_interval': frame_interval, 'num_clips': 1, 'test_mode': True},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'CenterCrop', 'crop_size': 224},
                {'type': 'Flip', 'flip_ratio': 0},
                {'type': 'Normalize3D', 'mean': mean, 'std': std, 'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])

        self.task_id = np.ones(len(self.targets), dtype=int) * task_id[1]

        self.past_logits = None
        
    def __len__(self):
        return len(self.samples)        
        
    def __getitem__(self, index):
        video_path = f"{self.root}/AVE_Dataset/rgb/{self.samples[index][0]}"
        audio_path = f"{self.root}/AVE_Dataset/audio/{self.samples[index][0]}.wav"

        video_dict = dict(
            frame_dir=video_path,
            total_frames=self.samples[index][1],
            label=-1,
            start_index=1,
            filename_tmpl='frame_{:010}.jpg',
            modality='RGB')
        video_data = self.visual_pipeline(video_dict)

        target = self.targets[index]

        samples, samplerate = sf.read(audio_path)

        resamples = samples[:self.audio_samplerate*10]
        while len(resamples) < self.audio_samplerate*10:
            resamples = np.tile(resamples, 10)[:self.audio_samplerate*10]

        resamples[resamples > 1.] = 1.
        resamples[resamples < -1.] = -1.
        frequencies, times, spectrogram = signal.spectrogram(resamples, samplerate, nperseg=512, noverlap=353)
        spectrogram = np.log(spectrogram + 1e-7)

        mean = np.mean(spectrogram)
        std = np.std(spectrogram)
        spectrogram = np.divide(spectrogram - mean, std + 1e-9)
        if self.split == 'train' and self.augment:
            noise = np.random.uniform(-0.05, 0.05, spectrogram.shape)
            spectrogram_aug = spectrogram + noise
            start1 = np.random.choice(256 - self.interval, (1,))[0]
            spectrogram_aug[start1:(start1 + self.interval), :] = 0
            spectrogram_out = np.expand_dims(spectrogram_aug,0)
        else:
            spectrogram_out = np.expand_dims(spectrogram,0)
        spectrogram_out = torch.from_numpy(spectrogram_out.astype(np.float32))

        if self.args.get('model') == 'co2l' and self.split == 'train' and self.augment:
            video_data2 = self.visual_pipeline(video_dict)
            noise = np.random.uniform(-0.05, 0.05, spectrogram.shape)
            spectrogram_aug = spectrogram + noise
            start1 = np.random.choice(256 - self.interval, (1,))[0]
            spectrogram_aug[start1:(start1 + self.interval), :] = 0
            spectrogram_out2 = np.expand_dims(spectrogram_aug,0)
            spectrogram_out2 = torch.from_numpy(spectrogram_out2.astype(np.float32))
            output = {'visual': [video_data['imgs'].squeeze(0),video_data2['imgs'].squeeze(0)], 
                      'audio': [spectrogram_out, spectrogram_out2], 
                      'target': target, 'task_id': self.task_id[index]}

        else:            
            output = {'visual': video_data['imgs'].squeeze(0), 'audio': spectrogram_out, 'target': target, 'task_id': self.task_id[index]}

        if self.past_logits is not None:
            output['past_logits'] = self.past_logits[index]

        return output

    def resample(self, index):
        self.samples = [self.samples[i] for i in index]
        self.targets = self.targets[index]  
        self.task_id = self.task_id[index]
        
        if self.past_logits is not None:
            self.past_logits = self.past_logits[index]

    def join(self, dataset):
        assert isinstance(dataset, type(self))
        
        if self.past_logits is not None or dataset.past_logits is not None:
            if self.past_logits is None:
                self.past_logits = np.empty((len(self.samples),self.N_CLASSES))
                self.past_logits.fill(np.nan)
            elif dataset.past_logits is None:
                dataset.past_logits = np.empty((len(dataset.samples),self.N_CLASSES))
                dataset.past_logits.fill(np.nan)
            
            self.past_logits = np.concatenate([self.past_logits, dataset.past_logits]).astype(np.float32)            
        
        self.samples = self.samples + dataset.samples
        self.targets = np.concatenate([self.targets, dataset.targets])
        
        self.task_id = np.concatenate([self.task_id, dataset.task_id])


class UESTC_MMEA(Dataset):
    TYPE = 'CIL'
    CLASS_ORDER = [
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
        [26, 14, 23,  4, 11, 25, 31, 10, 29,  5,  6,  9, 17, 22,  2, 19, 13,  1, 21, 16,  8,  3, 27, 28, 15, 30,  0,  7, 12, 18, 20,  2],
        [22, 20,  2, 14, 23, 28,  8,  6, 15, 29, 10, 21, 13, 30, 12, 27, 11,  5,  0, 24, 18,  4, 19, 26,  9,  7, 16,  3, 31, 17, 25,  1]
    ]    
    N_CLASSES = 32
    N_TASKS = 8
    N_CLASSES_PER_TASK = 4
    
    def __init__(self, args, task_id, train=True, augment=False):
        
        self.args = args
        self.root = args["data_path"]
        self.augment = augment
        
        start_class = self.N_CLASSES_PER_TASK * task_id[1]
        end_class = self.N_CLASSES_PER_TASK * (1 + task_id[1])
        classes = self.CLASS_ORDER[task_id[0]][start_class:end_class]
                
        self.split = 'train' if train else 'test'
        
        all_classId_vid_list = [x.strip().split(' ') for x in open(f"{self.root}/UESTC-MMEA-CL/{self.split}.txt")]
        all_classId_vid_list = [item for item in all_classId_vid_list if int(item[1]) >= 3]

        all_classId_vid_dict = {}
        for t in all_classId_vid_list:
            if int(t[-1]) not in all_classId_vid_dict:
                all_classId_vid_dict[int(t[-1])] = []
            all_classId_vid_dict[int(t[-1])].append(t[0].replace("yourdataset_path/data/", ''))
                
        self.samples = []
        self.targets = []
        for idx in range(start_class, end_class):
            class_idx = self.CLASS_ORDER[task_id[0]][idx]
            vids = all_classId_vid_dict[class_idx]
            for vid in vids:
                self.targets.append(idx)
                self.samples.append((vid, len(glob.glob(f"{self.root}/UESTC-MMEA-CL/rgb/{vid}/frame_*.jpg"))))
        self.targets = np.array(self.targets)

        self.imu_samplerate = 30

        from mmaction.utils import register_all_modules    
        register_all_modules()

        from mmengine.dataset.base_dataset import Compose        
        from .pipeline import Collect, Normalize3D     

        from mmengine.logging import MMLogger
        MMLogger.get_current_instance().setLevel(logging.CRITICAL)

        if augment and train:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': 32, 'frame_interval': 2, 'num_clips': 1},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'RandomResizedCrop'},
                {'type': 'Resize', 'scale': (224, 224), 'keep_ratio': False},
                {'type': 'Flip', 'flip_ratio': 0.5},
                {'type': 'Normalize3D',
                  'mean': [123.675, 116.28, 103.53],
                  'std': [58.395, 57.12, 57.375],
                  'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])
        else:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': 32, 'frame_interval': 2, 'num_clips': 1, 'test_mode': True},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'CenterCrop', 'crop_size': 224},
                {'type': 'Flip', 'flip_ratio': 0},
                {'type': 'Normalize3D',
                  'mean': [123.675, 116.28, 103.53],
                  'std': [58.395, 57.12, 57.375],
                  'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])

        self.task_id = np.ones(len(self.targets), dtype=int) * task_id[1]

        self.past_logits = None
        
    def __len__(self):
        return len(self.samples)        
        
    def __getitem__(self, index):
        video_path = f"{self.root}/UESTC-MMEA-CL/rgb/{self.samples[index][0]}"
        imu_path = f"{self.root}/UESTC-MMEA-CL/accelerator/{self.samples[index][0]}.npy"
        
        video_dict = dict(
            frame_dir=video_path,
            total_frames=self.samples[index][1],
            label=-1,
            start_index=1,
            filename_tmpl='frame_{:010}.jpg',
            modality='RGB')
        video_data = self.visual_pipeline(video_dict)

        target = self.targets[index]

        imu = np.load(imu_path)
        while imu.shape[1] < self.imu_samplerate*10:
            imu = np.tile(imu, (1,2))[:self.imu_samplerate*10]
        
        if self.split == 'train' and self.augment:
            start_idx = np.random.choice(imu.shape[1]-(self.imu_samplerate*10))
        else:
            start_idx = (imu.shape[1]-(self.imu_samplerate*10))//2
        imu_out = imu[:,start_idx:start_idx+(self.imu_samplerate*10)]
        imu_out = torch.from_numpy(imu_out.astype(np.float32))

        if self.args.get('model') == 'co2l' and self.split == 'train' and self.augment:
            video_data2 = self.visual_pipeline(video_dict)
            start_idx = np.random.choice(imu.shape[1]-(self.imu_samplerate*10))
            imu_out2 = imu[:,start_idx:start_idx+(self.imu_samplerate*10)]
            imu_out2 = torch.from_numpy(imu_out2.astype(np.float32))
            output = {'visual': [video_data['imgs'].squeeze(0),video_data2['imgs'].squeeze(0)], 
                      'inertial': [imu_out, imu_out2], 
                      'target': target, 'task_id': self.task_id[index]}

        else:            
            output = {'visual': video_data['imgs'].squeeze(0), 'inertial': imu_out, 'target': target, 'task_id': self.task_id[index]}
        
        if self.past_logits is not None:
            output['past_logits'] = self.past_logits[index]

        return output

    def resample(self, index):
        self.samples = [self.samples[i] for i in index]
        self.targets = self.targets[index]  
        self.task_id = self.task_id[index]
        
        if self.past_logits is not None:
            self.past_logits = self.past_logits[index]

    def join(self, dataset):
        assert isinstance(dataset, type(self))
        
        if self.past_logits is not None or dataset.past_logits is not None:
            if self.past_logits is None:
                self.past_logits = np.empty((len(self.samples),self.N_CLASSES))
                self.past_logits.fill(np.nan)
            elif dataset.past_logits is None:
                dataset.past_logits = np.empty((len(dataset.samples),self.N_CLASSES))
                dataset.past_logits.fill(np.nan)
            
            self.past_logits = np.concatenate([self.past_logits, dataset.past_logits]).astype(np.float32)            
        
        self.samples = self.samples + dataset.samples
        self.targets = np.concatenate([self.targets, dataset.targets])
        
        self.task_id = np.concatenate([self.task_id, dataset.task_id])


class KITCHEN(Dataset):
    TYPE = 'DIL'
    DOMAIN_ORDER = [
        ['P22','P01','P02','P30','P04'],
        ['P01','P02','P04','P22','P30'],
        ['P30','P02','P22','P04','P01'],
    ]    
    N_CLASSES = 10
    N_TASKS = 5
    N_CLASSES_PER_TASK = 10
    
    def __init__(self, args, task_id, train=True, augment=False):
        
        self.args = args
        self.root = args["data_path"]
        self.interval = 9
        self.augment = augment
        
        domain = self.DOMAIN_ORDER[task_id[0]][task_id[1]]
        self.split = 'train' if train else 'test'
        file = pd.read_pickle(f"{self.root}/{domain}_{self.split}.pkl")
        
        self.samples = []
        self.targets = []
        for _, line in file.iterrows():
            image = [f'{domain}' + '/' + line['video_id'], line['start_frame'], line['stop_frame'], line['start_timestamp'],
                    line['stop_timestamp']]
            labels = line['verb_class']
            self.targets.append(int(labels))
            self.samples.append((image[0], image[1], image[2], image[3], image[4]))
        self.targets = np.array(self.targets)

        self.audio_samplerate = 16000

        from mmaction.utils import register_all_modules    
        register_all_modules()

        from mmengine.dataset.base_dataset import Compose        
        from .pipeline import Collect, Normalize3D     
        
        from mmengine.logging import MMLogger
        MMLogger.get_current_instance().setLevel(logging.CRITICAL)
        
        if augment and train:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': 32, 'frame_interval': 2, 'num_clips': 1},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'RandomResizedCrop'},
                {'type': 'Resize', 'scale': (224, 224), 'keep_ratio': False},
                {'type': 'Flip', 'flip_ratio': 0.5},
                {'type': 'Normalize3D',
                  'mean': [123.675, 116.28, 103.53],
                  'std': [58.395, 57.12, 57.375],
                  'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])
        else:
            self.visual_pipeline = Compose([
                {'type': 'SampleFrames', 'clip_len': 32, 'frame_interval': 2, 'num_clips': 1, 'test_mode': True},
                {'type': 'RawFrameDecode'},
                {'type': 'Resize', 'scale': (-1, 256)},
                {'type': 'CenterCrop', 'crop_size': 224},
                {'type': 'Flip', 'flip_ratio': 0},
                {'type': 'Normalize3D',
                  'mean': [123.675, 116.28, 103.53],
                  'std': [58.395, 57.12, 57.375],
                  'to_rgb': False},
                {'type': 'FormatShape', 'input_format': 'NCTHW'},
                {'type': 'Collect', 'keys': ['imgs'], 'meta_keys': []},
                {'type': 'ToTensor', 'keys': ['imgs']}
            ])
            
        self.task_id = np.ones(len(self.targets), dtype=int) * task_id[1]

        self.past_logits = None
        
    def __len__(self):
        return len(self.samples)        
        
    def __getitem__(self, index):
        video_path = f"{self.root}/EPIC_KITCHENS/rgb/{self.split}/{self.samples[index][0]}"
        audio_path = f"{self.root}/EPIC_KITCHENS/audio/{self.split}/{self.samples[index][0]}.npy"

        video_dict = dict(
            frame_dir=video_path,
            total_frames=int(self.samples[index][2] - self.samples[index][1]),
            label=-1,
            start_index=int(self.samples[index][1]),
            filename_tmpl='frame_{:010}.jpg',
            modality='RGB')
        video_data = self.visual_pipeline(video_dict)

        target = self.targets[index]
        
        samples = np.load(audio_path)
        duration = len(samples) / self.audio_samplerate

        fr_sec = self.samples[index][3].split(':')
        hour1 = float(fr_sec[0])
        minu1 = float(fr_sec[1])
        sec1 = float(fr_sec[2])
        fr_sec = (hour1 * 60 + minu1) * 60 + sec1

        stop_sec = self.samples[index][4].split(':')
        hour1 = float(stop_sec[0])
        minu1 = float(stop_sec[1])
        sec1 = float(stop_sec[2])
        stop_sec = (hour1 * 60 + minu1) * 60 + sec1

        start1 = fr_sec / duration * len(samples)
        end1 = stop_sec / duration * len(samples)
        start1 = int(np.round(start1))
        end1 = int(np.round(end1))
        samples = samples[start1:end1]

        resamples = samples[:self.audio_samplerate*10]
        while len(resamples) < self.audio_samplerate*10:
            resamples = np.tile(resamples, 10)[:self.audio_samplerate*10]

        resamples[resamples > 1.] = 1.
        resamples[resamples < -1.] = -1.
        frequencies, times, spectrogram = signal.spectrogram(resamples, self.audio_samplerate, nperseg=512, noverlap=353)
        spectrogram = np.log(spectrogram + 1e-7)

        mean = np.mean(spectrogram)
        std = np.std(spectrogram)
        spectrogram = np.divide(spectrogram - mean, std + 1e-9)
        if self.split == 'train' and self.augment:
            noise = np.random.uniform(-0.05, 0.05, spectrogram.shape)
            spectrogram_aug = spectrogram + noise
            start1 = np.random.choice(256 - self.interval, (1,))[0]
            spectrogram_aug[start1:(start1 + self.interval), :] = 0
            spectrogram_out = np.expand_dims(spectrogram_aug,0)
        else:
            spectrogram_out = np.expand_dims(spectrogram,0)
        spectrogram_out = torch.from_numpy(spectrogram_out.astype(np.float32))
        
        if self.args.get('model') == 'co2l' and self.split == 'train' and self.augment:
            video_data2 = self.visual_pipeline(video_dict)
            noise = np.random.uniform(-0.05, 0.05, spectrogram.shape)
            spectrogram_aug = spectrogram + noise
            start1 = np.random.choice(256 - self.interval, (1,))[0]
            spectrogram_aug[start1:(start1 + self.interval), :] = 0
            spectrogram_out2 = np.expand_dims(spectrogram_aug,0)
            spectrogram_out2 = torch.from_numpy(spectrogram_out2.astype(np.float32))
            output = {'visual': [video_data['imgs'].squeeze(0),video_data2['imgs'].squeeze(0)], 
                      'audio': [spectrogram_out, spectrogram_out2], 
                      'target': target, 'task_id': self.task_id[index]}

        else:            
            output = {'visual': video_data['imgs'].squeeze(0), 'audio': spectrogram_out, 'target': target, 'task_id': self.task_id[index]}

        if self.past_logits is not None:
            output['past_logits'] = self.past_logits[index]
            
        return output
    
    def resample(self, index):
        self.samples = [self.samples[i] for i in index]
        self.targets = self.targets[index]  
        self.task_id = self.task_id[index]
        
        if self.past_logits is not None:
            self.past_logits = self.past_logits[index]

    def join(self, dataset):
        assert isinstance(dataset, type(self))
        
        if self.past_logits is not None or dataset.past_logits is not None:
            if self.past_logits is None:
                self.past_logits = np.empty((len(self.samples),self.N_CLASSES))
                self.past_logits.fill(np.nan)
            elif dataset.past_logits is None:
                dataset.past_logits = np.empty((len(dataset.samples),self.N_CLASSES))
                dataset.past_logits.fill(np.nan)
            
            self.past_logits = np.concatenate([self.past_logits, dataset.past_logits]).astype(np.float32)            
        
        self.samples = self.samples + dataset.samples
        self.targets = np.concatenate([self.targets, dataset.targets])
        
        self.task_id = np.concatenate([self.task_id, dataset.task_id])



class DKD(Dataset):
    TYPE = 'DIL'
    DOMAIN_ORDER = [
        ['SiDRP','SEED','KTPH'],
        ['SEED','KTPH','SiDRP'],
        ['KTPH','SiDRP','SEED'],
    ]    
    N_CLASSES = 2
    N_TASKS = 3
    N_CLASSES_PER_TASK = 2
    
    def __init__(self, args, task_id, train=True, augment=False):
        
        self.args = args
        self.root = args["data_path"]
        self.augment = augment
        
        domain = self.DOMAIN_ORDER[task_id[0]][task_id[1]]
        self.split = 'train' if train else 'test'
        
        self.samples = pd.read_csv(f"{self.root}/{domain}_{self.split}.csv")
        self.targets = np.array(self.samples['CKD'])

        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]

        if augment and train:
            from timm.data import create_transform
            
            self.visual_transform = create_transform(
                input_size=224,
                is_training=True,
                color_jitter=0.4,
                auto_augment=None,
                interpolation='bicubic',
                re_prob=0.25,
                re_mode='pixel',
                re_count=1,
                mean=mean,
                std=std,
            )
        else:
            self.visual_transform = transforms.Compose([
                transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), 
                transforms.ToTensor(),
                transforms.Normalize(mean, std)                
            ])
        
        self.task_id = np.ones(len(self.samples), dtype=int) * task_id[1]

        self.past_logits = None
        
    def __len__(self):
        return len(self.samples)        
        
    def __getitem__(self, index):
        sample = self.samples.iloc[index]
        
        imageL_path = f"{self.root}/images/{sample['imageName_LM']}"
        imageR_path = f"{self.root}/images/{sample['imageName_RM']}"
        
        imageL_img = Image.open(imageL_path)
        imageR_img = Image.open(imageR_path)
        
        imageL = self.visual_transform(imageL_img)
        imageR = self.visual_transform(imageR_img)
        
        if self.split == 'train' and self.augment and np.random.rand() < 0.5:
            image = torch.stack([imageR, imageL], dim=0)
        else:
            image = torch.stack([imageL, imageR], dim=0)

        tabular = np.zeros(7).astype(np.float32)    
        tabular[0] = sample['age']
        if sample['gender'] == 2:
            tabular[1] = 1

        if sample['ethnicity'] == 'Chinese':
            tabular[2] = 1
        elif sample['ethnicity'] == 'Malay':
            tabular[3] = 1
            
        if str(sample['dbduration']) != 'nan':
            tabular[4] = sample['dbduration']
        else:
            tabular[4] = self.samples['dbduration'].mean()

        if str(sample['hba1c']) != 'nan' :
            tabular[5] = sample['hba1c']
        else:
            tabular[5] = self.samples['hba1c'].mean()

        if str(sample['sbp']) != 'nan' :
            tabular[6] = sample['sbp']
        else:
            tabular[6] = self.samples['sbp'].mean()
        tabular = torch.from_numpy(tabular)
        
        target = sample['CKD']        
            
        if self.args.get('model') == 'co2l' and self.split == 'train' and self.augment:
            imageL = self.visual_transform(imageL_img)
            imageR = self.visual_transform(imageR_img)

            if np.random.rand() < 0.5:
                image2 = torch.stack([imageR, imageL], dim=0)
            else:
                image2 = torch.stack([imageL, imageR], dim=0)

            output = {'visual': [image, image2], 
                      'tabular': [tabular, tabular], 
                      'target': target, 'task_id': self.task_id[index]}

        else:            
            output = {'visual': image, 'tabular': tabular, 'target': target, 'task_id': self.task_id[index]}

        if self.past_logits is not None:
            output['past_logits'] = self.past_logits[index]
        return output
    
    def resample(self, index):
        self.samples = self.samples.iloc[index].reset_index(drop=True)
        self.targets = self.targets[index]  
        self.task_id = self.task_id[index]
        
        if self.past_logits is not None:
            self.past_logits = self.past_logits[index]
                
    def join(self, dataset):
        assert isinstance(dataset, type(self))
        
        if self.past_logits is not None or dataset.past_logits is not None:
            if self.past_logits is None:
                self.past_logits = np.empty((len(self.samples),self.N_CLASSES))
                self.past_logits.fill(np.nan)
            elif dataset.past_logits is None:
                dataset.past_logits = np.empty((len(dataset.samples),self.N_CLASSES))
                dataset.past_logits.fill(np.nan)
            
            self.past_logits = np.concatenate([self.past_logits, dataset.past_logits]).astype(np.float32)            
        
        self.samples = pd.concat([self.samples, dataset.samples]).reset_index(drop=True)
        self.targets = np.concatenate([self.targets, dataset.targets])
        
        self.task_id = np.concatenate([self.task_id, dataset.task_id])


