from typing import OrderedDict

import einops
import torch
from torch.utils.data import Dataset


class DataSet(Dataset):
    """
        Base class for our dataset. Works both with supervised data as well as with unsupervised.
        Input images should be channel-last
    """
    def __init__(self, *x):
        super().__init__()
        assert len(x) > 0 , "At least one dataset must be given"
        self.data = []
        for xx in x:
            if isinstance(xx, dict): # ROI-split fmri 
                data = OrderedDict()
                for roi in xx:
                    data[roi] = torch.tensor(xx[roi], dtype=torch.float32)
                    self.len = xx[roi].shape[0]
            else: #image
                data = torch.tensor(xx, dtype=torch.float32)
                data = einops.rearrange(data, "n w h c -> n c w h")
                data = data / torch.max(data) 
                self.len = data.shape[0]
            self.data.append(data)

    def __len__(self):
        return self.len

    def __getitem__(self, i):
        if i >= self.len: raise IndexError 
        out = []
        for x in self.data:
            tmp = OrderedDict()
            if isinstance(x, dict):
                for key in x:
                    tmp[key] = x[key][i].cuda()
                out.append(tmp)
            else:
                out.append(x[i].cuda())
        return out 
