import torch

import numpy as np

from typing import Iterable
from torch.utils.data import Dataset


__all__ = [
    'SWDataset'
]

class SWDataset(Dataset):
    def __init__(self, data: Iterable):
        self.len = len(data)
        self.gdr, self.data = [], []

        for sample in data['data']:
            self.gdr.append(torch.tensor([
                sample['gravity'], sample['depth'], sample['rossby_radius']
            ]))
            self.data.append(sample['data'])

        self.gdr = torch.stack(self.gdr).to(torch.float32)
        self.data = torch.from_numpy(np.stack(self.data)).to(torch.float32)

    def __len__(self):
        return len(self.gdr)

    def __getitem__(self, idx):
        return self.gdr[idx], self.data[idx]