import numpy as np

import scipy.fft as sp_fft


def sliding_window(X, size, overlap=0):
    """Create a sliding window of size `size` and overlap `overlap` on the data `X`.

    Parameters
    ----------
    X : array, shape=(T)
        Data.
    size : int
        Size of the window.
    overlap : int
        Overlap between two windows.

    Returns
    -------
    array, shape=(n_windows, size)
    """
    data = []
    start = 0
    end = size
    step = size - overlap
    if step <= 0:
        ValueError("overlap must be smaller than size")
    length = len(X)
    while end <= length:
        data.append(X[start:end])
        start += step
        end += step
    return np.array(data)


def get_barycenters(X, filter_size, weights=None):
    """Get the barycenter of the PSD of several sources.

    Parameters
    ----------
    X : array, shape=(K, T)
        Data of several sources.
    filter_size : int
        Size of the filter to compute.
    weights : array, shape=(N,)
        Weights of each source. If None, all sources are weighted equally.

    Returns
    -------
    array, shape=(filter_size)
        The barycenter of the PSD of the sources.
    """
    K = len(X)
    if weights is None:
        weights = np.ones(K) / K
    # Cut data in size of filter with overlap
    X_ = [
        sliding_window(X[i], size=filter_size, overlap=int(filter_size / 2))
        for i in range(K)
    ]
    psd_sources_ = [
        np.mean(np.abs(sp_fft.rfftn(X_[i], axes=-1)) ** 2, axis=0) for i in range(K)
    ]

    barycenter_ = (
        np.sum([weights[i] * np.sqrt(psd_sources_[i]) for i in range(K)], axis=0) ** 2
    )
    return barycenter_


def get_filter(X, barycenter, filter_size, C=0):
    """Compute the filter to apply to the target data.

    Parameters
    ----------
    X : array, shape=(T)
        Data of several sources.
    barycenter : array, shape=(filter_size,)
        Barycenter of the PSD of the sources.
    filter_size : int
        Size of the filter to compute.
    C : float
        Regularization parameter.

    Returns
    -------
    array, shape=(filter_size,)
        The filter to apply to the target data.
    """

    X_ = sliding_window(X, size=filter_size, overlap=int(filter_size / 2))

    # Compute target PSD
    psd_ = np.abs(sp_fft.rfftn(X_, axes=-1)) ** 2
    psd_mean_ = np.mean(psd_, axis=0)

    # Compute filter
    D = np.sqrt(barycenter + C) / np.sqrt(psd_mean_ + C)

    return sp_fft.irfftn(D, axes=-1)


def get_psd(X, filter_size):
    """Compute the PSD of some data using Welch's method.

    Parameters
    ----------
    X : array, shape=(T)
        Data of several sources.
    filter_size : int
        Size of the filter to compute.

    Returns
    -------
    array, shape=(filter_size,)
        The PSD of the data.
    """

    X_ = sliding_window(X, size=filter_size, overlap=int(filter_size / 2))

    # Compute target PSD
    psd_ = np.abs(sp_fft.rfftn(X_, axes=-1)) ** 2
    psd_mean_ = np.mean(psd_, axis=0)
    return psd_mean_


def convolution(X, D, mode="valid"):
    """Convolve the data with some filter.

    Parameters
    ----------
    X : array, shape=(N, T,)
        Data.
    D : array, shape=(filter_size,)
        Filter.
    mode : str
        Mode of the convolution. Can be 'valid', 'same' or 'full'.

    Returns
    -------
    array, shape=(N, T,)
        The convolved data.
    """
    X_transform = []
    for i in range(len(X)):
        X_transform.append(np.convolve(X[i], D, mode=mode))
    return np.array(X_transform)


def monge_mapping_transform(X, X_barycenter, filter_size=128, conv_mode="same"):
    """Apply the Monge mapping transform to some data.

    Parameters
    ----------
    X : array, shape=(K, N, C, T)
        Data.
    X_barycenter : array, shape=(Kb, N, C, T)
        Data to compute barycenter.
    filter_size : int
        Size of the filter to compute.
    conv_mode : str
        Mode of the convolution. Can be 'valid', 'same' or 'full'.

    Returns
    -------
    array, shape=(K, N, C, T)
        The transformed data.
    """
    barycenters = []
    n_chan = X[0].shape[1]
    for chan in range(n_chan):
        X_ = [
            np.concatenate(X_barycenter[i][:, chan]) for i in range(len(X_barycenter))
        ]
        barycenter = get_barycenters(X_, filter_size=filter_size)
        barycenters.append(barycenter)

    K = len(X)
    X_transform = []
    for chan in range(n_chan):
        X_concat = [np.concatenate(X[i][:, chan]) for i in range(K)]
        barycenter = barycenters[chan]
        D = [get_filter(X_concat[i], barycenter, filter_size) for i in range(K)]
        X_transform.append(
            [
                np.expand_dims(convolution(X[i][:, chan], D[i], mode=conv_mode), axis=1)
                for i in range(K)
            ]
        )
    X_transform = [
        np.concatenate([X_transform[chan][i] for chan in range(n_chan)], axis=1)
        for i in range(K)
    ]

    return X_transform
