from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import torch
import torch.nn.functional as F

def to_synthetic(target, with_contrast=True, size=(512,512)):
    target = F.interpolate(target.unsqueeze(0).unsqueeze(0), size=size, mode='bilinear', align_corners=True)[0, 0]
    target /= torch.max(target)
    if with_contrast:
        cont_target = torch.sigmoid((target-0.5)*20)
        cont_target -= torch.min(cont_target)
        cont_target /= torch.max(cont_target)
        return cont_target
    else:
        return target

class Synthetic2Real(Dataset):
    def __init__(self, data):
        """
        Args:
            data (list or array): Your data.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data = data
        self.train_transform = transforms.Compose(
            [
                transforms.RandomAffine((-180, 180), (0, 0), (0.8, 1.2), interpolation=InterpolationMode.BILINEAR),
                # transforms.ToTensor(),
            ]
        )

    def test_transform(self):
        return transforms.Compose([
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with torch.no_grad():
            x = torch.tensor(self.data[idx])
            x = to_synthetic(x, with_contrast=False)
            y = to_synthetic(x, with_contrast=True)

        sample = self.train_transform(torch.stack([x, y], 0))

        return sample[1], sample[0]