import sys
import numpy as np
import torch
from skimage import transform as imtf
from skimage import exposure as imexp
from skimage import filters as imfilt

from training_utils.helpers import _to_repeated_list, pad_if_needed, fft2_np, ifft2_np, split_coils, stack_coils

sys.path.append('external/fastMRI')
from data import transforms as T
from common.subsample import create_mask_for_mask_type

class DataTransform:
    """
    Data Transformer for training Var Net models with data augmentation.
    """

    def __init__(self, 
                 mode, 
                 challenge, 
                 augmentor,
                 mask_func=None, 
                 use_seed=True, 
                 ):
        """
        Args:
            mask_func (common.subsample.MaskFunc): A function that can create a mask of
                appropriate shape.
            resolution (int): Resolution of the image.
            use_seed (bool): If true, this class computes a pseudo random number generator seed
                from the filename. This ensures that the same mask is used for all the slices of
                a given volume every time.
        """
        self.mask_func = mask_func
        self.use_seed = use_seed
        self.mode = mode
        self.challenge = challenge
        self.augmentor = augmentor
        
    def __call__(self, kspace, mask, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            mask (numpy.array): Mask from the test dataset
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                masked_kspace (torch.Tensor): Masked k-space
                mask (torch.Tensor): Mask
                target (torch.Tensor): Target image converted to a torch Tensor.
                fname (str): File name
                slice (int): Serial number of the slice.
                max_value (numpy.array): Maximum value in the image volume
        """
        # Setting calibration lines to zero
        acq_start = attrs['padding_left']
        acq_end = attrs['padding_right']
        max_value = attrs['max']

        if self.mode == 'train':                
            kspace, target, desc = self.augmentor(kspace, target)

        kspace = T.to_tensor(kspace)

        if self.challenge == 'singlecoil':
            kspace = kspace.unsqueeze(axis=0)
            
        if target is not None:
            target = T.to_tensor(target)
        else:
            target = torch.tensor(0)
            max_value = 0.0
                
        seed = None if not self.use_seed else tuple(map(ord, fname))            

        if self.mask_func:
            masked_kspace, mask = T.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace
            shape = np.array(kspace.shape)
            num_cols = shape[-2]
            shape[:-3] = 1
            mask_shape = [1 for _ in shape]
            mask_shape[-2] = num_cols
            mask = torch.from_numpy(mask.reshape(
                *mask_shape).astype(np.float32))
            mask[:, :, :acq_start] = 0
            mask[:, :, acq_end:] = 0

        return masked_kspace, mask.byte(), target, fname, slice, max_value
        
class AugmentationTransform:
    
    def __init__(self):
        self.order = None
        self.tag = None

    def __call__(self, im, tags_to_apply, new_rand=True):
        if new_rand:
            self.generate_new_random()
        if self.tag in tags_to_apply:
            images = split_coils(im)
            augmented_images = []
            for image in images:
                aug_im, desc = self.augment_function(image)
                augmented_images.append(aug_im)
            return stack_coils(augmented_images), desc
        else:
            return im, {}

    
    def augment_fuction(self, im):
        raise NotImplementedError
    
    def generate_new_random(self):
        pass

        
# Pixel blitting        
class Translate(AugmentationTransform):
    
    def __init__(self, max_translation):
        super().__init__()
        
        self.max_translation = None
        self.set_max_translation(max_translation)
        
        self.current_translation = None
        self.generate_new_random()
        
        self.type = 'blit'
        self.tag = 'translate'
        
    def augment_function(self, im):
        scaled_t_x = int(self.current_translation[0] * im.shape[-2])
        scaled_t_y = int(self.current_translation[1] * im.shape[-1])
        translate_fn = imtf.AffineTransform(translation=(scaled_t_x, scaled_t_y))
        im_real = imtf.warp(np.real(im), inverse_map=translate_fn, mode='reflect')
        im_imag = imtf.warp(np.imag(im), inverse_map=translate_fn, mode='reflect')
        return (im_real + 1j * im_imag).astype(np.complex64), \
               {'translation': (self.current_translation[0], self.current_translation[1])}
    
    def generate_new_random(self):
        t_x = np.random.uniform(-self.max_translation[0], self.max_translation[0])
        t_y = np.random.uniform(-self.max_translation[1], self.max_translation[1])
        self.current_translation = (t_x, t_y)
                
    def set_max_translation(self, max_translation):
        max_translation = _to_repeated_list(max_translation, 2)
        assert len(max_translation) == 2
        assert max(max_translation) <= 1.0 and min(max_translation) >= 0.0
        self.max_translation = max_translation
        
class Rot90(AugmentationTransform):
    
    def __init__(self):
        super().__init__()
        
        self.current_rotation = None
        self.generate_new_random()
        
        self.type = 'blit'
        self.tag = 'rot90'           
            
    def augment_function(self, im):
        return np.rot90(im, k=self.current_rotation, axes=(-2, -1)), \
               {'rot90': self.current_rotation*90}
    
    def generate_new_random(self):
        self.current_rotation = np.random.randint(1, 4)
                
        
class FlipH(AugmentationTransform):
    
    def __init__(self):
        super().__init__()
        self.type = 'blit'
        self.tag = 'fliph'
          
    def augment_function(self, im):
        return np.flip(im, axis=-2), \
               {'fliph': True}
    
class FlipV(AugmentationTransform):
    
    def __init__(self):
        super().__init__()
        self.type = 'blit'
        self.tag = 'flipv'
       
    def augment_function(self, im):
        return np.flip(im, axis=-1), \
                {'flipv': True}
        
# Geometric transformations   
class Affine(AugmentationTransform):
    
    def __init__(self, max_translation,
                       max_rotation,
                       max_shear, 
                       max_scale,  
                       isotropic_scaling=False,
                       order=1):
        super().__init__()
                 
        self.set_max(max_translation, max_rotation, max_shear, max_scale)
        self.isotropic_only = isotropic_scaling
        
        self.order = order
        
        self.current_params = {}
        self.generate_new_random()
        
        self.type = 'geometric'
        self.tag = ['translation', 'rotation','shearing', 'scaling']
        
        
    def __call__(self, im, tags_to_apply, new_rand=True):
        if new_rand:
            self.generate_new_random()
        
        for tag in self.tag:
            if not tag in tags_to_apply: 
                self.current_params[tag] = None
                
        images = split_coils(im)
        augmented_images = []
        for image in images:
            aug_im, desc = self.augment_function(image)
            augmented_images.append(aug_im)
            
        return stack_coils(augmented_images), desc

        
    def augment_function(self, im):
        descriptor = {}
        if self.current_params['translation'] is not None:
            scaled_t_x = int(self.current_params['translation'][1] * im.shape[-2]) # Image axes are flipped in skimage!
            scaled_t_y = int(self.current_params['translation'][0] * im.shape[-1])
            translation = (scaled_t_x, scaled_t_y)
            descriptor['translation'] = (self.current_params['translation'][1], self.current_params['translation'][0])
        else:
            translation = None
            
        if self.current_params['shearing'] is not None:
            shear = self.current_params['shearing']/180. * np.pi
            descriptor['shearing'] = self.current_params['shearing']
        else:
            shear = None
            
        if self.current_params['rotation'] is not None:
            rotation = self.current_params['rotation']/180. * np.pi
            descriptor['rotation'] = self.current_params['rotation']
        else:
            rotation = None
            
        if self.current_params['scaling'] is not None:
            scaling = self.current_params['scaling']
            descriptor['scaling'] = scaling
        else:
            scaling = None
        
        if len(descriptor) == 0:
            return im, {}
        
        im_center = (im.shape[-2]//2, im.shape[-1]//2)
        centering_fn = imtf.AffineTransform(translation=(-im_center[1], -im_center[0])) # Image axes are flipped in skimage!
        centering_fn_inv = imtf.AffineTransform(translation=(im_center[1], im_center[0]))

        affine_fn = imtf.AffineTransform(translation=translation,
                                         rotation=rotation,
                                         shear=shear,
                                         scale=scaling)
        
        inv_mapping = (centering_fn+ (affine_fn+centering_fn_inv)).inverse
        im_real = imtf.warp(np.real(im), 
                            inverse_map=inv_mapping, 
                            mode='reflect',
                            order=self.order)
        im_imag = imtf.warp(np.imag(im), 
                            inverse_map=inv_mapping,
                            mode='reflect',
                            order=self.order)
        
        return (im_real + 1j * im_imag).astype(np.complex64), descriptor
    
    def generate_new_random(self):
        self.current_params['rotation'] = np.random.uniform(low=-self.max_rotation, high=self.max_rotation)
        
        self.current_params['shearing'] = np.random.uniform(-self.max_shear, self.max_shear)
        
        if self.isotropic_only:
            self.current_iso = True
        else:
            self.current_iso = bool(np.random.randint(0, 2))
            
        scale_x = np.random.uniform(1-self.max_scale[0], 1 + self.max_scale[1])
        if self.current_iso:
            scale_y = scale_x
        else: 
            scale_y = np.random.uniform(1-self.max_scale[0], 1 + self.max_scale[1])
            
        self.current_params['scaling'] = (scale_x, scale_y)
        
        t_x = np.random.uniform(-self.max_translation[0], self.max_translation[0])
        t_y = np.random.uniform(-self.max_translation[1], self.max_translation[1])
        self.current_params['translation'] = (t_x, t_y)

    def set_max(self, max_translation, max_rotation, max_shear, max_scale):
        self.max_translation =  _to_repeated_list(max_translation, 2)
        self.max_rotation = max_rotation
        self.max_shear = max_shear
        self.max_scale =  _to_repeated_list(max_scale, 2)

class Rotate(AugmentationTransform):
    
    def __init__(self, max_rotation, order=1):
        super().__init__()
        
        self.max_rotation = None
        self.set_max_rotation(max_rotation)
        
        self.order = order
        
        self.current_rotation = None
        self.generate_new_random()
        
        self.type = 'geometric'
        self.tag = 'rotation'
             
    def augment_function(self, im):
        im_real = imtf.rotate(np.real(im), angle=self.current_rotation, mode='reflect', order=self.order)
        im_imag = imtf.rotate(np.imag(im), angle=self.current_rotation, mode='reflect', order=self.order)
        return (im_real + 1j * im_imag).astype(np.complex64), \
                {'rotation': self.current_rotation}
    
    def generate_new_random(self):
        self.current_rotation = np.random.uniform(low=-self.max_rotation, high=self.max_rotation)
                
    def set_max_rotation(self, max_rotation):
        self.max_rotation = max_rotation


class Shear(AugmentationTransform):
    
    def __init__(self, max_shear, order=1):
        super().__init__()
        
        self.max_shear = None
        self.set_max_shear(max_shear)
        
        self.order= order
        
        self.current_shear = None
        self.generate_new_random()
        
        self.type = 'geometric'
        self.tag = 'shearing'
       
    def augment_function(self, im):
        angle_in_rad = self.current_shear / 180. * np.pi
        shear_fn = imtf.AffineTransform(shear=angle_in_rad)
        im_real = imtf.warp(np.real(im), inverse_map=shear_fn, mode='reflect', order=self.order)
        im_imag = imtf.warp(np.imag(im), inverse_map=shear_fn, mode='reflect', order=self.order)
        return (im_real + 1j * im_imag).astype(np.complex64), \
                {'shearing': self.current_shear}
    
    def generate_new_random(self):
        self.current_shear = np.random.uniform(-self.max_shear, self.max_shear)
                
    def set_max_shear(self, max_shear):
        self.max_shear = max_shear
        
        
class Zoom(AugmentationTransform):
    
    def __init__(self, max_zoom_in, max_zoom_out, isotropic_only=False, order=1):
        super().__init__()
        
        self.isotropic_only = isotropic_only

        self.max_zoom_in = None
        self.max_zoom_out = None
        
        self.order=order

        self.set_max_zoom(max_zoom_in, max_zoom_out)
        
        self.current_scaling = None
        self.current_iso = None
        self.generate_new_random()
        
        self.type = 'geometric'
        self.tag = 'scaling'
        
        
    def augment_function(self, im):
        resized_shape = (int(im.shape[-2]*self.current_scaling[0]),
                         int(im.shape[-1]*self.current_scaling[1]))
        
        im_real = imtf.resize(np.real(im), resized_shape, order=self.order)
        im_real = pad_if_needed(im_real, min_shape=im.shape[-2:], mode='reflect')
        im_real = T.center_crop(im_real, im.shape[-2:])

        im_imag = imtf.resize(np.imag(im), resized_shape, order=self.order)
        im_imag = pad_if_needed(im_imag, min_shape=im.shape[-2:], mode='reflect')
        im_imag = T.center_crop(im_imag, im.shape[-2:])

        return (im_real + 1j * im_imag).astype(np.complex64), \
                {'scaling': (self.current_scaling[0], self.current_scaling[1])}
    
    def generate_new_random(self):
        if self.isotropic_only:
            self.current_iso = True
        else:
            self.current_iso = bool(np.random.randint(0, 2))
            
        scale_x = np.random.uniform(1-self.max_zoom_in, 1 + self.max_zoom_out)
        if self.current_iso:
            scale_y = scale_x
        else: 
            scale_y = np.random.uniform(1-self.max_zoom_in, 1 + self.max_zoom_out)
            
        self.current_scaling = (scale_x, scale_y)
                
    def set_max_zoom(self, max_zoom_in, max_zoom_out):
        assert max(max_zoom_in, max_zoom_out) <= 1.0 and min(max_zoom_in, max_zoom_out) >= 0.0
        self.max_zoom_in = max_zoom_in
        self.max_zoom_out = max_zoom_out

        
        
class AugmentationPipeline:
    
    def __init__(self, augmentations=None,
                 weight_dict=None,
                 gating_probability=1.0,
                 interpolation_order=1,
                 upsample_augment=False, 
                 upsample_factor=2, 
                 upsample_order=1,
                 augment_target=False):
        self.initialize_augmentations(augmentations, weight_dict, interpolation_order)
        self.upsample_augment = upsample_augment
        self.upsample_factor = upsample_factor
        self.upsample_order = upsample_order
        self.augmentation_strength = None
        self.gating_probability = gating_probability
        self.augment_target = augment_target
        
    def initialize_augmentations(self, augmentations,weight_dict, order):
        self.augmentations = augmentations
        self.augmentation_tags = []
        
        # Set transformation interpolation mode and gather all transformation names
        for a in self.augmentations:
            a.order = order
            if isinstance(a.tag, list):
                self.augmentation_tags.extend(a.tag)
            else:
                self.augmentation_tags.append(a.tag)
                
        self.weight_dict = weight_dict
        if self.weight_dict is None:
            self.weight_dict = {}
            
        # If a transformation doesn't have a specified weight it will be set to 1
        self.augmentation_tags = list(set(self.augmentation_tags))
        for tag in self.augmentation_tags:
            if tag not in self.weight_dict:
                self.weight_dict[tag] = 1.0
            
        
    def set_augmentation_strength(self, augmentation_strength):             
        self.augmentation_strength = augmentation_strength
        
    def augment_image(self, im, output_size=None):
        def has_common_element(a, b):
            return len(list(set(a) & set(b))) > 0
            
        gating = float(np.random.uniform(0., 1.) < self.gating_probability)
        
        if output_size is None:
            output_size = im.shape[-2:]
        else:
            output_size = _to_repeated_list(output_size, 2)
        if self.augmentations is None:
            print('Uninitialized augmentation pipeline.')
            return im

        # Randomly pick augmentations to be applied
        to_be_applied=[]
        for tag in self.augmentation_tags:
            probability = self.augmentation_strength * self.weight_dict[tag] * gating
            if np.random.uniform(0., 1.) < probability:
                to_be_applied.append(tag)

        # Upsample only if there is a transformation that needs interpolation
        need_upsample = self.upsample_augment and has_common_element(['rotation', 'shearing', 'scaling', 'translation'], to_be_applied)
        if need_upsample:
            im = self.upsample_image(im, factor=self.upsample_factor)

        augmentation_descriptor = []

        # Apply transformations that are in the list
        for t in self.augmentations:
            im, d = t(im, tags_to_apply=to_be_applied)
            if len(d) > 0:
                augmentation_descriptor.append(d)

        if need_upsample:
            im = self.downsample_image(im, factor=self.upsample_factor)
            
        im = pad_if_needed(im, min_shape=output_size, mode='reflect')
        im = T.center_crop(im, output_size)
        
        return im, augmentation_descriptor
    
    def augment_from_kspace(self, kspace, target_size, train_size=None):
        if train_size is None:
            train_size = kspace.shape[-2:]
        else: 
            train_size = _to_repeated_list(train_size, 2)
        target_size = _to_repeated_list(target_size, 2)
        
        im = ifft2_np(kspace)
        if self.augment_target:
            im = self.im_to_target(im)
            
        im, augmentation_descriptor = self.augment_image(im, output_size=train_size)
        
        target = self.im_to_target(im, target_size)
        
        kspace = fft2_np(im)
        return kspace, target, augmentation_descriptor
    
    
    def im_to_target(self, im, target_size=None):
        if target_size is None:
            target_size = [im.shape[-2], im.shape[-1]]
            
        if len(im.shape) == 2:
            target = np.abs(T.center_crop(im, target_size)).astype(np.float32)
        else:
            assert len(im.shape) == 3
            target = np.sqrt(np.sum(np.square(np.abs(im)), axis=0))
            target = T.center_crop(target, target_size).astype(np.float32)
        return target        
    
    def upsample_image(self, im, factor=2):
        upsampled_size = (im.shape[-2]*factor, im.shape[-1]*factor)
        images = split_coils(im)  
        upsampled_images = []
        for image in images:
            upsampled_images.append( 
                imtf.resize(np.real(image), 
                             output_shape=upsampled_size, 
                             order=self.upsample_order, 
                             mode='reflect') + \
            1j * imtf.resize(np.imag(image), 
                             output_shape=upsampled_size, 
                             order=self.upsample_order, 
                             mode='reflect') 
            )
        upsampled_images = stack_coils(upsampled_images)
        return upsampled_images
    
    def downsample_image(self, im, factor=2):
        downsampled_size = (im.shape[-2]//factor, im.shape[-1]//factor)
        images = split_coils(im)  
        downsampled_images = []
        for image in images:
            downsampled_images.append(  imtf.resize(np.real(image), 
                             output_shape=downsampled_size, 
                             order=self.upsample_order, 
                             mode='reflect',
                             anti_aliasing=True) + \
            1j * imtf.resize(np.imag(image), 
                             output_shape=downsampled_size, 
                             order=self.upsample_order, 
                             mode='reflect',
                             anti_aliasing=True)
                                     )
        downsampled_images = stack_coils(downsampled_images)
        return downsampled_images
    
class DataAugmentor:
    
    def __init__(self, hparams, current_epoch_fn):
        self.current_epoch_fn = current_epoch_fn
        self.hparams = hparams
        self.aug_on = hparams.aug_on
        if self.aug_on:
            self.augmentation_pipeline = self.create_augmentation_pipeline(hparams)
        if hparams.train_resolution is not None:
            self.hparams.train_resolution = _to_repeated_list(hparams.train_resolution, 2)
        self.hparams.resolution = _to_repeated_list(hparams.resolution, 2)
        
    def __call__(self, kspace, target):
        
        desc = []

        # Set augmentation probability
        if self.aug_on:
            current_epoch = self.current_epoch_fn()
            p = schedule_p(self.hparams, current_epoch)
            self.augmentation_pipeline.set_augmentation_strength(p)
        else:
            p = 0.0
        
        # Augment if needed
        if self.aug_on and p > 0.0:
                kspace, target, desc = self.augmentation_pipeline.augment_from_kspace(kspace,
                                                                          self.hparams.resolution,
                                                                          self.hparams.train_resolution)
        else:
            if self.hparams.train_resolution is not None:
                # Padding and center cropping to training size
                im = ifft2_np(kspace)
                im = pad_if_needed(im, self.hparams.train_resolution, 'reflect')                
                im = T.center_crop(im, self.hparams.train_resolution)
                kspace = fft2_np(im)
                target = T.center_crop(target, self.hparams.resolution)
        return kspace, target, desc
            

    def create_augmentation_pipeline(self, hparams):
        if hparams.aug_on:
            weight_dict = {'translation': hparams.aug_weight_translation,
                          'rotation': hparams.aug_weight_rotation,
                          'scaling': hparams.aug_weight_scaling,
                          'shearing': hparams.aug_weight_shearing,
                          'rot90': hparams.aug_weight_rot90,
                          'fliph': hparams.aug_weight_fliph,
                          'flipv': hparams.aug_weight_flipv}
            
            augmentations = [FlipH(),
                             FlipV(),
                             Rot90(),
                             Affine(max_translation=(hparams.aug_max_translation_x, hparams.aug_max_translation_y),
                                   max_rotation=hparams.aug_max_rotation,
                                   max_shear=hparams.aug_max_shearing,
                                   max_scale=(hparams.aug_max_scaling_x, hparams.aug_max_scaling_y))]
            
            augmentation_pipeline = AugmentationPipeline(augmentations=augmentations, 
                                                         weight_dict=weight_dict,
                                                         gating_probability=hparams.aug_gating, 
                                                         upsample_augment=hparams.aug_upsample, 
                                                         upsample_factor=hparams.aug_upsample_factor, 
                                                         upsample_order=hparams.aug_upsample_order,
                                                         interpolation_order=hparams.aug_interpolation_order,
                                                         augment_target=hparams.aug_target)
            return augmentation_pipeline
        else:
            return None
        
    
def schedule_p(hparams, epoch):
    D = hparams.aug_delay
    T = hparams.num_epochs
    t = epoch
    p_max = hparams.aug_strength

    if epoch < D:
        return 0.0
    else:
        if hparams.aug_schedule == 'constant':
            p = p_max
        elif hparams.aug_schedule == 'ramp':
            p = (t-D)/(T-D) * p_max
        elif hparams.aug_schedule == 'exp':
            c = hparams.aug_exp_decay/(T-D) # Decay coefficient
            p = p_max/(1-np.exp(-(T-D)*c))*(1-np.exp(-(t-D)*c))
        return p

def create_data_transform(hparams, mode, current_epoch_fn=None):
    is_train = (mode == 'train')
    mask = create_mask_for_mask_type(hparams.mask_type, hparams.center_fractions,
                                     hparams.accelerations)
    return DataTransform(mode=mode,
                        challenge=hparams.challenge,
                        mask_func=mask, 
                        use_seed=(not is_train), 
                        augmentor=DataAugmentor(hparams, current_epoch_fn))
        
def add_augmentation_specific_args(parser):
    parser.add_argument('--aug-on', default=False,
                    help='This switch turns data augmentation on.',action='store_true')
    parser.add_argument('--aug-target', default=False,
                    help='If set, augmentation will be applied to real target images instead of complex fully-sampled images. Seriously degrades reconstruction quality, only set it for experimentation.',action='store_true')
    # --------------------------------------------
    # Related to augmentation strenght scheduling
    # --------------------------------------------
    parser.add_argument('--aug-schedule', type=str, default='ramp',
                        help='Type of data augmentation strength scheduling. Options: constant, ramp, exp')
    parser.add_argument('--aug-delay', type=int, default=0,
                        help='Number of epochs at the beginning of training without data augmentation. The schedule in --aug-schedule will be adjusted so that at the last epoch the augmentation strength is --aug-strength.')
    parser.add_argument('--aug-strength', type=float, default=0.0, 
                        help='Augmentation strength, combined with --aug-schedule determines '
                                          'the augmentation strength in each epoch')
    parser.add_argument('--aug-gating', type=float, default=1.0, 
                        help='Probability that augmentation is applied to an image. First the augmentation is checked against this probability, then the augmentation transform is generated according to augmentation strength if gating is passed.')
    parser.add_argument('--aug-exp-decay', type=float, default=5.0, 
                        help='Exponential decay coefficient if --aug-schedule is set to exp. 1.0 is close to linear, 10.0 is close to step function')
    
    # --------------------------------------------
    # Related to interpolation 
    # --------------------------------------------
    parser.add_argument('--aug-interpolation-order', type=int, default=1,
                        help='Order of interpolation filter used in data augmentation, 1: bilinear, 3:bicubic')
    parser.add_argument('--aug-upsample', default=False,
                        help='Upsample before augmentation to improve quality of augmented images',action='store_true')
    parser.add_argument('--aug-upsample-factor', type=int, default=2,
                        help='Factor of upsampling before augmentation, if --aug-upsample is set')
    parser.add_argument('--aug-upsample-order', type=int, default=1,
                        help='Order of upsampling filter before augmentation, 1: bilinear, 3:bicubic')

    # --------------------------------------------
    # Related to transformation probability weights
    # --------------------------------------------
    parser.add_argument('--aug-weight-translation', type=float, default=1.0, 
                        help='Weight of translation probability. Augmentation probability will be multiplied by this constant')
    parser.add_argument('--aug-weight-rotation', type=float, default=1.0, 
                        help='Weight of arbitrary rotation probability. Augmentation probability will be multiplied by this constant')  
    parser.add_argument('--aug-weight-shearing', type=float, default=1.0, 
                        help='Weight of shearing probability. Augmentation probability will be multiplied by this constant')
    parser.add_argument('--aug-weight-scaling', type=float, default=1.0, 
                        help='Weight of scaling probability. Augmentation probability will be multiplied by this constant')
    parser.add_argument('--aug-weight-rot90', type=float, default=1.0, 
                        help='Weight of probability of rotation by multiples of 90 degrees. Augmentation probability will be multiplied by this constant')  
    parser.add_argument('--aug-weight-fliph', type=float, default=1.0, 
                        help='Weight of horizontal flip probability. Augmentation probability will be multiplied by this constant')
    parser.add_argument('--aug-weight-flipv', type=float, default=1.0, 
                        help='Weight of vertical flip probability. Augmentation probability will be multiplied by this constant') 
    
    # --------------------------------------------
    # Related to transformation limits
    # --------------------------------------------
    parser.add_argument('--aug-max-translation-x', type=float, default=0.125, 
                        help='Maximum translation applied along the x axis as fraction of image width')
    parser.add_argument('--aug-max-translation-y', type=float, default=0.125, 
                        help='Maximum translation applied along the y axis as fraction of image height')
    parser.add_argument('--aug-max-rotation', type=float, default=180., 
                        help='Maximum rotation applied in either clockwise or counter-clockwise direction in degrees.')
    parser.add_argument('--aug-max-shearing', type=float, default=15.0, 
                        help='Maximum shearing applied in either positive or negative direction in degrees.')
    parser.add_argument('--aug-max-scaling-x', type=float, default=0.25, 
                        help='Maximum scaling applied along x axis as fraction of image width. If set to s_x, a scaling factor between 1.0-s_x and 1.0+s_x will be applied.')
    parser.add_argument('--aug-max-scaling-y', type=float, default=0.25, 
                        help='Maximum scaling applied along y axis as fraction of image height. If set to s_y, a scaling factor between 1.0-s_y and 1.0+s_y will be applied.')
    
    return parser
