import torch
from torch.utils.data import Dataset


class DummyDataset(Dataset):
    input_size = 2

    def __init__(self, tag, size=1, x_0=1, y_0=1):
        super(DummyDataset, self).__init__()

        # self.tasks = {f'task{i+1}': [i, f'toy{i+1}'] for i in range(2)}
        self.tasks = {
            'loss1-1': [0, 'loss1-1'],
            'loss1-2': [1, 'loss1-2'],
            'loss2-1': [2, 'loss2-1'],
            'loss2-2': [3, 'loss2-2']
        }

        # self.data = torch.randn([size, 2], dtype=torch.float) * 2. + 1.
        self.data = torch.ones([size, 2], dtype=torch.float)
        self.data[:, 0] *= x_0
        self.data[:, 1] *= y_0

        self.target = torch.zeros((size, 4, 1)).unbind(dim=1)

    def __getitem__(self, index):
        return self.data[index], [t[index] for t in self.target]

    def __len__(self):
        return self.data.size(0)

    @torch.no_grad()
    def plot(self, model, tasks, title):  # TODO tag
        pass