import numpy as np

from math import sqrt

from numpy.lib.stride_tricks import as_strided
from scipy.ndimage import convolve
from scipy.ndimage import sobel
from scipy.ndimage import prewitt
from scipy.signal import find_peaks
from scipy.fft import rfft2, fftshift, next_fast_len

from skimage.transform import warp_polar
from skimage.measure import EllipseModel

from scipy.fft import fft2, fftshift

from typing import Tuple
from multiprocessing import Pool

from utils.information import H_p

def block_reshape(img : np.ndarray, kernel_size : Tuple[int, int] = (2, 2), stride : Tuple[int, int] = (1, 1)) -> np.ndarray:
    pad_width = np.zeros ((2, 2), dtype = np.int)

    pad_shape = np.array (kernel_size) - 1
    pad_width[:, 0], pad_width[:, 1] = (np.ceil (pad_shape / 2), np.floor (pad_shape / 2))

    pad_width = ((0, 0), *pad_width) if len(img.shape) == 3 else pad_width
    img = np.pad (img, pad_width, mode = 'edge')

    shape   = img.shape[:-2] +\
              ((img.shape[-2] - kernel_size[-2]) // stride[0] + 1,) + \
              ((img.shape[-1] - kernel_size[-1]) // stride[1] + 1,) + \
              kernel_size

    strides = img.strides[:-2] + \
              (img.strides[-2] * stride[1],) + \
              (img.strides[-1] * stride[0],) + \
              img.strides[-2:]
    
    return as_strided(img, shape = shape, strides = strides)

def angular_slice(ang_min : float, ang_max : float, size : Tuple[int, int], remove_const : bool = True) -> np.ndarray:
    '''
        This function produces a mask that can be applied to a Fourier Transformed image to
        index the coefficients between a specific angular window. 
    '''
    ysize, xsize = size
    x, y = np.linspace(-1, 1, num = xsize), np.linspace(-1, 1, num = ysize)
    X, Y = np.meshgrid(x, -y)

    mask1 = np.logical_and(np.arctan2( Y,  X) > np.deg2rad(ang_min), np.arctan2( Y,  X) < np.deg2rad(ang_max))
    mask2 = np.logical_and(np.arctan2(-Y, -X) > np.deg2rad(ang_min), np.arctan2(-Y, -X) < np.deg2rad(ang_max))

    mask = np.logical_or(mask1, mask2)

    # Always remove the zero Fourier Coefficient 
    if remove_const: mask[ysize // 2, xsize // 2] = False

    # assert mask.any(), 'Mask is all False'

    return mask

def BSI(pwr_ang : np.ndarray, min_ang : int = 15):
    num_ang = pwr_ang.shape[-1]
    
    # Tile the data to simulate the periodicity in the angular domain
    r_data = np.tile(pwr_ang, (1, 1, 3))

    def _f(r_data):
        peaks, p_prop = find_peaks( r_data, distance = min_ang, prominence = .02, height = -1e5)
        thrgs, t_prop = find_peaks(-r_data, distance = min_ang, prominence = .02, height = -1e5)
        
        p_mask = np.logical_and(peaks >= 180, peaks <= 360)
        t_mask = np.logical_and(thrgs >= 180, thrgs <= 360)
        peaks = peaks[p_mask]
        thrgs = thrgs[t_mask]
        
        p_heights = p_prop['peak_heights'][p_mask] 
        t_heights = t_prop['peak_heights'][t_mask]
        
        peaks, p_idxs = np.unique(peaks % (num_ang), return_index = True)
        thrgs, t_idxs = np.unique(thrgs % (num_ang), return_index = True)
        
        # If peaks search fails, just return unplausible values that would
        # later be filtered out
        if len(peaks) < 2 or len(thrgs) < 2: return (-1, -1), -1

        p_heights = p_heights[p_idxs]
        t_heights = t_heights[t_idxs]

        # Sort the peaks & throghts based on heights
        p_sort = np.argsort( p_heights)[-2:][::-1]
        t_sort = np.argsort(-t_heights)[ :2]
        
        p_Max, p_max =  p_heights[p_sort]
        t_Max, t_max = -t_heights[t_sort]
        
        peaks = peaks[p_sort]
        thrgs = thrgs[t_sort]
        
        # Return the two peaks and the BSI
        return peaks, (p_max - t_max) / (p_Max - t_Max + 1e-10)

    R = np.apply_along_axis(_f, 1, r_data[:, 0])
    G = np.apply_along_axis(_f, 1, r_data[:, 1])
    B = np.apply_along_axis(_f, 1, r_data[:, 2])

    R_m, R_s = R.T
    
    # Separete metric from score
    R_m, R_s = R.T
    G_m, G_s = G.T
    B_m, B_s = B.T

    # Separete peak1 from peak2
    Rp1, Rp2 = zip(*R_m)
    Gp1, Gp2 = zip(*G_m)
    Bp1, Bp2 = zip(*B_m)

    # Assemble everything togheter channel-wise
    peak1 = np.stack([Rp1, Gp1, Bp1], axis = -1)
    peak2 = np.stack([Rp2, Gp2, Bp2], axis = -1)
    score = np.stack([R_s, G_s, B_s], axis = -1)
    
    return peak1, peak2, score 


def LCM (img : np.ndarray,                        # Input image of which to extract the Local Constrast Map
         kind : str = 'rms',                      # The kind of algorithm used for constrast estimation
         kernel_size : Tuple[int, int] = (3, 3),  # Size of kernel window to use (only for some algorithms)
         stride : Tuple[int, int] = (1, 1),       # Stride of the convolution (only for some algorithms)
         normalize : bool = True) -> np.ndarray:  

    supported = ('rms', 'sobel', 'prewitt', 'michelson', 'roberts-cross')
    
    # Here we select the type of contrast
    if   kind == 'rms': out = np.std (block_reshape(img, kernel_size, stride), axis = (-2, -1))
    elif kind == 'sobel': out = np.hypot(sobel(img, axis = -1), sobel(img, axis = -2))
    elif kind == 'prewitt': out = np.hypot(prewitt(img, axis = -1), prewitt(img, axis = -2))
        
    elif kind == 'michelson': 
        img_blocks = block_reshape(img, kernel_size, stride)
        amax, amin = np.max (img_blocks, axis = (-2, -1)), np.min (img_blocks, axis = (-2, -1))
        out = (amax - amin) / (amax + amin + 1e-10)

    elif kind == 'roberts-cross':
        f1, f2 = np.array([[[[1, 0], [0, -1]]]]), np.array([[[[0, 1], [-1, 0]]]])
        f1, f2 = (f1[0], f2[0]) if len(img.shape) == 3 else (f1, f2)
        f1, f2 = (f1[0, 0], f2[0, 0]) if len(img.shape) == 2 else (f1, f2)
        Gx, Gy = convolve(img, f1), convolve(img, f2)
        out = np.hypot(Gx, Gy)

    else:
        raise ValueError (f'Unknown Contrast {kind}. Supported algorithms are: {supported}')

    return out / np.max(out) if normalize else out

def APT(imgs : np.ndarray, off : int = 2, smooth : Tuple[int, int] = (2, 2), chunk = 100):
    '''
        This function computes the dominant angle present in a scene based on a Polar
        Fourier Spectrum of the image.

        img [np.ndarray]: Tensor of expected shape [num_unit, num_imgs, C, H, W].
            NOTE: because of high memory consumption, one should take care of
            containing the number of units for which to compute this object.
    '''

    # Get the optimal shape for efficient FFT computations
    h, w = imgs.shape[-2:]
    ffts = next_fast_len(h), next_fast_len(w)
    
    # Compute the real FFT of the images and extract the power spectra
    pwr_imgs = np.abs(rfft2(imgs, s = ffts, workers = -1, overwrite_x = True))

    # Before shifting we empirically try to "smooth-out" the edge artifact due
    # to periodic boundary conditions in FFTs (spurious edges appearing at very
    # low frequencies due to image-tiling)
    fx, fy = pwr_imgs.shape[-2:]
    sx, sy = min(smooth[0], fx - 2), min(smooth[1], fy - 2)
    ex = sx + 2 if sx == fx - 2 else -sx

    # Shift frequencies back, so to proper map them via polar transform. No need
    # to shift axis = -1 as it was transformed via real-FFT, so only half space
    # was computed in the first place
    pwr_imgs = fftshift(pwr_imgs[..., sx:ex, sy:], axes = -2)

    # Transform the spectra into polar coordinates. We map unit and batch
    # dimensions into a shared axis for faster parallel computation
    u, b, c, h, w = pwr_imgs.shape
    global to_polar
    def to_polar(pwr : np.ndarray):
        return warp_polar(pwr, channel_axis = 0, center = (h / 2, 0), scaling = 'log')
        
    with Pool() as P:
        results  = P.imap(to_polar, pwr_imgs.reshape(-1, c, h, w), chunk)
        pwr_imgs = np.array(list(results), copy = False)
    
    # Reshape the array back to [num_unit, num_imgs, C, num_angles, num_radii]
    # NOTE: num_angles is always equal to 360, while the num_radii should be
    #       inferred as it varies with the receptive field size
    pwr_imgs = pwr_imgs.reshape(u, b, c, 360, -1)

    # Compute the peak orientation by summing over the radius dimension
    # NOTE: We compute argmax with the [::-1] array order to match the
    #       anti-clockwise orientation 
    hlf_pols = np.concatenate([pwr_imgs[..., 270:, :], pwr_imgs[..., :90, :]], axis = -2)
    pwr_angs = hlf_pols[..., off:].sum(-1)
    max_angs = np.argmax(pwr_angs[..., ::-1], axis = -1)
    
    # Compute a metric for the quality of edge detection
    amax, amin = np.max(pwr_angs, axis = -1), np.min(pwr_angs, axis = -1)
    
    msi = np.where(amin > 0, (amax - amin) / (amax + amin + 1e-10), 0)
    ent = H_p(pwr_angs)
    
    # Return the computed edge angles and associated scores
    return max_angs, msi, ent

def CPT(imgs : np.ndarray, off : int = 2, smooth : Tuple[int, int] = (2, 2), chunk = 100):
    '''
        This function computes the two dominant angles present in a scene based on a Polar
        Fourier Spectrum of the image. It is use for corner extraction.

        img [torch.FloatTensor]: Tensor of expected shape [num_unit, num_imgs, C, H, W].
            NOTE: because of high memory consumption, one should take care of
            containing the number of units for which to compute this object.
    '''

    # Get the optimal shape for efficient FFT computations
    h, w = imgs.shape[-2:]
    ffts = next_fast_len(h), next_fast_len(w)
    
    # Compute the FFT of the images and extract the power spectra
    pwr_imgs = np.abs(rfft2(imgs, s = ffts, workers = -1, overwrite_x = True))

    # Before shifting we empirically try to "smooth-out" the edge artifact due
    # to periodic boundary conditions in FFTs (spurious edges appearing at very
    # low frequencies due to image-tiling)

    # Shift frequencies back, so to proper map them via polar transform. No need
    # to shift axis = -1 as it was transformed via real-FFT, so only half space
    # was computed in the first place
    fx, fy = pwr_imgs.shape[-2:]
    sx, sy = min(smooth[0], fx - 2), min(smooth[1], fy - 2)
    ex = sx + 2 if sx == fx - 2 else -sx
    pwr_imgs = fftshift(pwr_imgs[..., sx:ex, sy:], axes = -2)

    # Transform the spectra into polar coordinates. We map unit and batch
    # dimensions into a shared axis for faster parallel computation
    u, b, c, h, w = pwr_imgs.shape
    global __to_polar 
    def __to_polar(pwr : np.ndarray):
        return warp_polar(pwr, channel_axis = 0, center = (h / 2, 0), scaling = 'log')
        
    with Pool() as P:
        results  = P.imap(__to_polar, pwr_imgs.reshape(-1, c, h, w), chunk)
        pwr_imgs = np.array(list(results), copy = False)
    
    # Reshape the array back to [num_unit, num_imgs, C, num_angles, num_radii]
    # NOTE: num_angles is always equal to 360, while the num_radii should be
    #       inferred as it varies with the receptive field size
    pwr_imgs = pwr_imgs.reshape(u, b, c, 360, -1)

    # Compute the peak orientation by summing over the radius dimension
    # NOTE: We compute argmax with the [::-1] array order to match the
    #       anti-clockwise orientation
    hlf_pols = np.concatenate([pwr_imgs[..., 270:, :], pwr_imgs[..., :90, :]], axis = -2)
    pwr_angs = hlf_pols[..., off:].sum(-1)
    
    # Normalize the signal in [0, 1]
    amax = pwr_angs.max(-1).reshape(-1, c, 1)
    pwr_angs = pwr_angs.reshape(-1, c, 180) / (amax + 1e-10)
    peak1, peak2, score = BSI(pwr_angs)
    
    peak1 = peak1.reshape(u, b, c)
    peak2 = peak2.reshape(u, b, c)
    score = score.reshape (u, b, c)
    
    return peak1, peak2, score

def LFAC(img          : np.ndarray,         # Input image of which to extract the Local Fourier Angular Contrast
         num_ang      : int = 25,           # The number of angular region in which to subdivide the spectrum
         kind         : str = 'michelson',  # The algorithm used to compute the 1D contrast of the Angular Spectrum
         remove_const : bool = True,        # Flag to signal whether to remove the constant Fourier component
         workers      : int  = -1           # Number of worker processes to use the FFT computation
        ) -> np.ndarray:

    supported = ('michelson', 'rms')
    
    # Compute the 2D Fourier Transform of input image
    pwr_img = np.abs(fftshift(fft2(img, workers = workers), axes = (-2, -1)))

    # Get the angular slice masks
    angles = np.linspace(0, 180, num = num_ang)
    masks  = [angular_slice(ang1, ang2, img.shape[-2:], remove_const) for ang1, ang2 in zip(angles, angles[1:])]

    # Compute the channel-specific angular power spectrum
    pwr_ang = np.array([pwr_img[..., mask].mean(axis = -1) for mask in masks if mask.any()]).transpose(1, 2, 0)

    # Compute the local contrast of the Angular Power Spectrum
    if kind == 'michelson': 
        pmax, pmin = np.max(pwr_ang, axis = -1), np.min(pwr_ang, axis = -1)
        out = (pmax - pmin) / (pmax + pmin + 1e-10)
    elif kind == 'rms':
        out = np.std(pwr_ang, axis = -1)

    else:
        raise ValueError(f'Unknown Contrast {kind}. Supported algorithms are: {supported}')

    return out

def LFO(img          : np.ndarray,         # Input image of which to extract the Local Fourier Orientation
        num_ang      : int = 25,           # The number of angular region in which to subdivide the spectrum
        metric       : str = 'entropy',    # The algorithm used to compute the quality metric
        remove_const : bool = True,        # Flag to signal whether to remove the constant Fourier component
        workers      : int  = -1           # Number of worker processes to use the FFT computation
        ) -> Tuple[np.ndarray, np.ndarray]:

    supported = ('entropy',)
    
    # Compute the 2D Fourier Transform of input image
    pwr_img = np.abs(fftshift(fft2(img, workers = workers), axes = (-2, -1)))

    # Get the angular slice masks
    angles = np.linspace(0, 180, num = num_ang)
    masks  = [angular_slice(ang1, ang2, img.shape[-2:], remove_const) for ang1, ang2 in zip(angles, angles[1:])]

    # Compute the channel-specific angular power spectrum
    pwr_ang = np.array([pwr_img[..., mask].mean(axis = -1) for mask in masks if mask.any()]).transpose(2, 1, 0)

    # Compute the orientation for which the power angular spectrum achieves the maximum
    max_ang = np.argmax(pwr_ang, axis = -1)

    # Compute the quality metric of the Angular Power Spectrum
    if metric == 'entropy':
        quality = H_p(pwr_ang)

    else:
        raise ValueError(f'Unknown Metric {metric}. Supported algorithms are: {supported}')

    # Returns both the contrast metric and the highest-power orientation
    return max_ang, quality

def LFC(img          : np.ndarray,         # Input image of which to extract the Local Fourier Orientation
        num_ang      : int = 25,           # The number of angular region in which to subdivide the spectrum
        metric       : str = 'entropy',    # The algorithm used to compute the quality metric
        remove_const : bool = True,        # Flag to signal whether to remove the constant Fourier component
        workers      : int  = -1           # Number of worker processes to use the FFT computation
        ) -> Tuple[np.ndarray, np.ndarray]:
    '''
        This function computes a Local Fourier Corner estimation. It takes the input batch of images, (as
        seen through the unit RFs) and computes the Fourier Power spectrum. From this it extracts an estimate
        of the dominant corner-angle in the scene as the angular difference between the position of the
        peaks in spectrum (circular difference). It furthermore provides two quality metrics: the entropy
        of the angular power spectrum itself an the BSI (Bimodal tuning index), which estimates how
        multi-peaked the spectrum is, and it better be if a corner is expected to be there.
    '''

    supported = ('entropy',)
    
    # Compute the 2D Fourier Transform of input image
    pwr_img = np.abs(fftshift(fft2(img, workers = workers), axes = (-2, -1)))

    # Get the angular slice masks
    angles = np.linspace(0, 180, num = num_ang)
    masks  = [angular_slice(ang1, ang2, img.shape[-2:], remove_const) for ang1, ang2 in zip(angles, angles[1:])]

    # Compute the channel-specific angular power spectrum
    pwr_ang = np.array([pwr_img[..., mask].mean(axis = -1) for mask in masks if mask.any()])

    # Compute the angle of the putative corner as the difference between the index of the
    # first and second peak. Moreover, it computes the bimodal tuning index based on the
    # angular power spectrum
    ang, bsi = BSI(pwr_ang, num_ang)

    # Compute the quality metric of the Angular Power Spectrum
    if metric == 'entropy':
        # NOTE: For entropy computation, we move the angle dimension to the last position
        #       as it is the one where entropy is computed
        quality = H_p(pwr_ang.transpose(1, 2, 0))

    else:
        raise ValueError(f'Unknown Metric {metric}. Supported algorithms are: {supported}')

    # Returns both the contrast metric and the highest-power orientation
    # Return shape is [num_imgs, n_channels] for all the arrays
    return ang, bsi, quality

def EFO(
    img     : np.ndarray,           # Input image of which to extract the Ellipse Fourier Orientation
    qrt     : float = 75,           # The log-power-spectrum quantile use to set the threshold for ellipse fit
    workers : int  = -1             # Number of worker processes to use the FFT computation
    ) -> Tuple[np.ndarray, np.ndarray]:
    '''
        This function computes the Ellipse Fourier Orientation (EFO). It first computes the
        logaritm of the FFT power spectrum, and then fits an Ellipse Model on thresholded
        values based on the provided quartile (used to determine the threshold).

        Provided images are expected to have shape [n_imgs, n_channel, W, H].
        Computation is performed separately for each channel.

        Returns:
            metrics: [2-Tuple of np.arrays] The computed metrics. First component is the
                     angle of the ellipse, second component is the ellipse eccentricity,
                     which can be used as a metric for quality-of-edge. 
                     The returned arrays have shape [n_channel, n_imgs] 
    '''
    # NOTE: We can probably avoid the logaritm computation as it being a monotonic
    #       function it does not alter the element ordering.

    # Compute the 2D Fourier Transform of input image
    pwr_img = np.abs(fftshift(fft2(img, workers = workers), axes = (-2, -1)))

    # log_pwr = np.log(pwr_img)

    # Compute the threshold based on provided quartile of log-power spectrum
    thrs = np.percentile(pwr_img, qrt, axis = (-2, -1))

    # Fit the Ellipse model and extract parameters
    def __efit(pwr : np.ndarray, thr : float) -> Tuple[float, float]:
        # points = np.indices(log_pwr.shape)[::-1, log_pwr > thr].T
        points = np.argwhere(pwr > thr)[:, ::-1]

        # Fit the ellipse model on thresholded points
        ell = EllipseModel()
        i = 0
        while not ell.estimate(points):
            # If Ellipse estimation fails it is because too few points passed the quartile
            # threshold and the problem design matrix is singular. This happens for very small 
            # receptive fields, where few points exists at all. To compensate for this fact, 
            # we gradually increase the number of points available untill a non-singular condition
            # for ellipse estimation is reached. 
            points = np.argwhere(pwr >= np.sort(pwr, axis = None)[-4 -(i := i + 1)])[:, ::-1]

            assert i + 4 < pwr.size, f'{i+4}, {pwr.size}, {pwr.shape}'
            
        # Get the ellipse parameter and compute eccentricity
        _, _, a, b, phi = ell.params

        e = sqrt(1 - min(a, b)**2 / (max(a, b)**2 + 1e-20))

        return phi, e

    # NOTE: If provided images have a channel dimension, we compute the Ellipse models
    #       for each channel separately
    # Map the ellipse fit function to all provided power spectra
    Rp, Re = list(zip(*[__efit(lpwr, thr) for lpwr, thr in zip(pwr_img[:, 0], thrs[:, 0])]))
    Gp, Ge = list(zip(*[__efit(lpwr, thr) for lpwr, thr in zip(pwr_img[:, 1], thrs[:, 1])]))
    Bp, Be = list(zip(*[__efit(lpwr, thr) for lpwr, thr in zip(pwr_img[:, 2], thrs[:, 2])]))

    return np.stack([Rp, Gp, Bp]), np.stack([Re, Ge, Be])