import torch
import numpy as np
import typing


###### corruption
class NoiseCorrupter(object):
    def __init__(self, mean:float=0, std:float=1, same_across_rgb:bool=False, 
                    noise_only:bool=False, source_save:bool=True, seed:typing.Union[None,int]=None):
        '''
        A class for corrupting vector data by adding noise.

        
        Arguments
        ---------

        - ```mean```: ```float``` (optional):
            Mean of the gaussian noise to add.
            Defaults to ```0```.
        
        - ```std```: ```float``` (optional):
            Standard deviation of the gaussian noise to add.
            Defaults to ```1```.

        - ```same_across_rgb```: ```bool``` (optional):
            This dictates whether the same shuffling will be used in
            each of the three dimensions of an image given.
            Defaults to ```False```.

        - ```noise_only```: ```bool``` (optional):
            This dictates whether the corrupted samples
            will have noise added or be replaces with noise.
            If ```True```, data will be replaced with noise.
            Defaults to ```False```.

        - ```source_save```: ```bool``` (optional):
            Whether to use the same corruption for 
            the same index, or whether to 
            randomise the corruption every time an image is streamed.
            Defaults to ```True```.
        
        - ```seed```: ```None | int``` (optional):
            Whether to set the random operations in this
            class using a random seed.
            Defaults to ```None```.
        
        '''

        self.mean = mean
        self.std = std

        if seed is None:
            rng = np.random.default_rng(None)
            self.seed = rng.integers(low=1, high=1e9, size=1)[0]
        else:
            rng = np.random.default_rng(seed)
            self.seed = rng.integers(low=1, high=1e9, size=1)[0]
        self.seed = int(self.seed)
        self.same_across_rgb = same_across_rgb
        self.noise_only = noise_only
        self.source_save = source_save

    
    def add_noise(self, x, source, index):

        if self.source_save:
            g = torch.Generator().manual_seed(index+self.seed)
            noise = torch.zeros(x.shape).normal_(mean=self.mean, std=self.std, generator=g)
        else:
            self.seed = int(np.random.default_rng(self.seed).integers(low=1, high=1e9, size=1)[0])
            g = torch.Generator().manual_seed(self.seed)
            noise = torch.normal(mean=torch.tensor(self.mean).float(), 
                                    std=torch.tensor(self.std).float(), 
                                    size=x.shape, generator=g)
        x_out = x + noise
        
        return x_out
    
    def add_noise_rgb(self, x, source, index):

        if self.source_save:
            g = torch.Generator().manual_seed(index+self.seed)
            noise = torch.zeros(x[0,:,:].shape).normal_(mean=self.mean, std=self.std, generator=g)
        else:
            self.seed = int(np.random.default_rng(self.seed).integers(low=1, high=1e9, size=1)[0])
            g = torch.Generator().manual_seed(self.seed)
            noise = torch.normal(mean=torch.tensor(self.mean).float(), 
                                    std=torch.tensor(self.std).float(), 
                                    size=x[0,:,:].shape, generator=g)
        
        x_out = x + noise

        return x_out
    
    def __call__(self, x:torch.tensor, y, source, index:int, **kwargs):
        '''
        
        Arguments
        ---------

        - ```x```: ```torch.tensor```:
            The vector to have its chunks permutated.
        
        - ```y```: target
            This is the target of the input. This is ignored
            by the function.
        
        - ```source```: ```hashable```:
            This is the source of the input. This is ignored
            by the function.
        
        - ```index```: ```int```:
            This is the index of the data point that is 
            being corrupted. The index is used to make
            sure that the same data point is corrupted 
            in the same way each time it is called.
            This is only used if ```source_save=True```.

        
        '''

        if self.noise_only:
            if self.source_save:
                g = torch.Generator().manual_seed(index+self.seed)
                x_out = torch.zeros(x.shape).normal_(mean=self.mean, std=self.std, generator=g)
            else:
                self.seed = int(np.random.default_rng(self.seed).integers(low=1, high=1e9, size=1)[0])
                g = torch.Generator().manual_seed(self.seed)
                x_out = torch.normal(mean=torch.tensor(self.mean).float(), 
                                        std=torch.tensor(self.std).float(), 
                                        size=x.shape, generator=g)
        else:
            if len(x.shape) <= 2:
                x_out =  self.add_noise(x, source, index)

            elif len(x.shape) == 3:
                if self.same_across_rgb:
                    x_out =  self.add_noise_rgb(x, source, index)
                else:
                    x_out =  self.add_noise(x, source, index)
            else:
                raise NotImplementedError('Please supply a 1D, 2D, or 3*2D x.')

        return x_out, y


