import os
import cv2
import glob
import numpy as np
import torch
from torch.utils.data import Dataset


class KMAradar4kmDataset(Dataset):
    def __init__(self, radar_4km_path='/data/KMA/Radar/HSR_4km', year_from=2014, year_to=2021, input_length=7, input_interval=1, output_length=6, output_interval=1):
        self.radar_path = radar_4km_path
        self.year_from = year_from
        self.year_to = year_to
        self.input_length = input_length
        self.input_interval = input_interval
        self.output_length = output_length
        self.output_interval = output_interval
        self.history_length = (input_length-1)*input_interval + output_length*output_interval

        days = [31,28,31,30,31,30,31,31,30,31,30,31]
        times = [f'{hour:02d}{minutes:02d}' for hour in range(0,24) for minutes in range(0,60,10)]
        times.sort(key=lambda x : int(x)) 

        self.radar_list = []
        data_in_years = []
        num_data = 0
        for year in range(year_from, year_to+1):
            for month in range(1,13):
                for day in range(1,days[month-1]+1):
                    for time in times:
                        YYYYMMDD = f"{year:04d}/{month:02d}/{day:02d}"
                        YYYYMMDDHHmm = int(f"{year:04d}{month:02d}{day:02d}{time}")
                        radar_filename = f"{YYYYMMDDHHmm}.tiff"
                        self.radar_list.append(f"{self.radar_path}/{YYYYMMDD}/{radar_filename}")
                        num_data += 1
            data_in_years.append(num_data
                                 )
        self.time_list = list(range(len(self.radar_list)))

        self.history_indices = []
        last_nonexist_index = -1

        for idx, radar_name in enumerate(self.radar_list):
            if idx in data_in_years:
                last_nonexist_index = idx-1

            if os.path.exists(radar_name):
                condition1 = (idx - last_nonexist_index > self.history_length)

                if condition1:
                    self.history_indices.append(idx)
                else:
                    pass
            else:
                last_nonexist_index = idx

    def __len__(self):
        return len(self.history_indices)

    def __getitem__(self, idx):
        final_idx = self.history_indices[idx] 
        current_idx = final_idx - self.output_length*self.output_interval
        init_idx = current_idx - (self.input_length-1)*self.input_interval

        input_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(init_idx, current_idx+1, self.input_interval)])
        output_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(current_idx+self.output_interval, final_idx+1, self.output_interval)])

        return input_radar_history, output_radar_history



class SEVIRnowcastDataset(Dataset):
    def __init__(self, data_dir, split='train', end=None, pct_validation=0.2, norm=True):
        assert split in ['train', 'val', 'test'], "split must be 'train', 'val', or 'test'"

        self.file_list = sorted(glob.glob(os.path.join(data_dir, '*.npz')))
        self.norm = norm

        total_samples = len(self.file_list) if end is None else min(end, len(self.file_list))
        val_idx = int((1 - pct_validation) * total_samples)

        if split == 'train':
            self.start = 0
            self.end = val_idx
        elif split == 'val':
            self.start = val_idx
            self.end = total_samples
        elif split == 'test':
            self.start = 0
            self.end = total_samples

        self.length = self.end - self.start

    def __len__(self):
        return self.length 
    
    def __getitem__(self, idx):
        file_path = self.file_list[self.start + idx]
        data = np.load(file_path)
        x = torch.tensor(data['IN'], dtype=torch.float32).permute(2, 0, 1)  # (H, W, T) -> (T, H, W)
        y = torch.tensor(data['OUT'], dtype=torch.float32).permute(2, 0, 1)
        
        if self.norm:
            MEAN, SCALE = 33.44, 47.54
            x = (x-MEAN) / SCALE
            y = (y-MEAN) / SCALE

        return x, y



