from torch.utils.data import Dataset
import torch

class MIPLL_Dataset(Dataset):
    def __init__(
        self, x, y, s, l, weak_transform, preimage=None
    ):
        self.x = x
        self.y = y
        self.s = s
        self.l = l
        self.M = len(y)
        self.weak_transform = weak_transform
        if isinstance(preimage, dict):
            self.preimage = [preimage[s_sample] for s_sample in self.s]
        else:
            assert preimage is None or isinstance(preimage, list) 
            self.preimage = preimage

    def __len__(self):
        return self.l

    def __getitem__(self, index):
        
        each_x_w = [self.weak_transform(self.x[i][index]) for i in range(self.M)]
        each_s = self.s[index]
        labels = [int(self.y[i][index].item()) for i in range(self.M)]
        return each_x_w + [each_s] + [ labels ] + [ self.preimage[index] ] if self.preimage is not None else each_x_w + [each_s]
    
    def get_proofs(self):
        return self.preimage
    
    @staticmethod
    def collate_fn(batch):
        L = len(batch[0])
        if isinstance(batch[0][L - 1], list):
            # preimage is present
            # print(batch)
            columns = [torch.stack([item[i] for item in batch]) for i in range(L - 3)]
            ys = torch.stack([torch.tensor(item[L - 3]).long() for item in batch])
            gt = [item[L - 2] for item in batch]
            preimage = [item[L - 1] for item in batch]
            return (columns, ys, gt, preimage)
        else:
            columns = [torch.stack([item[i] for item in batch]) for i in range(L - 1)]
            ys = torch.stack([torch.tensor(item[L - 1]).long() for item in batch])
            return (columns, ys)
    
    def update_proofs_samplewise(self, pre_image):
        dataset = MIPLL_Dataset(
            self.x,
            self.y,
            self.s,
            self.l,
            self.weak_transform,
            preimage=pre_image,
        )
        return dataset

#TODO: consider deleting the Gold_Dataset class
class Gold_Dataset(Dataset):
    def __init__(
        self, x, y, transform
    ):
        self.data = x
        self.targets = y
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, target = self.transform(self.data[index]), int(self.targets[index])
        return img, target