import numpy as np

from tqdm import tqdm

from utils.image import EFO, LCM, LFO, LFC, APT, CPT
from multiprocessing import Pool

def luminosity(
    imgs      : np.ndarray,
    RF        : dict,
    n_workers : int = None,
    total     : int = None,
    verbose   : str = True
    ) -> dict:
    '''
        This function computes the unit-receptive-field specific average luminosity.

        Inputs:
        - imgs [np.ndarray]: Arrays containing the set of images. Expected to have shape:
                             [n_batches, batch_size, W, H]
        - RF [dict of signature {layer : unitRFs}]: dictionary containing the receptive
             fiels of all the considered units in a given layer.

        Kwargs:
        - n_workds [int]: Number of parallel process to use for computation
        - total [int|None]: total number of layer to process for prettier formatting

        Returns:
        - avg_lum [dict of signalture {layer : lum}]: dictionary containing the average
                  scene luminosity for all the provided units in all the layers 
    '''

    global __f
    def __f(RF):
        # Unpack unit RF
        rfx, rfy = RF

        # Compute the averate per-channel per-unit luminance, using the RF information
        return imgs[..., rfx[0]:rfx[1], rfy[0]:rfy[1]].mean(axis = (-2, -1))

    if n_workers is None or n_workers > 1:
        with Pool(n_workers) as P:
            iterator = tqdm(RF.items(), total = total, desc = 'Computing Local Luminance Map', leave = False) if verbose else\
                    RF.items()

            out = {layer : np.array(list(P.map(__f, uRF))) for layer, uRF in iterator}

    elif n_workers == 1:
        iterator = tqdm(RF.items(), total = total, desc = 'Computing Local Luminance Map', leave = False) if verbose else\
                    RF.items()

        out = {layer : np.array([__f(rf) for rf in uRF]) for layer, uRF in iterator}

    else:
        raise ValueError(f'Invalid n_workers parameter. Got {n_workers}.')

    return out

def contrast(
    imgs      : np.ndarray, 
    RF        : dict, 
    kind      : str = 'sobel',
    n_workers : int = None,
    chunk     : int = 10, 
    verbose   : str = True, 
    total     : int = None
    ) -> dict:
    '''
        This function computes the unit-receptive-field specific image contrast.

        Inputs:
        - imgs [np.ndarray]: Arrays containing the set of images.
        - RF [dict of signature {layer : unitRFs}]: dictionary containing the receptive
             fiels of all the considered units in a given layer.

        Kwargs:
        - kind [str]: Which kind of contrast method to use for computation
        - n_workds [int]: Number of parallel process to use for computation
        - chunk [int]: chunk size for parallel mapping
        - total [int|None]: total number of layer to process for prettier formatting

        Returns:
        - contrast [dict of signalture {layer : lum}]: dictionary containing the computed
                    image contrast for all the provided units in all the layers 
    '''

    # Compute the LCM for the complete set of images
    lcm_imgs = LCM(imgs, kind = kind)

    img_xsize, img_ysize = imgs.shape[-2:]

    global __f
    def __f(RF):
        # Unpack unit RF
        rfx, rfy = RF

        # Extract from the Local Contrast Map the RF relevant portion
        return lcm_imgs[..., rfx[0]:rfx[1], rfy[0]:rfy[1]].mean(axis = (-2, -1))

    if n_workers is None or n_workers > 1:
        with Pool(n_workers) as P:
            iterator = tqdm(RF.items(), total = total, desc = 'Computing Local Contrast Map', leave = False) if verbose else\
                    RF.items()

            out = {layer : np.array(list(P.imap(__f, uRF, chunk))) for layer, uRF in iterator}

    elif n_workers == 1:
        iterator = tqdm(RF.items(), total = total, desc = 'Computing Local Contrast Map', leave = False) if verbose else\
                    RF.items()

        out = {layer : np.array([__f(rf) for rf in uRF]) for layer, uRF in iterator}

    else:
        raise ValueError(f'Invalid n_workers parameter. Got {n_workers}.')

    return out  

def orientation(
    imgs : np.ndarray, 
    RF   : dict,
    off  : int = 2,
    verbose   : str = True, 
    total     : int = None
    ) -> dict:

    # Compute the APT for the complete set of images
    full_angs, full_msi, full_ent = APT(imgs, off = off)

    img_xsize, img_ysize = imgs.shape[-2:]

    global __f
    def __f(RF):
        # Unpack unit RF
        rfx, rfy = RF

        # If the receptive fields cover the whole image, just return the pre-computed stack
        if rfx[1] - rfx[0] >= img_xsize and rfy[1] - rfy[0] >= img_ysize:
            return full_angs.copy(), full_msi.copy(), full_ent.copy()

        else:
            # Compute the Local Fourier Orientation in the RF relevant portion
            return APT(imgs[..., rfx[0]:rfx[1], rfy[0]:rfy[1]], off = off)

    iterator = tqdm(RF.items(), total = total, desc = 'Computing Orientation', leave = False) if verbose else\
                    RF.items()

    out = {layer : np.array([__f(rf) for rf in uRF]) for layer, uRF in iterator}

    angs = {k : np.concatenate([tmp[0] for tmp in ldata]) for k, ldata in out.items()}
    msis = {k : np.concatenate([tmp[1] for tmp in ldata]) for k, ldata in out.items()}
    ents = {k : np.concatenate([tmp[2] for tmp in ldata]) for k, ldata in out.items()}

    return angs, msis, ents


def corner(
    imgs : np.ndarray, 
    RF   : dict,
    off  : int = 2,
    verbose   : str = True, 
    total     : int = None
    ) -> dict:

    # Compute the APT for the complete set of images
    full_peak1, full_peak2, full_bsi = CPT(imgs, off = off)

    img_xsize, img_ysize = imgs.shape[-2:]

    global __f
    def __f(RF):
        # Unpack unit RF
        rfx, rfy = RF

        # If the receptive fields cover the whole image, just return the pre-computed stack
        if rfx[1] - rfx[0] >= img_xsize and rfy[1] - rfy[0] >= img_ysize:
            return full_peak1.copy(), full_peak2.copy(), full_bsi.copy()

        else:
            # Compute the Local Fourier Orientation in the RF relevant portion
            return CPT(imgs[..., rfx[0]:rfx[1], rfy[0]:rfy[1]], off = off)

    iterator = tqdm(RF.items(), total = total, desc = 'Computing Corners', leave = False) if verbose else\
                    RF.items()

    out = {layer : np.array([__f(rf) for rf in uRF]) for layer, uRF in iterator}

    peak1 = {k : np.concatenate([tmp[0] for tmp in ldata]) for k, ldata in out.items()}
    peak2 = {k : np.concatenate([tmp[1] for tmp in ldata]) for k, ldata in out.items()}
    bsi   = {k : np.concatenate([tmp[2] for tmp in ldata]) for k, ldata in out.items()}

    return peak1, peak2, bsi