from torch.utils.data import Dataset, DataLoader
import torch as torch

import numpy as np


class SimpleDataSet(Dataset):
    """ load synthetic time series data"""
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __dim__(self):
        if len(self.x.shape) > 2:
            raise Exception("only handles single channel data")
        else:
            return self.x.shape[1]

    def __getitem__(self, idx):
        return (
            torch.from_numpy(np.array(self.x[idx])),
            torch.from_numpy(np.array(self.y[idx])),
        )


class StudentDataSet(Dataset):
    """ load synthetic time series data"""
    def __init__(self, x, y, y_robust_teacher):
        self.x = x
        self.y = y
        self.y_robust_teacher = y_robust_teacher

    def __len__(self):
        return self.x.shape[0]

    def __dim__(self):
        if len(self.x.shape) > 2:
            raise Exception("only handles single channel data")
        else:
            return self.x.shape[1]

    def __getitem__(self, idx):
        return (
            torch.from_numpy(np.array(self.x[idx])),
            torch.from_numpy(np.array(self.y[idx])),
            torch.from_numpy(np.array(self.y_robust_teacher[idx])),
        )


class MoMDataLoader():
    """ load synthetic time series data"""
    def __init__(self, training_data, batch_size):
        self.dataloader = []

        for x_i, y_i in training_data:
            data_set = SimpleDataSet(x_i, y_i)
            data_loader = DataLoader(data_set, batch_size=batch_size, shuffle=True)
            self.dataloader.append(data_loader)

    def get_ith_dataloader(self, i):
        return self.dataloader[i]





