import abc
import numpy as np
from PIL import Image

#Background base class.  If creating your own background, subclass from this.
class Background(abc.ABC):

    @abc.abstractmethod
    def newBackground(self) -> Image.Image:
        pass


class Blank(Background):

    #x_size (int): X size of image.
    #y_size (int): Y size of image.
    def __init__(self, x_size: int, y_size: int):
        self.x_size = x_size
        self.y_size = y_size
        return

    #Returns a blank background.
    def newBackground(self) -> Image.Image:
        return Image.new("L", (self.x_size, self.y_size), 0)


class WhiteNoise(Background):

    #x_size (int): X size of image.
    #y_size (int): Y size of image.
    def __init__(self, x_size: int, y_size: int):
        self.x_size = x_size
        self.y_size = y_size
        return
    
    #Returns a white noise background.
    def newBackground(self) -> Image.Image:
        return Image.fromarray(np.random.random((self.x_size, self.y_size)) * 255)


#Generate pink noise by bandlimiting white noise.
class PinkNoise(Background):

    #x_size (int): X size of image.
    #y_size (int): Y size of image.
    #Alpha (float, [0,2]): Corresponds to different noises.  0 is white, 1 is pink, 2 is brown, etc.  The coefficient to scale the magnitudes to.
    def __init__(self, x_size: int, y_size: int, alpha: float = 1.0):
        self.x_size = x_size
        self.y_size = y_size
        self.alpha = alpha
        return
    
    #Returns a randomly generated pink noise background from previously supplied alpha.
    def newBackground(self) -> Image.Image:

        #Setup - get blank magnitudes array and make noise to get random magnitudes and phase angles.
        output_magnitudes = np.zeros((self.x_size,self.y_size))
        noise = np.random.random((self.x_size, self.y_size))
        fft_white_noise = np.fft.fftshift(np.fft.fft2(noise))

        #Compute the phase (to later reconstruct), and magnitudes to scale.
        mag = np.abs(fft_white_noise)
        phase = np.angle(fft_white_noise)

        #Compute center of the image, or where the DC is.
        center_x = self.x_size // 2
        center_y = self.y_size // 2

        #For each element...
        for x, row in enumerate(mag):
            for y, element in enumerate(row):
                #Skip DC - distance / frequency zero.
                if x == center_x and y == center_y:
                    continue
                #Compute the 'rad'ius or distance of the element from DC.  Its frequency.
                rad = np.sqrt((center_x - x)**2 + (center_y - y)**2).round()

                # Raise the radius to the negative alpha. Multiply the magnitude element by this value.
                rad_filter = np.power(rad, -self.alpha)
                output_magnitudes[x][y] = element * rad_filter

        #Reconstruct the complex values for the fft with the band-limited magnitudes and preserved phase.
        fft_pink = output_magnitudes * np.exp(0+1j * phase)

        #Recreate the image, scale to range of 0,1.
        pink_noise = np.fft.ifft2(np.fft.ifftshift(fft_pink)).real
        pink_noise -= np.min(pink_noise)
        pink_noise /= np.max(pink_noise)

        return Image.fromarray(pink_noise * 255)


#Generate mean noise and conduct a low pass filtration from the average magnitudes of a generated dataset.
class LowPassNoiseFromMags(Background):

    #Magnitudes (np.ndarray): Averaged magnitudes of all images in a dataset.  Use "utils.scanFolderMagnitudes()" to get mean magnitude for a folder of images.
    #Alpha (float): The power to raise the inverse frequency to when multiplying by magnitude.  1 will give a low pass filter.
    def __init__(self, magnitudes, alpha: (float) = 1.0): 
        self.x_size = magnitudes.shape[0]
        self.y_size = magnitudes.shape[1]
        self.mag = magnitudes
        self.alpha = alpha
        return
    
    #Returns a randomly generated pink noise background from previously supplied magnitudes and alpha.
    def newBackground(self) -> Image.Image:

        #Setup - get blank magnitudes array and make noise to get random phase angles.
        output_magnitudes = np.zeros_like(self.mag)
        noise = np.random.random((self.x_size, self.y_size))
        fft_white_noise = np.fft.fftshift(np.fft.fft2(noise))

        #Only care to keep the phase angle of the noise, we are using average magnitudes supplied.
        phase = np.angle(fft_white_noise)

        #Compute center of the image, or where the DC is.
        center_x = self.x_size // 2
        center_y = self.y_size // 2

        #For each element...
        for x, row in enumerate(self.mag):
            for y, element in enumerate(row):
                #Skip DC - distance / frequency zero.
                if x == center_x and y == center_y:
                    continue
                #Compute the 'rad'ius or distance of the element from DC.  Its frequency.
                rad = np.sqrt((center_x - x)**2 + (center_y - y)**2).round()

                # Raise the radius to the negative alpha. Multiply the magnitude element by this value.
                rad_filter = np.power(rad, -self.alpha)
                output_magnitudes[x][y] = element * rad_filter

        #Reconstruct the complex values for the fft with the band-limited magnitudes and preserved phase.
        fft_low = output_magnitudes * np.exp(0+1j * phase)

        #Recreate the image, scale to range of 0,1.
        low_noise = np.fft.ifft2(np.fft.ifftshift(fft_low)).real
        low_noise -= np.min(low_noise)
        low_noise /= np.max(low_noise)

        return Image.fromarray(low_noise * 255)
    
#Generate mean noise from the average magnitudes of a generated dataset.
class MeanNoiseFromMags(Background):

    #Magnitudes (np.ndarray): Averaged magnitudes of all images in a dataset.  Use "utils.scanFolderMagnitudes()" to get mean magnitude for a folder of images.
    def __init__(self, magnitudes, alpha: (float) = 1.0):
        self.x_size = magnitudes.shape[0]
        self.y_size = magnitudes.shape[1]
        self.mag = magnitudes
        self.alpha = alpha
        return
    
    #Returns a randomly generated pink noise background from previously supplied magnitudes and alpha.
    def newBackground(self) -> Image.Image:

        #Setup - make noise to get random phase angles.
        noise = np.random.random((self.x_size, self.y_size))
        fft_white_noise = np.fft.fftshift(np.fft.fft2(noise))

        #Only care to keep the phase angle of the noise, we are using average magnitudes supplied.
        phase = np.angle(fft_white_noise)

        #Reconstruct the complex values for the fft with the band-limited magnitudes and preserved phase.
        fft_mean = self.mag * np.exp(0+1j * phase)

        #Recreate the image, scale to range of 0,1.
        mean_noise = np.fft.ifft2(np.fft.ifftshift(fft_mean)).real
        mean_noise -= np.min(mean_noise)
        mean_noise /= np.max(mean_noise)

        return Image.fromarray(mean_noise * 255)