class SEVIRnowcastDataset_small(Dataset):
    def __init__(self, data_dir, split='train', end=None, pct_validation=0.2, norm=True):
        assert split in ['train', 'val', 'test'], "split must be 'train', 'val', or 'test'"

        self.file_list = sorted(glob.glob(os.path.join(data_dir, '*.npz')))
        self.norm = norm

        total_samples = len(self.file_list) if end is None else min(end, len(self.file_list))
        val_idx = int((1 - pct_validation) * total_samples)

        if split == 'train':
            self.start = 0
            self.end = val_idx
        elif split == 'val':
            self.start = val_idx
            self.end = total_samples
        elif split == 'test':
            self.start = 0
            self.end = total_samples

        self.length = self.end - self.start

    def __len__(self):
        return self.length 
    
    def __getitem__(self, idx):
        file_path = self.file_list[self.start + idx]
        data = np.load(file_path)
        x = torch.tensor(data['IN'], dtype=torch.float32).permute(2, 0, 1)  # (H, W, T) -> (T, H, W)
        y = torch.tensor(data['OUT'], dtype=torch.float32).permute(2, 0, 1)
        
        if self.norm:
            MEAN, SCALE = 33.44, 47.54
            x = (x-MEAN) / SCALE
            y = (y-MEAN) / SCALE

        x = x[:, ::3, ::3]
        y = y[:, ::3, ::3]

        y = torch.cat((x[5:], y), dim=0)
        x = x[:5]
        
        # final shape: x.shape=(5, 128, 128), y.shape=(20, 128, 128)

        return x, y


