import torch
from torch.utils.data import Dataset

class ShortestPathDataset(Dataset):
    def __init__(self, maps, targets=None, weights=None, solver=None, transform=None):
        """
        Args:
            data (list): A list of data samples.
            targets (list): A list of corresponding targets.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        
        if targets is None:
            if (weights is None) or (solver is None):
                raise Exception('No targets given')
            else:
                self.solve=True
        else:
            self.solve=False
        
        self.maps = maps
        self.weights = weights
        self.targets = targets
        self.solver = solver
        self.transform = transform

    def __len__(self):
        return self.maps.shape[0]

    def __getitem__(self, index):
        sample = self.maps[index]
        if self.solve:
            w =self.weights[index]
            target = self.solver(w)
        else: 
            target = self.targets[index]

        target = torch.flatten(torch.tensor(target))

        # Apply the transform if available
        if self.transform:
            sample = self.transform(sample)

        return sample, target