class LabelRandom(object):
    def __init__(self, 
                    labels:int=10,
                    seed:int=None,
                    source_save:bool=False,
                    ):
        '''
        This randomly assigns a new label to a given input.
        
        Arguments
        ---------

        - ```labels```: ```int``` or ```list``` (optional):
            If ```int```, then all integers smaller
            than this value are possibly assigned labels.
            If ```list```, then these labels are used
            for the randomly assigned labels.
        
        - ```seed```: ```int```, optional:
            The random seed to set the random labels. 
            Defaults to ```None```.

        - ```source_save```: ```bool```, optional:
            This saves the random label mapping
            corruption by source, so that a given source
            maps the labels in the same way.
            Defaults to ```False```.
        
        '''

        if seed is None:
            rng = np.random.default_rng(None)
            self.seed = rng.integers(low=1, high=10000, size=1)[0]
        else:
            rng = np.random.default_rng(seed)
            self.seed = rng.integers(low=1, high=1e9, size=1)[0]
    
        rng = np.random.default_rng(self.seed)
        self.rng = np.random.default_rng(rng.integers(low=1, high=10000, size=1))

        self.labels = labels
        self.source_save = source_save
        self.source_save_dict = {}

        return
    
    def __call__(self, x:torch.tensor, y, source, **kwargs):
        '''
        
        Arguments
        ---------

        - ```x```: ```torch.tensor```:
            The vector to have its chunks permutated.
        
        - ```y```: target
            This is the target of the input.
        
        - ```source```: ```hashable```:
            This is ignored.
        
        '''

        new_seed = self.rng.integers(low=1, high=10000, size=1)
        self.rng = np.random.default_rng(new_seed)

        if self.source_save:
            if not source in self.source_save_dict:
                self.source_save_dict[source] = {}
            
            if not y in self.source_save_dict[source]:
                self.source_save_dict[source][y] = self.rng.choice(self.labels)
            
            y_out = self.source_save_dict[source][y]
            
        else:
            y_out = self.rng.choice(self.labels)

        return x, y_out