if False:
    from mnist_iterator import MovingMNISTAdvancedIterator
    class MovingMNISTDataset(Dataset):
        def __init__(self, num_samples=40_000, seqlen=20, norm=True,
                    scale_velocity=3.6, scale_size=1.1, scale_rotation=15, scale_illumination_min=0.6):
            
            assert seqlen % 2 == 0, "seqlen must be even"

            self.iterator = MovingMNISTAdvancedIterator(max_velocity_scale=scale_velocity,
                                                        initial_velocity_range=(0.0, scale_velocity),
                                                        scale_variation_range=(1 / scale_size, scale_size),
                                                        rotation_angle_range=(-scale_rotation, scale_rotation),
                                                        illumination_factor_range=(scale_illumination_min, 1.0))
            self.num_samples = num_samples
            self.seqlen = seqlen
            self.norm = norm

        def __len__(self):
            return self.num_samples

        def __getitem__(self, idx):
            seq, _ = self.iterator.sample(batch_size=1, seqlen=self.seqlen)
            seq = seq[:, 0, 0, :, :]    # shape: (seqlen, H, W)
            inputs = seq[:self.seqlen//2]
            outputs = seq[self.seqlen//2:]

            if self.norm:
                inputs = inputs / 255.
                outputs = outputs / 255.

            return inputs, outputs


class MovingMNISTExtremeDataset(Dataset):
    def __init__(self,
                 data_path='/data/SwinPreCast/mnist.npz',
                 num_samples=40000,
                 seq_len=20,
                 digit_num=3, 
                 img_size=64, 
                 max_velocity_scale=3.6,
                 initial_velocity_range=(0.0, 3.6),
                 scale_variation_range=(1 / 1.1, 1.1),
                 rotation_angle_range=(-15, 15),
                 illumination_factor_range=(0.6, 1.2),  # (Tony) new feature: each digit will be started with either disappearing or appearing with different illumination factor in this range
                 period_range=(3, 6),                   # (Tony) new feature: each digit will be disappeaing or appearing with different period in this range
                 overlapping=True                       # (Tony) new feature: overlapping digits will be added
                 ):


        def load_mnist(training_num=50000):
            dat = np.load(data_path)
            X = dat['X'][:training_num]
            Y = dat['Y'][:training_num]
            X_test = dat['X_test']
            Y_test = dat['Y_test']
            Y = Y.reshape((Y.shape[0],))
            Y_test = Y_test.reshape((Y_test.shape[0],))
            return X, Y, X_test, Y_test

        self.mnist_train_img, self.mnist_train_label, self.mnist_test_img, self.mnist_test_label = load_mnist()
        self.num_samples = num_samples
        self.seq_len = seq_len
        self._digit_num = digit_num 
        self._img_size = img_size 
        self._max_velocity_scale = max_velocity_scale
        self._initial_velocity_range = initial_velocity_range
        self._scale_variation_range = scale_variation_range
        self._rotation_angle_range = rotation_angle_range
        self._illumination_factor_range = illumination_factor_range
        self._period_range = period_range
        self._overlapping = overlapping
        self._index_range = (0, self.num_samples)


    def draw_imgs(self, base_img, affine_transforms):
        canvas_img = np.zeros((self._img_size, self._img_size), dtype=np.float32)
        for i in range(self._digit_num):
            tmp_img = cv2.warpAffine(base_img[i], affine_transforms[i], (self._img_size, self._img_size))
            if self._overlapping:
                canvas_img += tmp_img
            else:
                canvas_img = np.maximum(canvas_img, tmp_img)
        return canvas_img
    
    def crop_mnist_digit(self, digit_img, tol=5):
        tol = float(tol) / float(255)
        mask = digit_img > tol
        return digit_img[np.ix_(mask.any(1), mask.any(0))]

    def _bounce_border(self, inner_boundary, affine_transform, digit_shift, velocity, img_h, img_w):
        # top-left, top-right, down-left, down-right
        center = affine_transform.dot(np.array([img_w / 2.0, img_h / 2.0, 1], dtype=np.float32))
        new_velocity = velocity.copy()
        new_center = center.copy()
        if center[0] < inner_boundary[0]:
            new_velocity[0] = -new_velocity[0]
            new_center[0] = inner_boundary[0]
        if center[0] > inner_boundary[2]:
            new_velocity[0] = -new_velocity[0]
            new_center[0] = inner_boundary[2]
        if center[1] < inner_boundary[1]:
            new_velocity[1] = -new_velocity[1]
            new_center[1] = inner_boundary[1]
        if center[1] > inner_boundary[3]:
            new_velocity[1] = -new_velocity[1]
            new_center[1] = inner_boundary[3]
        affine_transform[:, 2] += new_center - center
        digit_shift += new_center - center
        return affine_transform, digit_shift, new_velocity

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):

        seq = np.zeros((self.seq_len, self._img_size, self._img_size), dtype=np.float32)
        inner_boundary = np.array([10, 10, self._img_size - 10, self._img_size - 10], dtype=np.float32)

 
        affine_transforms   = np.zeros((self.seq_len, self._digit_num, 2, 3), dtype=np.float32)
        appearance_variants = np.ones((self.seq_len, self._digit_num), dtype=np.float32)
        scale               = np.ones((self.seq_len, self._digit_num), dtype=np.float32)
        rotation_angle      = np.zeros((self.seq_len, self._digit_num), dtype=np.float32)
        init_velocity       = np.zeros((self._digit_num, 2), dtype=np.float32)
        velocity            = np.zeros((self.seq_len, self._digit_num, 2), dtype=np.float32)
        digit_shift         = np.zeros((self.seq_len, self._digit_num, 2), dtype=np.float32)


        digit_indices           = np.random.randint(low=self._index_range[0], high=self._index_range[1], size=(self._digit_num, ))                          
        appearance_mult         = np.random.uniform(low=self._illumination_factor_range[0], high=self._illumination_factor_range[1], size=self._digit_num)  # (Tony) modified        
        scale_variation         = np.random.uniform(low=self._scale_variation_range[0], high=self._scale_variation_range[1], size=(self._digit_num, ))      
        base_rotation_angle     = np.random.uniform(low=self._rotation_angle_range[0], high=self._rotation_angle_range[1], size=(self._digit_num, ))        
        init_velocity_angle     = np.random.uniform(size=(self._digit_num, )) * (2 * np.pi)                                                                
        init_velocity_magnitude = np.random.uniform(low=self._initial_velocity_range[0], high=self._initial_velocity_range[1], size=(self._digit_num, ))    
        affine_transforms_multipliers = np.random.uniform(size=(self._digit_num, 2))                                                                        
        
        periods = np.random.randint(low=self._period_range[0], high=self._period_range[1]+1, size=(self._digit_num, ))  # (Tony) modified                                                            

        base_digit_img = [self.crop_mnist_digit(self.mnist_train_img[i].reshape((28, 28))) for i in digit_indices]

        appearance_variants[0, :] = appearance_variants[0, :] * appearance_mult     # (Tony) apply multiplication factor to the first figure is needed!!
        for i in range(1, self.seq_len):
            appearance_variants[i, :] = appearance_variants[i - 1, :] * (appearance_mult ** -(2 * ((i // periods) % 2) - 1))    # (Tony) modify (i // 5) to (i // periods)

        for i in range(1, self.seq_len):
            base_factor = (2 * ((i // periods) % 2) - 1)    # (Tony) modify (i // 5) to (i // periods)
            scale[i, :] = scale[i - 1, :] * (scale_variation**base_factor)
            rotation_angle[i, :] = rotation_angle[i - 1, :] + base_rotation_angle

        affine_transforms[0, :, 0, 0] = 1.0
        affine_transforms[0, :, 1, 1] = 1.0
        for i in range(self._digit_num):
            affine_transforms[0, i, 0, 2] = affine_transforms_multipliers[i, 0] * (self._img_size - base_digit_img[i].shape[1])
            affine_transforms[0, i, 1, 2] = affine_transforms_multipliers[i, 1] * (self._img_size - base_digit_img[i].shape[0])

        init_velocity[:, 0] = init_velocity_magnitude * np.cos(init_velocity_angle)
        init_velocity[:, 1] = init_velocity_magnitude * np.sin(init_velocity_angle)
        curr_velocity = init_velocity

        for i in range(self._digit_num):
            digit_shift[0, i, 0] = affine_transforms[0, i, 0, 2] 
            digit_shift[0, i, 1] = affine_transforms[0, i, 1, 2] 

        for i in range(self.seq_len - 1):
            velocity[i, :, :] = curr_velocity
            curr_velocity = np.clip(curr_velocity, a_min=-self._max_velocity_scale, a_max=self._max_velocity_scale)
            for j in range(self._digit_num):
                digit_shift[i + 1, j, :] = digit_shift[i, j, :] + curr_velocity[j]
                rotation_mat = cv2.getRotationMatrix2D(
                    center=(base_digit_img[j].shape[1] / 2.0,
                            base_digit_img[j].shape[0] / 2.0),
                    angle=rotation_angle[i + 1, j],
                    scale=scale[i + 1, j])
                affine_transforms[i + 1, j, :, :2] = rotation_mat[:, :2]
                affine_transforms[i + 1, j, :, 2] = digit_shift[i + 1, j, :] + rotation_mat[:, 2]
                affine_transforms[i + 1, j, :, :], digit_shift[i + 1, j, :], curr_velocity[j] =\
                    self._bounce_border(inner_boundary=inner_boundary,
                                        affine_transform=affine_transforms[i + 1, j, :, :],
                                        digit_shift=digit_shift[i + 1, j, :],
                                        velocity=curr_velocity[j],
                                        img_h=base_digit_img[j].shape[0],
                                        img_w=base_digit_img[j].shape[1])
        for i in range(self.seq_len):
            seq[i, :, :] = self.draw_imgs(base_img=[base_digit_img[j] * appearance_variants[i, j] for j in range(self._digit_num)], affine_transforms=affine_transforms[i])
            
        seq = (seq - seq.min()) / (seq.max() - seq.min()) # (Tony) normalization

        inputs = seq[:self.seq_len // 2]
        outputs = seq[self.seq_len // 2:]

        return inputs, outputs
    



class MeteoNetDataset_old(Dataset):
    def __init__(self, radar_path='/data/MeteoNet_unzip/Radar', year_from=2016, year_to=2018, 
                 input_length=13, input_interval=0.5, output_length=12, output_interval=0.5, 
                 norm=True, crop=True, mask=None):

        # input_interval=0.5 means 5 minutes time interval
        # input_interval=1 means 10 minutes time interval
        input_interval = int(input_interval * 2)
        output_interval = int(output_interval * 2)
        self.norm = norm
        self.mask = mask
        self.crop = crop

        self.radar_path = radar_path
        self.year_from = year_from
        self.year_to = year_to
        self.input_length = input_length
        self.input_interval = input_interval
        self.output_length = output_length
        self.output_interval = output_interval
        self.history_length = (input_length-1)*input_interval + output_length*output_interval

        days = [31,28,31,30,31,30,31,31,30,31,30,31]
        times = [f'{hour:02d}{minutes:02d}' for hour in range(0,24) for minutes in range(0,60,5)]
        times.sort(key=lambda x : int(x)) 

        self.radar_list = []
        data_in_years = []
        num_data = 0
        for year in range(year_from, year_to+1):
            for month in range(1,13):
                for day in range(1,days[month-1]+1):
                    day = day + 1 if (year % 4 == 0 and month == 2) else day  # leap year
                    for time in times:
                        YYYYMMDD = f"{year:04d}/{month:02d}/{day:02d}"
                        YYYYMMDDHHmm = int(f"{year:04d}{month:02d}{day:02d}{time}")
                        radar_filename = f"{YYYYMMDDHHmm}.npz"
                        self.radar_list.append(f"{self.radar_path}/{YYYYMMDD}/{radar_filename}")
                        num_data += 1
            data_in_years.append(num_data)
        self.time_list = list(range(len(self.radar_list)))

        self.history_indices = []
        last_nonexist_index = -1

        for idx, radar_name in enumerate(self.radar_list):
            if idx in data_in_years:
                last_nonexist_index = idx-1

            if os.path.exists(radar_name):
                condition1 = (idx - last_nonexist_index > self.history_length)

                if condition1:
                    self.history_indices.append(idx)
                else:
                    pass
            else:
                last_nonexist_index = idx

        
    def __len__(self):
        return len(self.history_indices)

    def __getitem__(self, idx):
        final_idx = self.history_indices[idx] 
        current_idx = final_idx - self.output_length*self.output_interval
        init_idx = current_idx - (self.input_length-1)*self.input_interval

        input_timestamps = []
        output_timestamps = []
        input_radar_history = []
        output_radar_history = []

        def process(img):
            if self.crop:
                img = img[-384:, -384:]
            if self.mask is None:
                img[img == 255] = 0    # set uncovered radar region to 0
            if self.norm:
                img = img / 70.0
            return img.astype(np.float32) 

        for time_idx in range(init_idx, current_idx+1, self.input_interval):
            radar_data = np.load(self.radar_list[time_idx], allow_pickle=True)
            input_radar_history.append(process(radar_data['data']))
            input_timestamps.append(radar_data['dates'])
        for time_idx in range(current_idx+self.output_interval, final_idx+1, self.output_interval):
            radar_data = np.load(self.radar_list[time_idx], allow_pickle=True)
            output_radar_history.append(process(radar_data['data']))
            output_timestamps.append(radar_data['dates'])

        input_radar_history = np.stack(input_radar_history)
        output_radar_history = np.stack(output_radar_history)

        return input_radar_history, output_radar_history #, input_timestamps, output_timestamps
    
    

class MeteoNetDataset(Dataset):
    def __init__(self, radar_path='/data/MeteoNet_small/Radar/SE', year_from=2016, year_to=2018, 
                 input_length=13, input_interval=0.5, output_length=12, output_interval=0.5):

        # input_interval=0.5 means 5 minutes time interval
        # input_interval=1 means 10 minutes time interval
        input_interval = int(input_interval * 2)
        output_interval = int(output_interval * 2)

        self.radar_path = radar_path
        self.year_from = year_from
        self.year_to = year_to
        self.input_length = input_length
        self.input_interval = input_interval
        self.output_length = output_length
        self.output_interval = output_interval
        self.history_length = (input_length-1)*input_interval + output_length*output_interval

        days = [31,28,31,30,31,30,31,31,30,31,30,31]
        times = [f'{hour:02d}{minutes:02d}' for hour in range(0,24) for minutes in range(0,60,5)]
        times.sort(key=lambda x : int(x)) 

        self.radar_list = []
        data_in_years = []
        num_data = 0
        for year in range(year_from, year_to+1):
            for month in range(1,13):
                for day in range(1,days[month-1]+1):
                    day = day + 1 if (year % 4 == 0 and month == 2) else day  # leap year
                    for time in times:
                        YYYYMMDD = f"{year:04d}/{month:02d}/{day:02d}"
                        YYYYMMDDHHmm = int(f"{year:04d}{month:02d}{day:02d}{time}")
                        radar_filename = f"{YYYYMMDDHHmm}.tiff"
                        self.radar_list.append(f"{self.radar_path}/{YYYYMMDD}/{radar_filename}")
                        num_data += 1
            data_in_years.append(num_data)
        self.time_list = list(range(len(self.radar_list)))

        self.history_indices = []
        last_nonexist_index = -1

        for idx, radar_name in enumerate(self.radar_list):
            if idx in data_in_years:
                last_nonexist_index = idx-1

            if os.path.exists(radar_name):
                condition1 = (idx - last_nonexist_index > self.history_length)

                if condition1:
                    self.history_indices.append(idx)
                else:
                    pass
            else:
                last_nonexist_index = idx

        
    def __len__(self):
        return len(self.history_indices)

    def __getitem__(self, idx):
        final_idx = self.history_indices[idx] 
        current_idx = final_idx - self.output_length*self.output_interval
        init_idx = current_idx - (self.input_length-1)*self.input_interval

        input_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(init_idx, current_idx+1, self.input_interval)])
        output_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(current_idx+self.output_interval, final_idx+1, self.output_interval)])

        return input_radar_history, output_radar_history


class MeteoNetDataset_small(Dataset):
    def __init__(self, radar_path='/data/MeteoNet_small/Radar/SE', year_from=2016, year_to=2018, 
                 input_length=13, input_interval=0.5, output_length=12, output_interval=0.5):

        # input_interval=0.5 means 5 minutes time interval
        # input_interval=1 means 10 minutes time interval
        input_interval = int(input_interval * 2)
        output_interval = int(output_interval * 2)

        self.radar_path = radar_path
        self.year_from = year_from
        self.year_to = year_to
        self.input_length = input_length
        self.input_interval = input_interval
        self.output_length = output_length
        self.output_interval = output_interval
        self.history_length = (input_length-1)*input_interval + output_length*output_interval

        days = [31,28,31,30,31,30,31,31,30,31,30,31]
        times = [f'{hour:02d}{minutes:02d}' for hour in range(0,24) for minutes in range(0,60,5)]
        times.sort(key=lambda x : int(x)) 

        self.radar_list = []
        data_in_years = []
        num_data = 0
        for year in range(year_from, year_to+1):
            for month in range(1,13):
                for day in range(1,days[month-1]+1):
                    day = day + 1 if (year % 4 == 0 and month == 2) else day  # leap year
                    for time in times:
                        YYYYMMDD = f"{year:04d}/{month:02d}/{day:02d}"
                        YYYYMMDDHHmm = int(f"{year:04d}{month:02d}{day:02d}{time}")
                        radar_filename = f"{YYYYMMDDHHmm}.tiff"
                        self.radar_list.append(f"{self.radar_path}/{YYYYMMDD}/{radar_filename}")
                        num_data += 1
            data_in_years.append(num_data)
        self.time_list = list(range(len(self.radar_list)))

        self.history_indices = []
        last_nonexist_index = -1

        for idx, radar_name in enumerate(self.radar_list):
            if idx in data_in_years:
                last_nonexist_index = idx-1

            if os.path.exists(radar_name):
                condition1 = (idx - last_nonexist_index > self.history_length)

                if condition1:
                    self.history_indices.append(idx)
                else:
                    pass
            else:
                last_nonexist_index = idx

        
    def __len__(self):
        return len(self.history_indices)

    def __getitem__(self, idx):
        final_idx = self.history_indices[idx] 
        current_idx = final_idx - self.output_length*self.output_interval
        init_idx = current_idx - (self.input_length-1)*self.input_interval

        input_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(init_idx, current_idx+1, self.input_interval)])
        output_radar_history = np.stack([cv2.imread(self.radar_list[time_idx], cv2.IMREAD_UNCHANGED) for time_idx in range(current_idx+self.output_interval, final_idx+1, self.output_interval)])

        input_radar_history = input_radar_history[:, 0:384:3, 0:384:3]
        output_radar_history = output_radar_history[:, 0:384:3, 0:384:3]

        return input_radar_history, output_radar_history



if __name__ == '__main__':
    train_dataset = KMAradar4kmDataset(
        year_from=2014, year_to=2021,
        input_length=7, input_interval=1,
        output_length=6, output_interval=1
    )
    print(len(train_dataset))



    dataset = MeteoNetDataset(year_from=2016, year_to=2018,
                              input_length=12, input_interval=0.5,
                              output_length=12, output_interval=0.5)
    total_len = len(dataset)
    train_percent = 0.6
    train_len = int(total_len * train_percent)
    valid_len = int(total_len * (1 - train_percent) / 2)
    test_len  = total_len - (train_len + valid_len)

    from torch.utils.data import Subset
    train_dataset = Subset(dataset, indices=list(range(train_len)))
    valid_dataset = Subset(dataset, indices=list(range(train_len, train_len + valid_len)))
    test_dataset  = Subset(dataset, indices=list(range(train_len + valid_len, total_len)))

    print(len(train_dataset))