import json
from torch.utils.data import Dataset
import monai.transforms as mt
from copy import deepcopy




class LabelledDS(Dataset):
    def __init__(self, json_file='', image_size=(80, 112, 112), stage='train', is_augmentation=False):
        super(LabelledDS, self).__init__()
        self.image_size = image_size
        self.stage = stage
        self.is_augmentation = is_augmentation
        self.augmentation = self.aug_transform()
        self.post_transform = self.init_post_transform()
        
        try:
            with open(json_file, "r") as f:
                data = json.load(f)
                if stage == 'train':
                    self.data = data['training']
                    self.pre_transform = self.init_pre_transform()
                elif stage == 'val':
                    self.data = data['test']
                    self.pre_transform = self.pre_transform_val()
                else:
                    print('Invalid stage. Defaulting to unlabelled')
                    self.data = []
                    self.pre_transform = self.pre_transform_val()
        except:
            print(f"File {json_file} not found.")


        


    def __getitem__(self, item):
        batch = self.data[item]
        #sample = self.pre_transform(batch)
        sample = self.pre_transform(batch)
        if self.stage == 'train':
            sampleA1, sampleA2 = deepcopy(sample), deepcopy(sample)
            sampleA1, sampleA2 = self.augmentation(sampleA1), self.augmentation(sampleA2)
            sample = self.post_transform(sample)
            sampleA1 = self.post_transform(sampleA1)
            sampleA2 = self.post_transform(sampleA2)
            return sample['image'], sample['label'], sampleA1['image'], sampleA2['image']
        else:
            sample = self.post_transform(sample)
            return sample['image'], sample['label']

    def __len__(self):
        return len(self.data)



    def init_pre_transform(self):
        return mt.Compose([
            mt.LoadImaged(keys=("image", "label"), image_only=False),
            mt.EnsureChannelFirstd(keys=("image", "label")),
            mt.CropForegroundd(['image', 'label'], source_key='image', margin=0, allow_smaller=False),
            mt.RandFlipd(keys=("image", "label"), spatial_axis=[0], prob=0.50),
            mt.RandFlipd(keys=("image", "label"), spatial_axis=[1], prob=0.50),
            mt.RandFlipd(keys=("image", "label"), spatial_axis=[2], prob=0.50),
            mt.RandSpatialCropd(keys=("image", "label"), roi_size=self.image_size, random_size=False),

                
        ])
    
    
    def pre_transform_val(self):
        return mt.Compose([
            mt.LoadImaged(keys=("image", "label"), allow_missing_keys=True, image_only=False),
            mt.EnsureChannelFirstd(keys=("image", "label"), allow_missing_keys=True),
            mt.CropForegroundd(['image', 'label'], source_key='image', margin=0, allow_smaller=False),

                
        ])
    
    
    def aug_transform(self):
        return mt.Compose([
                mt.RandShiftIntensityd(keys="image", offsets=0.10, prob=0.20),
                mt.RandGibbsNoised('image', prob=0.1),
                mt.RandScaleIntensityd('image', 0.1, 0.2),  # Reduced range for scaling
                mt.RandAdjustContrastd('image', gamma=(0.8, 1.2), prob=0.2),
                mt.RandGaussianNoised('image', prob=0.2, std=0.2),
                ])
    
    def init_post_transform(self):
        return mt.Compose([
            mt.ResizeWithPadOrCropd(keys=("image", "label"), spatial_size=self.image_size, allow_missing_keys=True),
            mt.ScaleIntensityRangePercentilesd(keys="image", lower=0.5, upper=99.5, b_min=0.0, b_max=1.0, clip=True),
            mt.ToTensor()
        ])



class UnlabelledDS(Dataset):
    def __init__(self, json_file='', image_size=(80, 112, 112), stage='train', is_augmentation=False):
        super(UnlabelledDS, self).__init__()
        self.image_size = image_size
        self.stage = stage
        self.is_augmentation = is_augmentation
        
        with open(json_file, "r") as f:
            self.data = json.load(f)['unlabelled']

        # TODO generate unlabelled dataset
        self.data = [{'image':x['image']} for x in self.data]
        self.pre_transform = self.pre_transform()
        self.augmentation = self.aug_transform()
        self.post_transform = self.init_post_transform()

    def __getitem__(self, item):
        batch = self.data[item]
        #sample = self.pre_transform(batch)
        sample = self.pre_transform(batch)
        if self.stage == 'train':
            sampleA1, sampleA2 = deepcopy(sample), deepcopy(sample)
            sampleA1, sampleA2 = self.augmentation(sampleA1), self.augmentation(sampleA2)
            sample = self.post_transform(sample)
            sampleA1 = self.post_transform(sampleA1)
            sampleA2 = self.post_transform(sampleA2)
            return sample['image'], '_', sampleA1['image'], sampleA2['image']
        else:
            sample = self.post_transform(sample)
            return sample['image'], '_'

    def __len__(self):
        return len(self.data)



    def pre_transform(self):
        return mt.Compose([
            mt.LoadImaged(keys=("image")),
            mt.EnsureChannelFirstd(keys=("image")),
            mt.RandFlipd(keys=("image"), spatial_axis=[0], prob=0.50),
            mt.RandFlipd(keys=("image"), spatial_axis=[1], prob=0.50),
            mt.RandFlipd(keys=("image"), spatial_axis=[2], prob=0.50),
            mt.RandAffined(keys=("image"), prob=0.2, shear_range=(0.0, 0.15), rotate_range=(0., 0.15), mode="bilinear", padding_mode="zeros"),
            mt.RandSpatialCropd(keys=("image"), roi_size=self.image_size, random_size=False),

                
        ])
    
    def aug_transform(self):
        return mt.Compose([
                mt.RandShiftIntensityd(keys="image", offsets=0.10, prob=0.20),
                mt.RandGibbsNoised('image', prob=0.1),
                mt.RandScaleIntensityd('image', 0.1, 0.2),  # Reduced range for scaling
                mt.RandAdjustContrastd('image', gamma=(0.8, 1.2), prob=0.2),
                mt.RandGaussianNoised('image', prob=0.2, std=0.2),
                ])
    
    def init_post_transform(self):
        return mt.Compose([
            mt.ResizeWithPadOrCropd(keys=("image"), spatial_size=self.image_size),
            mt.ScaleIntensityRangePercentilesd(keys="image", lower=0.5, upper=99.5, b_min=0.0, b_max=1.0, clip=True),
            mt.ToTensor()
        ])