class SourceChunkSwap(object):
    def __init__(self, 
                    n_xpieces:int=10, 
                    source_save:bool=True, 
                    seed:int=None, 
                    same_across_rgb:bool=False,
                    ):
        '''
        A class for corrupting data by swapping chunks of the
        data with eachother.
        
        Arguments
        ---------

        - ```n_xpieces```: ```int``` (optional):
            The number of chunks in the input to rearrange.
            For 2D shapes, this is the number of chunks in the
            x and y direction. This means that for a 2D shape,
            ```n_xpieces**2``` chunks will be rearranged.
            Defaults to ```10```.
        
        - ```source_save```: ```bool``` (optional):
            Whether to use the same corruption for 
            the same example, or whether to 
            randomise the corruption every time an image is streamed.
            Defaults to ```True```.
        
        - ```seed```: ```int``` (optional):
            This value determines the random process for 
            the swapping of chunks in the corruption process.
            Defaults to ```None```.
        
        - ```same_across_rgb```: ```bool``` (optional):
            This dictates whether the same shuffling will be used in
            each of the three dimensions of an image given.
            Defaults to ```False```.
        
        '''

        self.n_xpieces = n_xpieces
        self.source_save = source_save

        if seed is None:
            rng = np.random.default_rng(None)
            self.seed = rng.integers(low=1, high=1e9, size=1)[0]
        else:
            rng = np.random.default_rng(seed)
            self.seed = rng.integers(low=1, high=1e9, size=1)[0]
        
        self.rng = np.random.default_rng(self.seed)

        self.same_across_rgb = same_across_rgb

    def chunk_rearrange(self, data, chunk_sizes, new_order):
        '''
        Adapted from https://stackoverflow.com/a/62292488
        to require and work with ```torch.tensor```s.
        '''
        m = chunk_sizes[:,None] > torch.arange(chunk_sizes.max())
        d1 = torch.empty(m.shape, dtype=data.dtype)
        d1[m] = data
        return d1[new_order][m[new_order]]

    def chunk_rearrange_2d(self, data, chunk_sizes, new_order):
        swap_x = self.chunk_rearrange(torch.arange(data.shape[1]), chunk_sizes=chunk_sizes[1], new_order=new_order[1])
        swap_y = self.chunk_rearrange(torch.arange(data.shape[0]), chunk_sizes=chunk_sizes[0], new_order=new_order[0])
        return data[:, swap_y][swap_x,:]
    
    def call_1d(self, x, source, index):

        x_shape = x.shape[0]
        box_bounds = (torch.linspace(0, x_shape, self.n_xpieces+1, dtype=int)).reshape(-1,1)
        chunks = box_bounds[1:] - box_bounds[:-1]        

        if self.source_save:
            seed = index + self.seed
            self.rng = np.random.default_rng(seed)
        else:
            self.seed = self.rng.integers(low=1, high=1e9, size=1)[0]
            seed = self.seed
            self.rng = np.random.default_rng(seed)
            
        new_order = self.rng.permutation(len(chunks))
        x_out = self.chunk_rearrange(x, chunk_sizes=chunks, new_order=new_order)
        
        return x_out

    def call_2d(self, x, source, index):

        xy_shape = x.shape[0]
        xx_shape = x.shape[1]

        if self.n_xpieces > min(xy_shape, xx_shape):
            raise TypeError('Please make sure that the number of pieces is smaller than '\
                            'both sides of the input. Array was shape {} and the number '\
                                'of pieces was {}.'.format(x.shape, self.n_xpieces))
        
        ybox_bounds = (torch.linspace(0, xy_shape, self.n_xpieces+1, dtype=int)).reshape(-1,1)
        xbox_bounds = (torch.linspace(0, xx_shape, self.n_xpieces+1, dtype=int)).reshape(-1,1)
        
        ychunks = ybox_bounds[1:] - ybox_bounds[:-1]
        xchunks = xbox_bounds[1:] - xbox_bounds[:-1]

        chunks = torch.cat([ychunks.reshape(1,-1), xchunks.reshape(1,-1)], dim=0)

        if self.source_save:
            seed = index +  self.seed
            self.rng = np.random.default_rng(seed)
            self.rng = np.random.default_rng(self.rng.integers(low=1, high=1e9, size=2))
        else:
            self.seed = self.rng.integers(low=1, high=1e9, size=1)[0]
            seed = self.seed
            self.rng = np.random.default_rng(seed)
            self.rng = np.random.default_rng(self.rng.integers(low=1, high=1e9, size=2))

        new_order = []
        xnew_order = self.rng.permutation(len(xchunks), axis=1)
        ynew_order = self.rng.permutation(len(ychunks), axis=1)
        new_order.extend([xnew_order, ynew_order])
        x_out = self.chunk_rearrange_2d(x, chunk_sizes=chunks, new_order=new_order)
        
        return x_out
    
    def call_rgb(self, x, source, index):

        idx = torch.arange(len(x[0,:, :].reshape(-1))).reshape(x.shape[1], x.shape[2])
        if self.same_across_rgb:
            new_idx = self.call_2d(idx, source=source, index=index)
            # apply the new index to each channel
            x_out = x.reshape(x.shape[0], -1)[:,new_idx.reshape(-1)].reshape(x.shape)
        else:
            new_idx_list = [self.call_2d(idx, source=source, index=index) for _ in range(3)]
            # apply the new index to each channel
            x_out = x.reshape(x.shape[0], -1)
            for ii, new_idx in enumerate(new_idx_list):
                x_out[ii, :] = x_out[ii, new_idx.reshape(-1)]
            x_out = x_out.reshape(x.shape)

        return x_out
    
    def __call__(self, x:torch.tensor, y, source, index, **kwargs):
        '''
        
        Arguments
        ---------

        - ```x```: ```torch.tensor```:
            The vector to have its chunks permutated.
        
        - ```y```: target
            This is the target of the input and is ignored.
        
        - ```source```: ```hashable```:
            The source of the input, which will be used to save
            the corruption for that source if 
            ```source_save=True```.
        
        '''
        if x.shape[0] == 1:
            x_out = x.reshape(-1)
            reshape_after = True
        else:
            x_out = x
            reshape_after = False

        if len(x_out.shape) == 2:
            x_out =  self.call_2d(x_out, source=source, index=index)
        elif len(x_out.shape) == 1:
            x_out =  self.call_1d(x_out, source=source, index=index)
        elif len(x_out.shape) == 3:
            x_out =  self.call_rgb(x_out, source=source, index=index)
        else:
            raise NotImplementedError('Please supply a 1D, 2D, or 3*2D x.')
        
        if reshape_after:
            x_out = x_out.reshape(1,-1)

        return x_out, y


