from setuptools import setup, find_packages

with open("README.md", "r") as f:
    readme = f.read()

requirements = [
    "numpy", "scikit-learn", "pandas", "scipy", "matplotlib", "nibabel",
    "joblib"
]

setup(
    name="brainsmash",
    version="0.11.0",
    author="Joshua Burt",
    author_email="joshua.burt@yale.edu",
    include_package_data=True,
    description="Brain Surrogate Maps with Autocorrelated Spatial Heterogeneity.",
    long_description=readme,
    long_description_content_type="text/markdown",
    url="https://github.com/murraylab/brainsmash",
    packages=find_packages(),
    install_requires=requirements,
    python_requires='>=3',
    classifiers=[
        "Programming Language :: Python :: 3.7",
        "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
    ],
)


from os.path import join, dirname, abspath

__all__ = ['kernels', 'parcel_labels_lr']

# Names of available kernels, which are defined in mapgen.kernels.py
kernels = ['exp', 'gaussian', 'invdist', 'uniform']

repo_root = dirname(dirname(abspath(__file__)))  # root directory path
package_root = join(repo_root, "brainsmash")  # package directory
data = join(package_root, "data")  # data directory

# This file is used by default to identify MNI coordinates of each subcortical
# voxel, and to establish the mapping between CIFTI indices and different gross
# anatomical structures
parcel_labels_lr = join(
    data, "CortexSubcortex_ColeAnticevic_NetPartition_"
          "wSubcorGSR_netassignments_LR.dlabel.nii")


__version__ = '0.11.0'


"""
Generate spatial autocorrelation-preserving surrogate maps.
"""
from .kernels import check_kernel
from ..utils.checks import check_map, check_distmat, check_deltas, check_pv
from ..utils.dataio import dataio
from sklearn.utils.validation import check_random_state
import numpy as np
from joblib import Parallel, delayed

MAX_ALLOWABLE_BATCH_SIZE = 500

__all__ = ['Base']


class Base:
    """
    Base implementation of map generator.

    Parameters
    ----------
    x : (N,) np.ndarray or filename
        Target brain map
    D : (N,N) np.ndarray or filename
        Pairwise distance matrix
    deltas : np.ndarray or List[float], default [0.1,0.2,...,0.9]
        Proportion of neighbors to include for smoothing, in (0, 1]
    kernel : str, default 'exp'
        Kernel with which to smooth permuted maps:
          'gaussian' : Gaussian function.
          'exp' : Exponential decay function.
          'invdist' : Inverse distance.
          'uniform' : Uniform weights (distance independent).
    pv : int, default 25
        Percentile of the pairwise distance distribution at which to
        truncate during variogram fitting
    nh : int, default 25
        Number of uniformly spaced distances at which to compute variogram
    resample : bool, default False
        Resample surrogate maps' values from target brain map
    b : float or None, default None
        Gaussian kernel bandwidth for variogram smoothing. If None, set to
        three times the spacing between variogram x-coordinates.
    seed : None or int or np.random.RandomState instance (default None)
        Specify the seed for random number generation (or random state instance)
    n_jobs : int (default 1)
        Number of jobs to use for parallelizing creation of surrogate maps

    Notes
    -----
    Passing resample=True preserves the distribution of values in the target
    map, with the possibility of worsening the simulated surrogate maps'
    variograms fits.

    """

    def __init__(self, x, D, deltas=np.linspace(0.1, 0.9, 9),
                 kernel='exp', pv=25, nh=25, resample=False, b=None,
                 seed=None, n_jobs=1):

        self._rs = check_random_state(seed)
        self._n_jobs = n_jobs

        self.x = x
        self.D = D
        n = self._x.size
        self.resample = resample
        self.nh = nh
        self.deltas = deltas
        self.pv = pv
        self.nmap = n
        self.kernel = kernel  # Smoothing kernel selection
        self._ikn = np.arange(n)[:, None]
        self._triu = np.triu_indices(self._nmap, k=1)  # upper triangular inds
        self._u = self._D[self._triu]  # variogram X-coordinate

        # Get indices of pairs with u < pv'th percentile
        self._uidx = np.where(self._u < np.percentile(self._u, self._pv))[0]
        self._uisort = np.argsort(self._u[self._uidx])

        # Find sorted indices of first `kmax` elements of each row of dist. mat.
        self._disort = np.argsort(self._D, axis=-1)
        self._jkn = dict.fromkeys(deltas)
        self._dkn = dict.fromkeys(deltas)
        for delta in deltas:
            k = int(delta*n)
            # find index of k nearest neighbors for each area
            self._jkn[delta] = self._disort[:, 1:k+1]  # prevent self-coupling
            # find distance to k nearest neighbors for each area
            self._dkn[delta] = self._D[(self._ikn, self._jkn[delta])]

        # Smoothed variogram and variogram _b
        utrunc = self._u[self._uidx]
        self._h = np.linspace(utrunc.min(), utrunc.max(), self._nh)
        self.b = b
        self._smvar = self.compute_smooth_variogram(self._x)

    def __call__(self, n=1, batch_size=1):
        """
        Randomly generate new surrogate map(s).

        Parameters
        ----------
        n : int, default 1
            Number of surrogate maps to randomly generate
        batch_size : int, default 1
            If generating n > 1 surrogates, how many to generate with each
            batch. An ideal batch_size for computation / memory tradeoffs seems
            to be around ~100.

        Returns
        -------
        (n,N) np.ndarray
            Randomly generated map(s) with matched spatial autocorrelation

        Notes
        -----
        Chooses a level of smoothing that produces a smoothed variogram which
        best approximates the true smoothed variogram. Selecting resample='True'
        preserves the original map's value distribution at the expense of
        worsening the surrogate maps' variogram fit.

        """

        # hard limit
        if batch_size > MAX_ALLOWABLE_BATCH_SIZE or batch_size == 'max':
            batch_size = MAX_ALLOWABLE_BATCH_SIZE

        # how many batches were requested?
        batches = [batch_size] * (n // batch_size)
        if n % batch_size != 0:
            batches += [n % batch_size]

        rs = self._rs.randint(np.iinfo(np.int32).max, size=len(batches))
        surrs = np.row_stack(
            Parallel(self._n_jobs)(
                delayed(self._call_method)(i=batches[n], rs=i)
                for n, i in enumerate(rs)
            )
        )
        return np.asarray(surrs.squeeze())

    def _call_method(self, i=1, rs=None):
        """ Subfunction used by .__call__() for parallelization purposes """

        # Reset RandomState so parallel jobs yield different results
        self._rs = check_random_state(rs)

        xperm = self.permute_map(i)  # Randomly permute values
        res = dict.fromkeys(self._deltas)

        for delta in self.deltas:  # foreach neighborhood size
            # Smooth the permuted map using delta proportion of
            # neighbors to reintroduce spatial autocorrelation
            sm_xperm = self.smooth_map(xperm, delta)

            # Calculate smoothed variogram of the smoothed permuted map
            smvar_perm = self.compute_smooth_variogram(sm_xperm)

            # Fit linear regression btwn smoothed variograms
            res[delta] = self.regress(smvar_perm, self._smvar)

        alphas, betas, residuals = np.array(
            [res[d] for d in self._deltas], dtype=float).transpose(1, 0, 2)

        # Select best-fit model and regression parameters
        iopt = np.argmin(residuals, axis=0)[None]
        dopt = self._deltas[iopt][0]
        aopt = np.take_along_axis(alphas, iopt, 0).squeeze()
        bopt = np.take_along_axis(betas, iopt, 0).squeeze()

        # Transform and smooth permuted map using best-fit parameters
        # TODO: fix somehow, ideally
        sm_xperm_best = np.column_stack([
            self.smooth_map(x[:, None], d) for x, d in zip(xperm.T, dopt)
        ])
        surr = (np.sqrt(np.abs(bopt)) * sm_xperm_best +
                np.sqrt(np.abs(aopt)) * self._rs.randn(self._nmap, i))

        if self._resample:  # resample values from empirical map
            sorted_map = np.sort(self._x)[:, None]
            ii = np.argsort(surr, axis=0)
            np.put_along_axis(surr, ii, sorted_map, axis=0)
        else:
            surr -= surr.mean(axis=0)  # De-mean

        return np.asarray(surr).T

    def compute_smooth_variogram(self, x, return_h=False):
        """
        Compute smoothed variogram values (1/2 squared pairwise differences)

        Parameters
        ----------
        x : (N,) np.ndarray
            Brain map scalar array
        return_h : bool, default False
            Return distances at which the smoothed variogram was computed

        Returns
        -------
        (self.nh,) np.ndarray
            Smoothed variogram values
        (self.nh) np.ndarray
            Distances at which smoothed variogram was computed (returned only if
            `return_h` is True)

        """
        if x.ndim < 2:
            x = x[..., None]

        diff_ij = x[self._triu[1][self._uidx]] - x[self._triu[0][self._uidx]]
        v = 0.5 * np.square(diff_ij)
        u = self._u[self._uidx]
        if len(u) != len(v):
            raise ValueError(
                "argument v: expected size {}, got {}".format(len(u), len(v)))
        # Subtract each h from each pairwise distance u
        # Each row corresponds to a unique h
        du = np.abs(u - self._h[:, None])
        w = np.exp(-np.square(2.68 * du / self._b) / 2)
        output = np.squeeze(np.dot(w, v) / np.nansum(w, axis=1)[:, None])
        if not return_h:
            return output
        return output, self._h

    def permute_map(self, i=1):
        """
        Return randomly permuted brain map.

        Returns
        -------
        (N,) np.ndarray
            Random permutation of target brain map

        """
        perm_idx = self._rs.random_sample((self._x.size, i)).argsort(axis=0)
        mask_perm = self._x.mask[perm_idx]
        x_perm = self._x.data[perm_idx]
        return np.ma.masked_array(data=x_perm, mask=mask_perm)

    def smooth_map(self, x, delta):
        """
        Smooth `x` using `delta` proportion of nearest neighbors.

        Parameters
        ----------
        x : (N,) np.ndarray
            Brain map scalars
        delta : float
            Proportion of neighbors to include for smoothing, in (0, 1)

        Returns
        -------
        (N,) np.ndarray
            Smoothed brain map

        """
        weights = self._kernel(self._dkn[delta])
        weights /= weights.sum(axis=1, keepdims=True)

        # iterate over the lesser of the dimensions for this comprehension
        if weights.shape[1] > x.shape[1]:
            return np.sum([
                weights[:, [n]] * x[self._jkn[delta][:, n]]
                for n in range(weights.shape[1])
            ], axis=0)

        return np.column_stack([
            np.sum(weights * xp[self._jkn[delta]], axis=1) for xp in x.T
        ])

    def regress(self, x, y):
        """
        Linearly regress `x` onto `y`.

        Parameters
        ----------
        x : (N,) np.ndarray
            Independent variable
        y : (N,) np.ndarray
            Dependent variable

        Returns
        -------
        alpha : float
            Intercept term (offset parameter)
        beta : float
            Regression coefficient (scale parameter)
        res : float
            Sum of squared residuals

        """
        if x.ndim < 2:
            x = x[..., None]
        if y.ndim < 2:
            y = y[..., None]
        if y.squeeze().ndim > 1:
            raise ValueError('Provided `y` has multiple dependent variables')

        num = (x * y).sum(axis=0) - ((x.sum(axis=0) * y.sum()) / len(x))
        denom = (x ** 2).sum(axis=0) - ((np.sum(x, axis=0) ** 2) / len(x))
        beta = num / denom
        alpha = y.mean() - (beta * x.mean(axis=0))
        res = np.sum((y - ((x * beta) + alpha)) ** 2, axis=0)

        return alpha, beta, res

    @property
    def x(self):
        """ (N,) np.ndarray : brain map scalar array """
        return self._x

    @x.setter
    def x(self, x):
        x_ = dataio(x)
        check_map(x=x_)
        brain_map = np.ma.masked_array(data=x_, mask=np.isnan(x_))
        self._x = brain_map

    @property
    def D(self):
        """ (N,N) np.ndarray : Pairwise distance matrix """
        return self._D

    @D.setter
    def D(self, x):
        x_ = dataio(x)
        check_distmat(D=x_)
        n = self._x.size
        if x_.shape != (n, n):
            e = "Distance matrix must have dimensions consistent with brain map"
            e += "\nDistance matrix shape: {}".format(x_.shape)
            e += "\nBrain map size: {}".format(n)
            raise ValueError(e)
        self._D = x_

    @property
    def nmap(self):
        """ int : length of brain map """
        return self._nmap

    @nmap.setter
    def nmap(self, x):
        self._nmap = int(x)

    @property
    def pv(self):
        """ int : percentile of pairwise distances at which to truncate """
        return self._pv

    @pv.setter
    def pv(self, x):
        pv = check_pv(x)
        self._pv = pv

    @property
    def deltas(self):
        """ np.ndarray or List[float] : proportions of nearest neighbors """
        return self._deltas

    @deltas.setter
    def deltas(self, x):
        check_deltas(deltas=x)
        self._deltas = x

    @property
    def nh(self):
        """ int : number of variogram distance intervals """
        return self._nh

    @nh.setter
    def nh(self, x):
        self._nh = x

    @property
    def h(self):
        """ np.ndarray : distances at which smoothed variogram is computed """
        return self._h

    @property
    def kernel(self):
        """ Callable : smoothing kernel function

        Notes
        -----
        When setting kernel, use name of kernel as defined in ``config.py``.

        """
        return self._kernel

    @kernel.setter
    def kernel(self, x):
        kernel_callable = check_kernel(x)
        self._kernel = kernel_callable

    @property
    def resample(self):
        """ bool : whether to resample surrogate maps from target map """
        return self._resample

    @resample.setter
    def resample(self, x):
        if not isinstance(x, bool):
            e = "parameter `resample`: expected bool, got {}".format(type(x))
            raise TypeError(e)
        self._resample = x

    @property
    def b(self):
        """ numeric : Gaussian kernel bandwidth """
        return self._b

    @b.setter
    def b(self, x):
        if x is not None:
            try:
                self._b = float(x)
            except (ValueError, TypeError):
                e = "bandwidth b: expected numeric, got {}".format(type(x))
                raise ValueError(e)
        else:   # set bandwidth equal to 3x bin spacing
            self._b = 3.*(self._h[1] - self._h[0])


""" Evaluation metrics for randomly generated surrogate maps. """

from .base import Base
from .sampled import Sampled
from ..utils.dataio import dataio
import matplotlib.pyplot as plt
import numpy as np

__all__ = ['base_fit', 'sampled_fit']


def base_fit(x, D, nsurr=100, return_data=False, **params):
    """
    Evaluate variogram fits for Base class.

    Parameters
    ----------
    x : (N,) np.ndarray or filename
        Target brain map
    D : (N,N) np.ndarray or filename
        Pairwise distance matrix between regions in `x`
    nsurr : int, default 100
        Number of simulated surrogate maps from which to compute variograms
    return_data : bool, default False
        if True, return: 1, the smoothed variogram values for the target
        brain map; 2, the distances at which the smoothed variograms values
        were computed; and 3, the surrogate maps' smoothed variogram values
    params
        Keyword arguments for :class:`brainsmash.mapgen.base.Base`

    Returns
    -------
    if and only if return_data is True:
    emp_var : (M,) np.ndarray
        empirical smoothed variogram values
    u0 : (M,) np.ndarray
        distances at which variogram values were computed
    surr_var : (nsurr, M) np.ndarray
        surrogate maps' smoothed variogram values

    Notes
    -----
    If `return_data` is False, this function generates and shows a matplotlib
    plot instance illustrating the fit of the surrogates' variograms to the
    target map's variogram. If `return_data` is True, this function returns the
    data needed to generate such a plot (i.e., the variogram values and the
    corresponding distances).

    """

    x = dataio(x)
    d = dataio(D)

    # Instantiate surrogate map generator
    generator = Base(x=x, D=d, **params)

    # Simulate surrogate maps
    surrogate_maps = generator(n=nsurr)

    # Compute empirical variogram
    emp_var, u0 = generator.compute_smooth_variogram(x, return_h=True)

    # Compute surrogate map variograms
    surr_var = np.empty((nsurr, generator.nh))
    for i in range(nsurr):
        surr_var[i] = generator.compute_smooth_variogram(surrogate_maps[i])

    if return_data:
        return emp_var, u0, surr_var

    # # Create plot for visual comparison

    # Plot empirical variogram
    fig = plt.figure(figsize=(3, 3))
    ax = fig.add_axes([0.12, 0.15, 0.8, 0.77])
    ax.autoscale(axis='y', tight=True)

    ax.scatter(u0, emp_var, s=20, facecolor='none', edgecolor='k',
               marker='o', lw=1, label='Empirical')

    # Plot surrogate maps' variograms
    mu = surr_var.mean(axis=0)
    sigma = surr_var.std(axis=0)
    ax.fill_between(u0, mu-sigma, mu+sigma, facecolor='#377eb8',
                    edgecolor='none', alpha=0.3)
    ax.plot(u0, mu, color='#377eb8', label='SA-preserving', lw=1)

    # Make plot nice
    leg = ax.legend(loc=0)
    leg.get_frame().set_linewidth(0.0)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.setp(ax.get_yticklabels(), visible=False)
    plt.setp(ax.get_yticklines(), visible=False)
    plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax.get_xticklines(), visible=False)
    ax.set_xlabel("Spatial separation\ndistance")
    ax.set_ylabel("Variance")

    plt.show()


def sampled_fit(x, D, index, nsurr=10, return_data=False, **params):
    """
    Evaluate variogram fits for Sampled class.

    Parameters
    ----------
    x : (N,) np.ndarray
        Target brain map
    D : (N,N) np.ndarray or np.memmap
        Pairwise distance matrix between elements of `x`
    index : (N,N) np.ndarray or np.memmap
        See :class:`brainsmash.mapgen.sampled.Sampled`
    nsurr : int, default 10
        Number of simulated surrogate maps from which to compute variograms
    return_data : bool, default False
        if True, return: 1, the smoothed variogram values for the target
        brain map; 2, the distances at which the smoothed variograms values
        were computed; and 3, the surrogate maps' smoothed variogram values
    params
        Keyword arguments for :class:`brainsmash.mapgen.sampled.Sampled`

    Returns
    -------
    if and only if return_data is True:
    emp_var : (M,) np.ndarray
        empirical smoothed variogram values
    u0 : (M,) np.ndarray
        distances at which variogram values were computed
    surr_var : (nsurr, M) np.ndarray
        surrogate maps' smoothed variogram values

    Notes
    -----
    If `return_data` is False, this function generates and shows a matplotlib
    plot instance illustrating the fit of the surrogates' variograms to the
    target map's variogram. If `return_data` is True, this function returns the
    data needed to generate such a plot (i.e., the variogram values and the
    corresponding distances).

    """

    # Instantiate surrogate map generator
    generator = Sampled(x=x, D=D, index=index, **params)

    # Simulate surrogate maps
    surrogate_maps = generator(n=nsurr)

    # Compute target & surrogate map variograms
    surr_var = np.empty((nsurr, generator.nh))
    emp_var_samples = np.empty((nsurr, generator.nh))
    u0_samples = np.empty((nsurr, generator.nh))
    for i in range(nsurr):
        idx = generator.sample()  # Randomly sample a subset of brain areas
        v = generator.compute_variogram(generator.x, idx)
        u = generator.D[idx, :]
        umax = np.percentile(u, generator.pv)
        uidx = np.where(u < umax)
        emp_var_i, u0i = generator.smooth_variogram(
            u=u[uidx], v=v[uidx], return_h=True)
        emp_var_samples[i], u0_samples[i] = emp_var_i, u0i
        # Surrogate
        v_null = generator.compute_variogram(surrogate_maps[i], idx)
        surr_var[i] = generator.smooth_variogram(
            u=u[uidx], v=v_null[uidx], return_h=False)

    u0 = u0_samples.mean(axis=0)
    emp_var = emp_var_samples.mean(axis=0)

    if return_data:
        return emp_var, u0, surr_var

    # Plot target variogram
    fig = plt.figure(figsize=(3, 3))
    ax = fig.add_axes([0.12, 0.15, 0.8, 0.77])
    ax.scatter(u0, emp_var, s=20, facecolor='none', edgecolor='k',
               marker='o', lw=1, label='Empirical')

    # Plot surrogate maps' variograms
    mu = surr_var.mean(axis=0)
    sigma = surr_var.std(axis=0)
    ax.fill_between(u0, mu-sigma, mu+sigma, facecolor='#377eb8',
                    edgecolor='none', alpha=0.3)
    ax.plot(u0, mu, color='#377eb8', label='SA-preserving', lw=1)

    # Make plot nice
    leg = ax.legend(loc=0)
    leg.get_frame().set_linewidth(0.0)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.setp(ax.get_yticklabels(), visible=False)
    plt.setp(ax.get_yticklines(), visible=False)
    plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax.get_xticklines(), visible=False)
    ax.set_xlabel("Spatial separation\ndistance")
    ax.set_ylabel("Variance")
    plt.show()


""" Kernels used to smooth randomly permuted brain maps.
"""

import numpy as np
from ..config import kernels

__all__ = ['gaussian', 'exp', 'invdist', 'uniform', 'check_kernel']


def gaussian(d):
    """
    Gaussian kernel which truncates at one standard deviation.

    Parameters
    ----------
    d : (N,) or (M,N) np.ndarray
        one- or two-dimensional array of distances

    Returns
    -------
    (N,) or (M,N) np.ndarray
        Gaussian kernel weights

    Raises
    ------
    TypeError : `d` is not array_like

    """
    try:  # 2-dim
        return np.exp(-1.25 * np.square(d / d.max(axis=-1)[:, np.newaxis]))
    except IndexError:  # 1-dim
        return np.exp(-1.25 * np.square(d/d.max()))
    except AttributeError:
        raise TypeError("expected array_like, got {}".format(type(d)))


def exp(d):
    """
    Exponentially decaying kernel which truncates at e^{-1}.

    Parameters
    ----------
    d : (N,) or (M,N) np.ndarray
        one- or two-dimensional array of distances

    Returns
    -------
    (N,) or (M,N) np.ndarray
        Exponential kernel weights

    Notes
    -----
    Characteristic length scale is set to d.max(axis=-1), i.e. the maximum
    distance within each row.

    Raises
    ------
    TypeError : `d` is not array_like

    """
    try:  # 2-dim
        return np.exp(-d / d.max(axis=-1)[:, np.newaxis])
    except IndexError:  # 1-dim
        return np.exp(-d/d.max())
    except AttributeError:
        raise TypeError("expected array_like, got {}".format(type(d)))


def invdist(d):
    """
    Inverse distance kernel.

    Parameters
    ----------
    d : (N,) or (M,N) np.ndarray
        One- or two-dimensional array of distances

    Returns
    -------
    (N,) or (M,N) np.ndarray
        Inverse distance, i.e. d^{-1}

    Raises
    ------
    ZeroDivisionError : `d` includes zero value
    TypeError : `d` is not array_like

    """
    try:
        return 1. / d
    except ZeroDivisionError as e:
        raise ZeroDivisionError(e)
    except AttributeError:
        raise TypeError("expected array_like, got {}".format(type(d)))


def uniform(d):
    """
    Uniform (i.e., distance independent) kernel.

    Parameters
    ----------
    d : (N,) or (M,N) np.ndarray
        One- or two-dimensional array of distances

    Returns
    -------
    (N,) or (M,N) np.ndarray
        Uniform kernel weights

    Notes
    -----
    Each element is normalized to 1/N such that columns sum to unity.

    Raises
    ------
    TypeError : `d` is not array_like

    """
    try:  # 2-dim
        return np.ones(d.shape) / d.shape[-1]
    except IndexError:  # 1-dim
        return np.ones(d.size) / d.size
    except AttributeError:
        raise TypeError("expected array_like, got {}".format(type(d)))


def check_kernel(kernel):
    """
    Check that a valid kernel was specified.

    Parameters
    ----------
    kernel : 'exp' or 'gaussian' or 'invdist' or 'uniform'
        Kernel selection

    Returns
    -------
    Callable

    Notes
    -----
    If `kernel` is included in ``config.py``, a function with the same name must
    be defined in ``mapgen.kernels.py``.

    Raises
    ------
    NotImplementedError : `kernel` is not included in `config.py`

    """
    if kernel not in kernels:
        e = "'{}' is not a valid kernel\n".format(kernel)
        e += "Valid kernels: {}".format(", ".join([k for k in kernels]))
        raise NotImplementedError(e)
    return globals()[kernel]


"""
Convert large data files written to disk to memory-mapped arrays for memory-
efficient data retrieval.
"""
from ..utils.dataio import dataio
from ..utils.checks import count_lines
import numpy.lib.format
from os import path
import numpy as np

__all__ = ['txt2memmap', 'load_memmap']


def txt2memmap(dist_file, output_dir, maskfile=None, delimiter=' '):
    """
    Export distance matrix to memory-mapped array.

    Parameters
    ----------
    dist_file : filename
        Path to `delimiter`-separated distance matrix file
    output_dir : filename
        Path to directory in which output files will be written
    maskfile : filename or np.ndarray or None, default None
        Path to a neuroimaging/txt file containing a mask, or a mask
        represented as a numpy array. Mask scalars are cast to boolean, and
        all elements not equal to zero will be masked.
    delimiter : str
        Delimiting character in `dist_file`

    Returns
    -------
    dict
        Keys are 'D' and 'index'; values are absolute paths to the
        corresponding binary files on disk.

    Notes
    -----
    Each row of the distance matrix is sorted before writing to file. Thus, a
    second mem-mapped array is necessary, the i-th row of which contains
    argsort(d[i]).
    If `maskfile` is not None, a binary mask.txt file will also be written to
    the output directory.

    Raises
    ------
    IOError : `output_dir` doesn't exist
    ValueError : Mask image and distance matrix have inconsistent sizes

    """

    nlines = count_lines(dist_file)
    if not path.exists(output_dir):
        raise IOError("Output directory does not exist: {}".format(output_dir))

    # Load mask if one was provided
    if maskfile is not None:
        mask = dataio(maskfile).astype(bool)
        if mask.size != nlines:
            e = "Incompatible input sizes\n"
            e += "{} rows in {}\n".format(nlines, dist_file)
            e += "{} elements in {}".format(mask.size, maskfile)
            raise ValueError(e)
        mask_fileout = path.join(output_dir, "mask.txt")
        np.savetxt(  # Write to text file
            fname=mask_fileout, X=mask.astype(int), fmt="%i", delimiter=',')
        nv = int((~mask).sum())  # number of non-masked elements
        idx = np.arange(nlines)[~mask]  # indices of non-masked elements
    else:
        nv = nlines
        idx = np.arange(nlines)

    # Build memory-mapped arrays
    with open(dist_file, 'r') as fp:

        npydfile = path.join(output_dir, "distmat.npy")
        npyifile = path.join(output_dir, "index.npy")
        fpd = numpy.lib.format.open_memmap(
            npydfile, mode='w+', dtype=np.float32, shape=(nv, nv))
        fpi = numpy.lib.format.open_memmap(
            npyifile, mode='w+', dtype=np.int32, shape=(nv, nv))

        ifp = 0  # Build memory-mapped arrays one row of distances at a time
        for il, l in enumerate(fp):  # Loop over lines of file
            if il not in idx:  # Keep only CIFTI vertices
                continue
            else:
                line = l.rstrip()
                if line:
                    data = np.array(line.split(delimiter), dtype=np.float32)
                    if data.size != nlines:
                        raise RuntimeError(
                            "Distance matrix is not square: {}".format(
                                dist_file))
                    d = data[idx]
                    sort_idx = np.argsort(d)
                    fpd[ifp, :] = d[sort_idx]  # sorted row of distances
                    fpi[ifp, :] = sort_idx  # sort indexes
                    ifp += 1
        del fpd  # Flush memory changes to disk
        del fpi

    return {'distmat': npydfile, 'index': npyifile}  # Return filenames


def load_memmap(filename):
    """
    Load a memory-mapped array.

    Parameters
    ----------
    filename : str
        path to memory-mapped array saved as npy file

    Returns
    -------
    np.memmap

    """
    return np.load(filename, mmap_mode='r')


"""
Generate spatial autocorrelation-preserving surrogate maps from memory-mapped
arrays with random subsampling.
"""
from ..utils.dataio import dataio
from ..utils.checks import check_map, check_pv, check_deltas
from .kernels import check_kernel
from sklearn.linear_model import LinearRegression
from sklearn.utils.validation import check_random_state
import numpy as np
from joblib import Parallel, delayed

__all__ = ['Sampled']


class Sampled:
    """
    Sampling implementation of map generator.

    Parameters
    ----------
    x : filename or 1D np.ndarray
        Target brain map
    D : filename or (N,N) np.ndarray or np.memmap
        Pairwise distance matrix between elements of `x`. Each row of `D` should
        be sorted. Indices used to sort each row are passed to the `index`
        argument. See :func:`brainsmash.mapgen.memmap.txt2memmap` or the online
        documentation for more details (brainsmash.readthedocs.io)
    index : filename or (N,N) np.ndarray or np.memmap
        See above
    ns : int, default 500
        Take a subsample of `ns` rows from `D` when fitting variograms
    deltas : np.ndarray or List[float], default [0.3, 0.5, 0.7, 0.9]
        Proportions of neighbors to include for smoothing, in (0, 1]
    kernel : str, default 'exp'
        Kernel with which to smooth permuted maps
        - 'gaussian' : gaussian function
        - 'exp' : exponential decay function
        - 'invdist' : inverse distance
        - 'uniform' : uniform weights (distance independent)
    pv : int, default 70
        Percentile of the pairwise distance distribution (in `D`) at
        which to truncate during variogram fitting
    nh : int, default 25
        Number of uniformly spaced distances at which to compute variogram
    knn : int, default 1000
        Number of nearest regions to keep in the neighborhood of each region
    b : float or None, default None
        Gaussian kernel bandwidth for variogram smoothing. if None,
        three times the distance interval spacing is used.
    resample : bool, default False
        Resample surrogate map values from the target brain map
    verbose : bool, default False
        Print surrogate count each time new surrogate map created
    seed : None or int or np.random.RandomState instance (default None)
        Specify the seed for random number generation (or random state instance)
    n_jobs : int (default 1)
        Number of jobs to use for parallelizing creation of surrogate maps

    Notes
    -----
    Passing resample=True will preserve the distribution of values in the
    target map, at the expense of worsening simulated surrogate maps'
    variograms fits. This worsening will increase as the empirical map
    more strongly deviates from normality.

    Raises
    ------
    ValueError : `x` and `D` have inconsistent sizes

    """

    def __init__(self, x, D, index, ns=500, pv=70, nh=25, knn=1000, b=None,
                 deltas=np.arange(0.3, 1., 0.2), kernel='exp', resample=False,
                 verbose=False, seed=None, n_jobs=1):

        self._rs = check_random_state(seed)
        self._n_jobs = n_jobs

        self._verbose = verbose
        self.x = x
        n = self._x.size
        self.nmap = int(n)
        self.knn = knn
        self.D = D
        self.index = index
        self.resample = resample
        self.nh = int(nh)
        self.deltas = deltas
        self.ns = int(ns)
        self.b = b
        self.pv = pv
        self._ikn = np.arange(self._nmap)[:, None]

        # Store k nearest neighbors from distance and index matrices
        self.kernel = kernel  # Smoothing kernel selection
        self._dmax = np.percentile(self._D, self._pv)
        self.h = np.linspace(self._D.min(), self._dmax, self._nh)
        if not self._b:
            self.b = 3 * (self.h[1] - self.h[0])

        # Linear regression model
        self._lm = LinearRegression(fit_intercept=True)

    def __call__(self, n=1):
        """
        Randomly generate new surrogate map(s).

        Parameters
        ----------
        n : int, default 1
            Number of surrogate maps to randomly generate

        Returns
        -------
        (n,N) np.ndarray
            Randomly generated map(s) with matched spatial autocorrelation

        Notes
        -----
        Chooses a level of smoothing that produces a smoothed variogram which
        best approximates the true smoothed variogram. Selecting resample='True'
        preserves the map value distribution at the expense of worsening the
        surrogate maps' variogram fits.

        """

        rs = self._rs.randint(np.iinfo(np.int32).max, size=n)
        surrs = np.row_stack(
            Parallel(self._n_jobs)(
                delayed(self._call_method)(rs=i) for i in rs
            )
        )
        return np.asarray(surrs.squeeze())

    def _call_method(self, rs=None):
        """ Subfunction used by .__call__() for parallelization purposes """

        # Reset RandomState so parallel jobs yield different results
        self._rs = check_random_state(rs)

        # Randomly permute map
        x_perm = self.permute_map()

        # Randomly select subset of regions to use for variograms
        idx = self.sample()

        # Compute empirical variogram
        v = self.compute_variogram(self._x, idx)

        # Variogram ordinates; use nearest neighbors because local effect
        u = self._D[idx, :]
        uidx = np.where(u < self._dmax)

        # Smooth empirical variogram
        smvar, u0 = self.smooth_variogram(u[uidx], v[uidx], return_h=True)

        res = dict.fromkeys(self._deltas)

        for d in self._deltas:  # foreach neighborhood size

            k = int(d * self._knn)

            # Smooth the permuted map using k nearest neighbors to
            # reintroduce spatial autocorrelation
            sm_xperm = self.smooth_map(x=x_perm, k=k)

            # Calculate variogram values for the smoothed permuted map
            vperm = self.compute_variogram(sm_xperm, idx)

            # Calculate smoothed variogram of the smoothed permuted map
            smvar_perm = self.smooth_variogram(u[uidx], vperm[uidx])

            # Fit linear regression btwn smoothed variograms
            res[d] = self.regress(smvar_perm, smvar)

        alphas, betas, residuals = np.array(
            [res[d] for d in self._deltas], dtype=float).T

        # Select best-fit model and regression parameters
        iopt = np.argmin(residuals)
        dopt = self._deltas[iopt]
        self._dopt = dopt
        kopt = int(dopt * self._knn)
        aopt = alphas[iopt]
        bopt = betas[iopt]

        # Transform and smooth permuted map using best-fit parameters
        sm_xperm_best = self.smooth_map(x=x_perm, k=kopt)
        surr = (np.sqrt(np.abs(bopt)) * sm_xperm_best +
                np.sqrt(np.abs(aopt)) * self._rs.randn(self._nmap))

        if self._resample:  # resample values from empirical map
            sorted_map = np.sort(self._x)
            ii = np.argsort(surr)
            np.put(surr, ii, sorted_map)
        else:
            surr = surr - np.nanmean(surr)  # De-mean

        if self._ismasked:
            return np.ma.masked_array(
                data=surr, mask=np.isnan(surr)).squeeze()
        return surr.squeeze()

    def compute_variogram(self, x, idx):
        """
        Compute variogram of `x` using pairs of regions indexed by `idx`.

        Parameters
        ----------
        x : (N,) np.ndarray
            Brain map
        idx : (ns,) np.ndarray[int]
            Indices of randomly sampled brain regions

        Returns
        -------
        v : (ns,ns) np.ndarray
            Variogram y-coordinates, i.e. 0.5 * (x_i - x_j) ^ 2, for i,j in idx

        """
        diff_ij = x[idx][:, None] - x[self._index[idx, :]]
        return 0.5 * np.square(diff_ij)

    def permute_map(self):
        """
        Return a random permutation of the target brain map.

        Returns
        -------
        (N,) np.ndarray
            Random permutation of target brain map

        """
        perm_idx = self._rs.permutation(self._nmap)
        if self._ismasked:
            mask_perm = self._x.mask[perm_idx]
            x_perm = self._x.data[perm_idx]
            return np.ma.masked_array(data=x_perm, mask=mask_perm)
        return self._x[perm_idx]

    def smooth_map(self, x, k):
        """
        Smooth `x` using `k` nearest neighboring regions.

        Parameters
        ----------
        x : (N,) np.ndarray
            Brain map
        k : float
            Number of nearest neighbors to include for smoothing

        Returns
        -------
        x_smooth : (N,) np.ndarray
            Smoothed brain map

        Notes
        -----
        Assumes `D` provided at runtime has been sorted.

        """
        jkn = self._index[:, :k]  # indices of k nearest neighbors
        xkn = x[jkn]  # values of k nearest neighbors
        dkn = self._D[:, :k]  # distances to k nearest neighbors
        weights = self._kernel(dkn)  # distance-weighted kernel
        # Kernel-weighted sum
        return (weights * xkn).sum(axis=1) / weights.sum(axis=1)

    def smooth_variogram(self, u, v, return_h=False):
        """
        Smooth a variogram.

        Parameters
        ----------
        u : (N,) np.ndarray
            Pairwise distances, ie variogram x-coordinates
        v : (N,) np.ndarray
            Squared differences, ie ariogram y-coordinates
        return_h : bool, default False
            Return distances at which smoothed variogram is computed

        Returns
        -------
        (nh,) np.ndarray
            Smoothed variogram samples
        (nh,) np.ndarray
            Distances at which smoothed variogram was computed (returned if
            `return_h` is True)

        Raises
        ------
        ValueError : `u` and `v` are not identically sized

        """
        if len(u) != len(v):
            raise ValueError("u and v must have same number of elements")

        # Subtract each element of h from each pairwise distance `u`.
        # Each row corresponds to a unique h.
        du = np.abs(u - self._h[:, None])
        w = np.exp(-np.square(2.68 * du / self._b) / 2)
        denom = np.nansum(w, axis=1)
        wv = w * v[None, :]
        num = np.nansum(wv, axis=1)
        output = num / denom
        if not return_h:
            return output
        return output, self._h

    def regress(self, x, y):
        """
        Linearly regress `x` onto `y`.

        Parameters
        ----------
        x : (N,) np.ndarray
            Independent variable
        y : (N,) np.ndarray
            Dependent variable

        Returns
        -------
        alpha : float
            Intercept term (offset parameter)
        beta : float
            Regression coefficient (scale parameter)
        res : float
            Sum of squared residuals

        """
        self._lm.fit(X=np.expand_dims(x, -1), y=y)
        beta = self._lm.coef_.item()
        alpha = self._lm.intercept_
        ypred = self._lm.predict(np.expand_dims(x, -1))
        res = np.sum(np.square(y-ypred))
        return alpha, beta, res

    def sample(self):
        """
        Randomly sample (without replacement) brain areas for variogram
        computation.

        Returns
        -------
        (self.ns,) np.ndarray
            Indices of randomly sampled areas

        """
        return self._rs.choice(
            a=self._nmap, size=self._ns, replace=False).astype(np.int32)

    @property
    def x(self):
        """ (N,) np.ndarray : brain map scalars """
        if self._ismasked:
            return np.ma.copy(self._x)
        return np.copy(self._x)

    @x.setter
    def x(self, x):
        self._ismasked = False
        x_ = dataio(x)
        check_map(x=x_)
        mask = np.isnan(x_)
        if mask.any():
            self._ismasked = True
            brain_map = np.ma.masked_array(data=x_, mask=mask)
        else:
            brain_map = x_
        self._x = brain_map

    @property
    def D(self):
        """ (N,N) np.memmap : Pairwise distance matrix """
        return np.copy(self._D)

    @D.setter
    def D(self, x):
        x_ = dataio(x)
        n = self._x.size
        if x_.shape[0] != n:
            raise ValueError(
                "D size along axis=0 must equal brain map size")
        self._D = x_[:, 1:self._knn + 1]  # prevent self-coupling

    @property
    def index(self):
        """ (N,N) np.memmap : indexes used to sort each row of dist. matrix """
        return np.copy(self._index)

    @index.setter
    def index(self, x):
        x_ = dataio(x)
        n = self._x.size
        if x_.shape[0] != n:
            raise ValueError(
                "index size along axis=0 must equal brain map size")
        self._index = x_[:, 1:self._knn+1].astype(np.int32)

    @property
    def nmap(self):
        """ int : length of brain map """
        return self._nmap

    @nmap.setter
    def nmap(self, x):
        self._nmap = int(x)

    @property
    def pv(self):
        """ int : percentile of pairwise distances at which to truncate """
        return self._pv

    @pv.setter
    def pv(self, x):
        pv = check_pv(x)
        self._pv = pv

    @property
    def deltas(self):
        """ np.ndarray or List[float] : proportions of nearest neighbors """
        return self._deltas

    @deltas.setter
    def deltas(self, x):
        check_deltas(deltas=x)
        self._deltas = x

    @property
    def nh(self):
        """ int : number of variogram distance intervals """
        return self._nh

    @nh.setter
    def nh(self, x):
        self._nh = x

    @property
    def kernel(self):
        """ Callable : smoothing kernel function

        Notes
        -----
        When setting kernel, use name of kernel as defined in ``config.py``.

        """
        return self._kernel

    @kernel.setter
    def kernel(self, x):
        kernel_callable = check_kernel(x)
        self._kernel = kernel_callable

    @property
    def resample(self):
        """ bool : whether to resample surrogate map values from target maps """
        return self._resample

    @resample.setter
    def resample(self, x):
        if not isinstance(x, bool):
            raise TypeError("expected bool, got {}".format(type(x)))
        self._resample = x

    @property
    def knn(self):
        """ int : number of nearest neighbors included in distance matrix """
        return self._knn

    @knn.setter
    def knn(self, x):
        if x > self._nmap:
            raise ValueError('knn must be less than len(X)')
        self._knn = int(x)

    @property
    def ns(self):
        """ int : number of randomly sampled regions used to construct map """
        return self._ns

    @ns.setter
    def ns(self, x):
        self._ns = int(x)

    @property
    def b(self):
        """ numeric : Gaussian kernel bandwidth """
        return self._b

    @b.setter
    def b(self, x):
        self._b = x

    @property
    def h(self):
        """ np.ndarray : distances at which variogram is evaluated """
        return self._h

    @h.setter
    def h(self, x):
        self._h = x


""" Functions for performing statistical inference using surrogate maps. """

import numpy as np
from scipy.stats import rankdata

__all__ = ['spearmanr', 'pearsonr', 'pairwise_r', 'nonparp']


def spearmanr(X, Y):
    """
    Multi-dimensional Spearman rank correlation between rows of `X` and `Y`.

    Parameters
    ----------
    X : (N,P) np.ndarray
    Y : (M,P) np.ndarray

    Returns
    -------
    (N,M) np.ndarray

    Raises
    ------
    TypeError : `X` or `Y` is not array_like
    ValueError : `X` and `Y` are not same size along second axis

    """
    if not isinstance(X, np.ndarray) or not isinstance(Y, np.ndarray):
        raise TypeError('X and Y must be numpy arrays')

    if X.ndim == 1:
        X = X.reshape(1, -1)
    if Y.ndim == 1:
        Y = Y.reshape(1, -1)

    n = X.shape[1]
    if n != Y.shape[1]:
        raise ValueError('X and Y must be same size along axis=1')

    return pearsonr(rankdata(X, axis=1), rankdata(Y, axis=1))


def pearsonr(X, Y):
    """
    Multi-dimensional Pearson correlation between rows of `X` and `Y`.

    Parameters
    ----------
    X : (N,P) np.ndarray
    Y : (M,P) np.ndarray

    Returns
    -------
    (N,M) np.ndarray

    Raises
    ------
    TypeError : `X` or `Y` is not array_like
    ValueError : `X` and `Y` are not same size along second axis

    """
    if not isinstance(X, np.ndarray) or not isinstance(Y, np.ndarray):
        raise TypeError('X and Y must be numpy arrays')

    if X.ndim == 1:
        X = X.reshape(1, -1)
    if Y.ndim == 1:
        Y = Y.reshape(1, -1)

    n = X.shape[1]
    if n != Y.shape[1]:
        raise ValueError('X and Y must be same size along axis=1')

    mu_x = X.mean(axis=1)
    mu_y = Y.mean(axis=1)

    s_x = X.std(axis=1, ddof=n - 1)
    s_y = Y.std(axis=1, ddof=n - 1)
    cov = np.dot(X, Y.T) - n * np.dot(
        mu_x[:, np.newaxis], mu_y[np.newaxis, :])
    return cov / np.dot(s_x[:, np.newaxis], s_y[np.newaxis, :])


def pairwise_r(X, flatten=False):
    """
    Compute pairwise Pearson correlations between rows of `X`.

    Parameters
    ----------
    X : (N,M) np.ndarray
    flatten : bool, default False
        If True, return flattened upper triangular elements of corr. matrix

    Returns
    -------
    (N*(N-1)/2,) or (N,N) np.ndarray
        Pearson correlation coefficients

    """
    rp = pearsonr(X, X)
    if not flatten:
        return rp
    triu_inds = np.triu_indices_from(rp, k=1)
    return rp[triu_inds].flatten()


def nonparp(stat, dist):
    """
    Compute two-sided non-parametric p-value.

    Compute the fraction of elements in `dist` which are more extreme than
    `stat`.

    Parameters
    ----------
    stat : float
        Test statistic
    dist : (N,) np.ndarray
        Null distribution for test statistic

    Returns
    -------
    float
        Fraction of elements in `dist` which are more extreme than `stat`

    """
    n = float(len(dist))
    return np.sum(np.abs(dist) > abs(stat)) / n


from .base import Base
from .sampled import Sampled
from brainsmash.mapgen.eval import base_fit, sampled_fit

__all__ = ['Base', 'Sampled', 'base_fit', 'sampled_fit']


from pathlib import Path
import numpy as np

__all__ = ['check_extensions',
           'check_outfile',
           'check_pv',
           'check_deltas',
           'check_map',
           'check_distmat',
           'check_file_exists',
           'check_sampled',
           'is_string_like',
           'count_lines',
           'stripext']


def check_map(x):
    """
    Check that brain map conforms to expectations.

    Parameters
    ----------
    x : np.ndarray
        Brain map

    Returns
    -------
    None

    Raises
    ------
    TypeError : `x` is not a np.ndarray object
    ValueError : `x` is not one-dimensional

    """
    if not isinstance(x, np.ndarray):
        e = "Brain map must be array-like\n"
        e += "got type {}".format(type(x))
        raise TypeError(e)
    if x.ndim != 1:
        e = "Brain map must be one-dimensional\n"
        e += "got shape {}".format(x.shape)
        raise ValueError(e)


def check_extensions(filename, exts):
    """
    Test filename for a set of file extensions.

    Parameters
    ----------
    filename : str
        Path to file
    exts : List[str]
        List of allowed file extensions for `filename`

    Returns
    -------
    bool
        True if `filename`'s extensions is in `exts`

    Raises
    ------
    TypeError : `filename` is not string-like

    """
    if not is_string_like(filename):
        e = "Expected str, got {}".format(type(filename))
        e += "\nfilename: {}".format(filename)
        raise TypeError(e)
    ext = Path(filename).suffix
    return True if ext in exts else False


def check_distmat(D):
    """
    Check that a distance matrix conforms to expectations.

    Parameters
    ----------
    D : (N,N) np.ndarray
        Pairwise distance matrix

    Returns
    -------
    None

    Raises
    ------
    ValueError : `D` is not symmetric

    """
    if not np.allclose(D, D.T):
        raise ValueError("Distance matrix must be symmetric")


def check_sampled(D, index):
    """
    Check arguments provided to :class:`brainsmash.mapgen.sampled.Sampled`.

    Parameters
    ----------
    D : np.ndarray or np.memmap
        Pairwise distance matrix
    index : np.ndarray or np.memmap
        See :class:`brainsmash.mapgen.sampled.Sampled`

    Returns
    -------
    None

    Raises
    ------
    ValueError : Arguments do not have identical dimensions
    ValueError : `D` has not been sorted column-wise
    TypeError : rows of `D` or `index` are not sorted (ascending)

    """
    if not isinstance(D, np.ndarray) or not isinstance(index, np.ndarray):
        raise TypeError("'D' and 'index' must be array_like")
    if D.shape != index.shape:
        e = "`D` and `index` must have identical dimensions\n"
        e += "D.shape: {}".format(D.shape)
        e += "index.shape: {}".format(index.shape)
        raise ValueError(e)
    if isinstance(D, np.ndarray):
        if not np.all(D[:, 1:] >= D[:, :-1]):
            raise ValueError("Each row of `D` must be sorted (ascending)")
    else:  # just test the first row
        if not np.all(D[0, 1:] >= D[0, :-1]):
            raise ValueError("Each row of `D` must be sorted (ascending)")


def check_deltas(deltas):
    """
    Check input argument `deltas`.

    Parameters
    ----------
    deltas : np.ndarray or List[float]
        Proportions of neighbors to include for smoothing, in (0, 1]

    Returns
    -------
    None

    Raises
    ------
    TypeError : `deltas` is not a List or np.ndarray object
    ValueError : One or more elements of `deltas` lies outside (0,1]

    """
    if not isinstance(deltas, list) and not isinstance(deltas, np.ndarray):
        raise TypeError("Parameter `deltas` must be a list or ndarray")
    for d in deltas:
        if d <= 0 or d > 1:
            raise ValueError("Each element of `deltas` must lie in (0,1]")


def check_pv(pv):
    """
    Check input argument `pv`.

    Parameters
    ----------
    pv : int
        Percentile of the pairwise distance distribution at which to truncate
        during variogram fitting.

    Returns
    -------
    int

    Raises
    ------
    ValueError : `pv` lies outside range (0, 100]

    """
    try:
        pv = int(pv)
    except ValueError:
        raise ValueError("parameter 'pv' must be an integer in (0,100]")
    if pv <= 0 or pv > 100:
        raise ValueError("parameter 'pv' must be in (0,100]")
    return pv


def check_outfile(filename):
    """
    Warn if file exists and throw error if parent directory does not exist.

    Parameters
    ----------
    filename : filename
        File to be written

    Returns
    -------
    None

    Raises
    ------
    IOError : Parent directory of `filename` does not exist
    ValueError : directory was provided instead of file

    """
    if Path(filename).is_dir():
        raise ValueError("expected filename, got dir: {}".format(filename))
    if Path(filename).exists():
        print("WARNING: overwriting {}".format(filename))

    # Check that parent directory exists
    if not Path(filename).parent.exists():
        raise IOError("Output directory does not exist: {}".format(
            str(Path(filename).parent)))


def is_string_like(obj):
    """ Check whether `obj` behaves like a string. """
    try:
        obj + ''
    except (TypeError, ValueError):
        return False
    return True


def stripext(f):
    """
    Strip extension from a file.

    Parameters
    ----------
    f : filename
        Path to file with extension

    Returns
    -------
    objective : filename
        Path to file without extension

    """
    return str(Path(f).with_suffix(''))


def check_file_exists(f):
    """
    Check that file exists and has nonzero size.

    Parameters
    ----------
    f : filename

    Returns
    -------
    None

    Raises
    ------
    IOError : file does not exist or has zero size

    """
    if not Path(f).exists() or Path(f).stat().st_size == 0:
        raise IOError("{} was not successfully written to".format(f))


def count_lines(filename):
    """
    Count number of lines in a file.

    Parameters
    ----------
    filename : filename

    Returns
    -------
    int
        number of lines in file

    """
    with open(filename, 'rb') as f:
        lines = 0
        buf_size = 1024 * 1024
        read_f = f.raw.read
        buf = read_f(buf_size)
        while buf:
            lines += buf.count(b'\n')
            buf = read_f(buf_size)
        return lines


from ..config import parcel_labels_lr
from .checks import is_string_like
import tempfile
from os import path, system
import nibabel as nib
import pandas as pd
from pathlib import Path
import numpy as np


def dataio(x):
    """
    Data I/O for core classes.

    To facilitate flexible user inputs, this function loads data from:
        - neuroimaging files
        - txt files
        - npy files (memory-mapped arrays)
        - array_like data

    Parameters
    ----------
    x : filename or np.ndarray or np.memmap

    Returns
    -------
    np.ndarray or np.memmap

    Raises
    ------
    FileExistsError : file does not exist
    RuntimeError : file is empty
    ValueError : file type cannot be determined or is not implemented
    TypeError : input is not a filename or array_like object

    """
    if is_string_like(x):
        if not Path(x).exists():
            raise FileExistsError("file does not exist: {}".format(x))
        if Path(x).stat().st_size == 0:
            raise RuntimeError("file is empty: {}".format(x))
        if Path(x).suffix == ".npy":  # memmap
            return np.load(x, mmap_mode='r')
        if Path(x).suffix == ".txt":  # text file
            return np.loadtxt(x).squeeze()
        try:
            return load(x)
        except TypeError:
            raise ValueError(
                "expected npy or txt or nii or gii file, got {}".format(
                    Path(x).suffix))
    else:
        if not isinstance(x, np.ndarray):
            raise TypeError(
                "expected filename or array_like obj, got {}".format(type(x)))
        return x


def load(filename):
    """
    Load data contained in a CIFTI2-/GIFTI-format neuroimaging file.

    Parameters
    ----------
    filename : filename
        Path to neuroimaging file

    Returns
    -------
    (N,) np.ndarray
        Brain map data stored in `filename`

    Raises
    ------
    TypeError : `filename` has unknown filetype

    """
    try:
        return _load_gifti(filename)
    except AttributeError:
        try:
            return _load_cifti2(filename)
        except AttributeError:
            raise TypeError("This file cannot be loaded: {}".format(filename))


def export_cifti_mapping(image=None):
    """
    Compute the map from CIFTI indices to surface vertices and volume voxels.

    Parameters
    ----------
    image : filename or None, default None
        Path to NIFTI-2 format (.nii) neuroimaging file. The metadata
        from this file is used to determine the CIFTI indices and voxel
        coordinates of elements in the image. This file must include all
        subcortical volumes and both cortical hemispheres.

    Returns
    -------
    maps : dict
        A dictionary containing the maps between CIFTI indices, surface
        vertices, and volume voxels. Keys include 'cortex_left',
        'cortex_right', and 'volume'.

    Notes
    -----
    `image` must be a whole-brain NIFTI file for this function to work
    as-written. See the Workbench documentation here for more details:
    https://www.humanconnectome.org/software/workbench-command/-cifti-export-dense-mapping.

    """

    # Temporary files written to by Workbench, then loaded and returned
    tempdir = tempfile.gettempdir()

    if image is None:
        image = parcel_labels_lr

    basecmd = "wb_command -cifti-export-dense-mapping '{}' COLUMN ".format(
        image)

    # Subcortex (volume)
    volume = path.join(tempdir, "volume.txt")
    system(basecmd + " -volume-all '{}' -structure ".format(volume))

    # Cortex left
    left = path.join(tempdir, "left.txt")
    system(basecmd + "-surface CORTEX_LEFT '{}'".format(left))

    # Cortex right
    right = path.join(tempdir, "right.txt")
    system(basecmd + "-surface CORTEX_RIGHT '{}'".format(right))

    maps = dict()
    maps['volume'] = pd.read_table(
        volume, header=None, index_col=0, sep=' ',
        names=['structure', 'mni_i', 'mni_j', 'mni_k']).rename_axis('index')

    maps['cortex_left'] = pd.read_table(left, header=None, index_col=0, sep=' ',
                                        names=['vertex']).rename_axis('index')
    maps['cortex_right'] = pd.read_table(
        right, header=None, index_col=0, sep=' ', names=['vertex']).rename_axis(
        'index')

    return maps


def _load_gifti(filename):
    """
    Load data stored in a GIFTI (.gii) neuroimaging file.

    Parameters
    ----------
    filename : filename
        Path to GIFTI-format (.gii) neuroimaging file

    Returns
    -------
    np.ndarray
        Neuroimaging data in `filename`

    """
    return nib.load(filename).darrays[0].data


def _load_cifti2(filename):
    """
    Load data stored in a CIFTI-2 format neuroimaging file (e.g., .dscalar.nii
    and .dlabel.nii files).

    Parameters
    ----------
    filename : filename
        Path to CIFTI-2 format (.nii) file

    Returns
    -------
    np.ndarray
        Neuroimaging data in `filename`

    Notes
    -----
    CIFTI-2 files follow the NIFTI-2 file format. CIFTI-2 files may contain
    surface-based and/or volumetric data.

    """
    return np.asanyarray(nib.load(filename).dataobj).squeeze()


from .checks import check_sampled, check_file_exists, check_distmat
from .checks import check_map, check_deltas, check_pv, check_outfile
from .checks import check_extensions, count_lines, is_string_like, stripext
from .dataio import dataio


""" Routines for constructing distance matrices from neuroimaging files. """

from ..utils.checks import *
from .io import check_surface, check_image_file
from ..utils.dataio import load, export_cifti_mapping, dataio
from ..config import parcel_labels_lr
from .surf import make_surf_graph
from scipy.spatial.distance import cdist
from scipy import ndimage, sparse
from tempfile import gettempdir, NamedTemporaryFile
from os import path
from os import system
import numpy as np
import nibabel as nib
import numpy.lib.format
# from warnings import warn
# from ..utils.dataio import dataio
# from os import remove
from joblib import Parallel, delayed

__all__ = ['cortex', 'subcortex', 'parcellate', 'volume']


def cortex(surface, outfile, euclid=False, dlabel=None, medial=None,
           use_wb=True, unassigned_value=0, verbose=True, n_jobs=None):
    """
    Calculates surface distances for `surface` mesh and saves to `outfile`.

    Parameters
    ----------
    surface : str or os.PathLike
        Path to surface file on which to calculate distance
    outfile : str or os.PathLike
        Path to which generated distance matrix should be saved
    euclid : bool, optional, default False
        Whether to compute Euclidean distance instead of surface distance
    dlabel : str or os.PathLike, optional, default None
        Path to file with parcel labels for provided `surf`. If provided,
        calculate and save parcel-parcel distances instead of vertex distances.
    medial : str or os.PathLike, optional, default None
        Path to file containing labels for vertices corresponding to medial
        wall. If provided and `use_wb=False`, will disallow calculation of
        surface distance along the medial wall.
    use_wb : bool, optional, default True
        Whether to use calls to `wb_command -surface-geodesic-distance` for
        computation of the surface distance matrix; this will involve
        significant disk I/O. If False, all computations will be done in memory
        using the `scipy.sparse.csgraph.dijkstra` function.
    unassigned_value : int, default 0
        Label value which indicates that a vertex/voxel is not assigned to
        any parcel. This label is excluded from the output. 0 is the default
        value used by Connectome Workbench, e.g. for ``-cifti-parcellate``.
    verbose : bool, optional, default True
        Whether to print status updates while distances are calculated.
    n_jobs : int, default None
        The number of parallel jobs to run for distance calculation. None means
        1 unless in a ``joblib.parallel_backend`` context. -1 means using all
        processors.

    Returns
    -------
    distance : str
        Path to generated `outfile`

    Notes
    -----
    The surface distance matrix computed with `use_wb=False` will have slightly
    lower values than when `use_wb=True` due to known estimation errors. These
    will be fixed at a later date. By default, `use_wb=True` for backwards-
    compatibility but this will be changed in a future update.

    Raises
    ------
    ValueError : inconsistent # of vertices in label, mask, and/or surface file

    """

    n_vert = len(load(surface))

    # get data from dlabel / medial wall files if provided
    labels, mask = None, np.zeros(n_vert, dtype=bool)
    if dlabel is not None:
        labels = check_image_file(dlabel)
        if len(labels) != n_vert:
            raise ValueError('Provided `dlabel` file does not contain same '
                             'number of vertices as provided `surface`')
    if medial is not None:
        mask = np.asarray(check_image_file(medial), dtype=bool)
        if len(mask) != n_vert:
            raise ValueError('Provided `medial` file does not contain same '
                             'number of vertices as provided `surface`')

    # define which function we'll be using to calculate the distances
    if euclid:
        func = _get_euclid_distance
        graph = check_surface(surface)  # vertex coordinates
    else:
        if use_wb:
            func = _get_workbench_distance
            graph = surface
        else:
            func = _get_graph_distance
            vert, faces = [darray.data for darray in nib.load(surface).darrays]
            graph = make_surf_graph(vert, faces, mask=mask)

    # if we want the vertex-vertex distance matrix we'll stream it to disk to
    # save on memory, a la `_geodesic()` or `_euclid()`
    # NOTE: streaming to disk takes a lot more _time_ than storing in memory
    if labels is None:
        with open(outfile, 'w') as dest:
            for n in range(n_vert):
                if verbose and n % 1000 == 0:
                    print('Running vertex {} of {}'.format(n, n_vert))
                np.savetxt(dest, func(n, graph))
    # we can store the temporary n_vert x label matrix in memory; running this
    # is much faster than trying to read through the giant vertex-vertex
    # distance matrix file
    else:
        # depends on size of parcellation, but assuming even a liberal 1000
        # parcel atlas this will be ~250 MB in-memory for the default fslr32k
        # resolution
        unique_parcels = np.unique(labels)
        par, func = Parallel(n_jobs=n_jobs), delayed(func)
        dist = np.row_stack(par(func(n, graph, labels) for n in range(n_vert)))
        # average rows (vertices) into parcels; columns are already parcels
        dist = np.row_stack([
            dist[labels == lab].mean(axis=0) for lab in unique_parcels])
        dist[np.diag_indices_from(dist)] = 0
        # NOTE: if `medial` is supplied and any of the parcel labels correspond
        # to the medial wall then those parcel-parcel distances will be `inf`!

        # remove unassigned parcel
        if unassigned_value in unique_parcels:
            idx = list(unique_parcels).index(unassigned_value)
            dist = np.delete(dist, idx, axis=0)
            dist = np.delete(dist, idx, axis=1)

        np.savetxt(outfile, dist)

    return outfile


# def cortex(surface, outfile, euclid=False):
#     """
#     Compute distance matrix for a cortical hemisphere.
#
#     Parameters
#     ----------
#     surface : filename
#         Path to a surface GIFTI (.surf.gii) from which to compute distances
#     outfile : filename
#         Path to output file
#     euclid : bool, default False
#         If True, compute Euclidean distances; if False, compute geodesic dist
#
#     Returns
#     -------
#     filename : str
#         Path to output distance matrix file
#
#     """
#
#     check_outfile(outfile)
#
#     # Strip file extensions and define output text file
#     outfile = stripext(outfile)
#     dist_file = outfile + '.txt'
#
#     # Load surface file
#     coords = check_surface(surface)
#
#     if euclid:  # Pairwise Euclidean distance matrix
#         of = _euclidean(dist_file=dist_file, coords=coords)
#     else:  # Pairwise geodesic distance matrix
#         of = _geodesic(
#             surface=surface, dist_file=dist_file, coords=coords)
#     return of


def subcortex(fout, image_file=None, dlabel=None, unassigned_value=0,
              verbose=True):
    """
    Compute inter-voxel Euclidean distance matrix.

    Parameters
    ----------
    fout : str or os.Pathlike
        Path to output text file
    image_file : str or os.Pathlike or None, default None
        Path to a CIFTI-2 format neuroimaging file (eg .dscalar.nii). MNI
        coordinates for each subcortical voxel are read from this file's
        metadata. If None, uses dlabel file defined in ``brainsmash.config.py``.
    dlabel : str or os.PathLike, optional, default None
        Path to file with parcel labels for provided `surf`. If provided,
        calculate and save parcel-parcel distances instead of vertex distances.
    unassigned_value : int, default 0
        Label value which indicates that a vertex/voxel is not assigned to
        any parcel. This label is excluded from the output. 0 is the default
        value used by Connectome Workbench, e.g. for ``-cifti-parcellate``.
    verbose : bool, optional, default True
        Whether to print status updates while distances are calculated.

    Returns
    -------
    filename : str
        Path to output text file containing pairwise Euclidean distances

    Notes
    -----
    Voxel indices are used as a proxy for physical distance, since the two are
    related by a simple linear scaling. Note that this assumes voxels are
    cubic, i.e., that the scaling is equivalent along the x, y, and z dimension.

    Raises
    ------
    ValueError : `image_file` header does not contain volume information
    IndexError : Inconsistent number of elements in `image_file` and `dlabel`

    """
    # TODO add more robust error handling

    check_outfile(fout)

    # Strip file extensions and define output text file
    fout = stripext(fout)
    dist_file = fout + '.txt'

    # Load CIFTI mapping  (i.e., map from scalar index to 3-D MNI indices)
    if image_file is None:
        image_file = parcel_labels_lr
    maps = export_cifti_mapping(image_file)
    if "volume" not in maps.keys():
        e = "Subcortical information was not found in {}".format(image_file)
        raise ValueError(e)
    coords = maps['volume'].drop("structure", axis=1).values
    # outfile = _euclidean(dist_file=dist_file, coords=coords)
    n_vert = coords.shape[0]

    # Get data from dlabel file if provided
    labels, mask = None, np.zeros(n_vert, dtype=bool)
    if dlabel is not None:
        all_labels = check_image_file(dlabel)
        volume_indices = maps['volume'].index.values
        try:
            labels = all_labels[volume_indices]
        except IndexError:
            raise IndexError(
                'Volumetric CIFTI indices obtained from `image_file` could not '
                'be indexed from the provided `dlabel` file.')

    func = _get_euclid_distance

    # If we want the vertex-vertex distance matrix we'll stream it to disk to
    # save on memory.
    # NOTE: streaming to disk takes a lot more _time_ than storing in memory
    if labels is None:
        with open(dist_file, 'w') as dest:
            for n in range(n_vert):
                if verbose and n % 1000 == 0:
                    print('Running vertex {} of {}'.format(n, n_vert))
                np.savetxt(dest, func(n, coords))
    # We can store the temporary n_vert x label matrix in memory; running this
    # is much faster than trying to read through the giant vertex-vertex
    # distance matrix file
    else:
        # depends on size of parcellation, but assuming even a liberal 1000
        # parcel atlas this will be ~250 MB in-memory for the default fslr32k
        # resolution
        unique_parcels = np.unique(labels)
        dist = np.zeros((n_vert, unique_parcels.size), dtype='float32')
        # NOTE: because this is being done in-memory it could be multiprocessed
        # for additional speed-ups, if desired!
        for n in range(n_vert):
            if verbose and n % 1000 == 0:
                print('Running vertex {} of {}'.format(n, n_vert))
            dist[n] = func(n, coords, labels)
        # average rows (vertices) into parcels; columns are already parcels
        dist = np.row_stack([
            dist[labels == lab].mean(axis=0) for lab in unique_parcels])
        dist[np.diag_indices_from(dist)] = 0

        # remove unassigned parcel
        if unassigned_value in unique_parcels:
            idx = list(unique_parcels).index(unassigned_value)
            dist = np.delete(dist, idx, axis=0)
            dist = np.delete(dist, idx, axis=1)

        np.savetxt(dist_file, dist)

    return dist_file


def parcellate(infile, dlabel_file, outfile, delimiter=' ', unassigned_value=0):
    """
    Parcellate a dense distance matrix.

    Parameters
    ----------
    infile : filename
        Path to `delimiter`-separated distance matrix file
    dlabel_file : filename
        Path to parcellation file  (.dlabel.nii)
    outfile : filename
        Path to output text file (to be created)
    delimiter : str, default ' '
        Delimiter between elements in `infile`
    unassigned_value : int, default 0
        Label value which indicates that a vertex/voxel is not assigned to
        any parcel. This label is excluded from the output. 0 is the default
        value used by Connectome Workbench, e.g. for ``-cifti-parcellate``.

    Returns
    -------
    filename : str
        Path to output parcellated distance matrix file

    Notes
    -----
    For two parcels A and B, the inter-parcel distance is defined as the mean
    distance between area i in parcel A and area j in parcel B, for all i,j.

    Inputs `infile` and `dlabel_file` should include the same anatomical
    structures, e.g. the left cortical hemisphere, and should have the same
    number of elements. If you need to isolate one anatomical structure from
    `dlabel_file`, see the following Workbench function:
    https://www.humanconnectome.org/software/workbench-command/-cifti-separate

    Raises
    ------
    ValueError : `infile` and `dlabel_file` have inconsistent sizes

    """

    print("\nComputing parcellated distance matrix\n")
    m = "For a 32k-vertex cortical hemisphere, this takes about 30 mins "
    m += "for the HCP MMP parcellation. For subcortex, this takes about an hour"
    m += " for the CAB-NP parcellation."
    print(m)

    check_outfile(outfile)

    # Strip file extensions and define output text file
    fout = stripext(outfile)
    dist_file = fout + '.txt'

    # Load parcel labels
    labels = check_image_file(dlabel_file)

    with open(infile, 'r') as fp:

        # Compare number of elements in distance matrix to dlabel file
        nrows = 1
        ncols = len(fp.readline().split(delimiter))
        for l in fp:
            if l.rstrip():
                nrows += 1
        if not (labels.size == nrows == ncols):
            e = "Files must contain same number of areas\n"
            e += "{} areas in {}\n".format(labels.size, dlabel_file)
            e += "{} rows and {} cols in {}".format(nrows, ncols, infile)
            raise ValueError(e)
        fp.seek(0)  # return to beginning of file

        # Skip unassigned parcel label
        unique_labels = np.unique(labels)
        nparcels = unique_labels.size
        if unassigned_value in unique_labels:
            unique_labels = unique_labels[unique_labels != unassigned_value]
            nparcels -= 1

        # Create vertex-level mask for each unique cortical parcel
        parcel_vertex_mask = {l: labels == l for l in unique_labels}

        # Loop over pairwise parcels at the level of surface vertices
        distance_matrix = np.zeros((nparcels, nparcels))

        for i, li in enumerate(unique_labels[:-1]):

            # Labels of parcels for which to compute mean geodesic distance
            labels_lj = unique_labels[i+1:]

            # Initialize lists in which to store pairwise vertex-level distances
            parcel_distances = {lj: list() for lj in labels_lj}

            # Loop over vertices with parcel label i
            li_vertices = np.where(parcel_vertex_mask[li])[0]

            fp.seek(0)
            for vi, l in enumerate(fp):
                if vi in li_vertices:
                    # Load distance from vertex vi to every other vertex
                    d = np.array(l.split(delimiter), dtype=np.float32)
                    # Store dists from vertex vi to vertices in parcel j
                    for lj in labels_lj:
                        vi_lj_distances = d[parcel_vertex_mask[lj]]
                        parcel_distances[lj].append(vi_lj_distances)

            # Compute average geodesic distances
            for j, lj in enumerate(labels_lj):
                mean_distance = np.mean(parcel_distances[lj])
                distance_matrix[i, i + j + 1] = mean_distance

            print("# Parcel label %s complete." % str(li))

        # Make final matrix symmetric
        i, j = np.triu_indices(nparcels, k=1)
        distance_matrix[j, i] = distance_matrix[i, j]

        # Write to file
        np.savetxt(fname=dist_file, X=distance_matrix)
        check_file_exists(dist_file)
        return dist_file


def volume(coord_file, outdir, chunk_size=1000):
    """
    Generate distance-matrix-related memory-mapped files for volumetric data.

    Parameters
    ----------
    coord_file : str or os.PathLike
        Path to text file in which rows correspond to voxels of interest, and
        *three* columns correspond to voxels' coordinates
    outdir : str or os.PathLike
        Path to the directory where outputs will be saved
    chunk_size : int, default 1000
        The number of voxels to process per chunk. For N voxels, this will
        impose a memory burden of N*`chunk_size` per iteration (in contrast to
        a memory burden of N*N for a single iteration, in the absence of
        chunking).

    Returns
    -------
    dict
        Keys are 'D' and 'index'; values are absolute paths to the
        corresponding files on disk. These files are used as inputs to
        `brainsmash.mapgen.sampled.Sampled`.

    Raises
    ------
    IOError : `outdir` doesn't exist
    ValueError : `coord_file` does not contain three columns

    Notes
    -----
    This function computes 3D Euclidean distance between all pairs of voxels
    whose 3D coordinates are provided in `coord_file`. The distance matrix is
    not saved in its raw, symmetric form, but rather as a pair of memory-mapped
    arrays which are needed to create an instance of the
    `brainsmash.mapgen.sampled.Sampled` class. See
    `brainsmash.mapgen.memmap.txt2memmap` for more details. Note that the input
    file should only contain brain voxels --- i.e., voxels in your brain map
    of interest. Each row of `coord_file`, which indicates the coordinates of
    a voxel, should therefore correspond to one brain map value.

    """
    if not path.exists(outdir):
        raise IOError("Output directory does not exist: {}".format(outdir))

    # Load data
    X = dataio(coord_file)
    if X.ndim != 2 or X.shape[1] != 3:
        e = f"expected N rows by 3 columns, instead got shape {X.shape}"
        raise ValueError(e)

    n = X.shape[0]
    print(f"loading voxels coordinates from {coord_file}")
    print(f"file contains {n} voxels")
    print(f"saving memory-mapped distance matrix files to {outdir}")

    # Open memory-mapped arrays
    npydfile = path.join(outdir, "distmat.npy")
    npyifile = path.join(outdir, "index.npy")
    fpd = numpy.lib.format.open_memmap(
        npydfile, mode='w+', dtype=np.float32, shape=(n, n))
    fpi = numpy.lib.format.open_memmap(
        npyifile, mode='w+', dtype=np.int32, shape=(n, n))

    i = 0
    while i < n:
        j = min(i + chunk_size, n)
        d = cdist(X[i:j], X)  # compute 3D euclidean distance
        sort_idx = np.argsort(d, axis=1)  # sort each row
        fpd[i:j] = d[np.arange(j-i)[:, np.newaxis], sort_idx]
        fpi[i:j] = sort_idx  # indices used to sort each row
        i += chunk_size

    del fpd  # Flush memory changes to disk
    del fpi

    return {'D': npydfile, 'index': npyifile}  # Return filenames

# def euclidean(coords, fout, chunk_size=10000):
#     """
#     Generate a three-dimensional Euclidean distance matrix and write to file.
#
#     Parameters
#     ----------
#     coords : (N,3) np.ndarray or str or os.Pathlike
#         3-D coordinates for N brain regions (e.g. voxels, parcels), provided
#         either as a numpy array or as a filename
#     fout : filename
#         Outputs are written to this file; include .txt extension!
#     chunk_size : float
#         The distance matrix is written to file `chunk_size` rows at
#         a time, assuming N is larger than `chunk_size`.
#
#     Returns
#     -------
#     filename : str
#         Output text file
#
#     Raises
#     ------
#     TypeError : coordinates not a numpy array
#     ValueError : coordinates not a two-dimensional (N,3) matrix
#
#     """
#
#     X = dataio(coords)
#     if not isinstance(X, np.ndarray):
#         e = "Expected array-like, got type {}".format(type(X))
#         raise TypeError(e)
#     if X.ndim != 2 or X.shape[1] != 3:
#         e = "Expected (N,3)-shaped coordinates, got shape {}".format(X.shape)
#         raise ValueError(e)
#     if path.exists(fout):
#         warn("Output file already exists, will be overwritten")
#         if input("continue? [y/n] ") == "y":
#             remove(fout)
#         else:
#             print("terminating")
#             return
#
#     n = X.shape[0]
#     with open(fout, 'ab') as fp:
#         i = 0
#         while i < n:
#             j = min(i+chunk_size, n)
#             print(i, j)
#             Y = X[i:i+j]
#             d = cdist(Y, X)
#             np.save(fp, d)
#             i += chunk_size
#     check_file_exists(f=fout)
#     return fout


def _geodesic(surface, dist_file, coords):
    """
    Compute pairwise geodesic distance between rows of `coords`. Write results
    to `dist_file`.

    Parameters
    ----------
    surface : filename
        Path to a surface GIFTI file (.surf.gii) from which to compute distances
    dist_file : filename
        Path to output file, with .txt extension
    coords : (N,3) np.ndarray
        MNI coordinates for N voxels/vertices

    Returns
    -------
    filename : str
        Path to output distance matrix file

    Notes
    -----
    This function uses command-line utilities included in Connectome Workbench.

    """
    nvert = coords.shape[0]

    print("\nComputing geodesic distance matrix\n")
    m = "For a 32k-vertex cortical hemisphere, this may take up to two hours."
    print(m)

    # Files produced at runtime by Workbench commands
    temp = gettempdir()
    coord_file = path.join(temp, "coords.func.gii")
    distance_metric_file = path.join(temp, "dist.func.gii")

    # Create a metric file containing the coordinates of each surface vertex
    cmd = 'wb_command -surface-coordinates-to-metric "{0:s}" "{1:s}"'
    system(cmd.format(surface, coord_file))

    with open(dist_file, 'w') as f:
        for ii in np.arange(coords.shape[0]):
            cmd = 'wb_command -surface-geodesic-distance "{0}" {1} "{2}" '
            system(cmd.format(surface, ii, distance_metric_file))
            distance_from_iv = load(distance_metric_file)
            line = " ".join([str(dij) for dij in distance_from_iv])
            f.write(line + "\n")
            if not (ii % 1000):
                print("Vertex {} of {} complete.".format(ii+1, nvert))
    check_file_exists(f=dist_file)
    return dist_file


def _get_workbench_distance(vertex, surf, labels=None):
    """
    Gets surface distance of `vertex` to all other vertices in `surf`.

    Parameters
    ----------
    vertex : int
        Index of vertex for which to calculate surface distance
    surf : str or os.PathLike
        Path to surface file on which to calculate distance
    labels : array_like, optional (default None)
        Labels indicating parcel to which each vertex belongs. If provided,
        distances will be averaged within distinct labels.

    Returns
    -------
    dist : (N,) numpy.ndarray
        Distance of `vertex` to all other vertices in `graph` (or to all
        parcels in `labels`, if provided)

    """

    distcmd = 'wb_command -surface-geodesic-distance {surf} {vertex} {out}'

    # run the geodesic distance command with wb_command
    with NamedTemporaryFile(suffix='.func.gii') as out:
        system(distcmd.format(surf=surf, vertex=vertex, out=out.name))
        dist = load(out.name)

    return _get_parcel_distance(vertex, dist, labels)


def _get_graph_distance(vertex, graph, labels=None):
    """
    Gets surface distance of `vertex` to all other vertices in `graph`

    Parameters
    ----------
    vertex : int
        Index of vertex for which to calculate surface distance
    graph : array_like
        Graph along which to calculate shortest path distances
    labels : array_like, optional
        Labels indicating parcel to which each vertex belongs. If provided,
        distances will be averaged within unique labels

    Returns
    -------
    dist : (N,) numpy.ndarray
        Distance of `vertex` to all other vertices in `graph` (or to all
        parcels in `labels`, if provided)

    Notes
    -----
    Distances are computed using Dijkstra's algorithm.

    """

    # this involves an up-cast to float64; will produce some numerical rounding
    # discrepancies here when compared to the wb_command subprocess call
    dist = sparse.csgraph.dijkstra(graph, directed=False, indices=[vertex])
    return _get_parcel_distance(vertex, dist, labels)


def _get_euclid_distance(vertex, coords, labels=None):
    """
    Gets Euclidean distance of `vertex` to all other vertices in `coords`.

    Parameters
    ----------
    vertex : int
        Index of vertex for which to calculate Euclidean distance
    coords : (N,3) array_like
        Coordinates of vertices on surface mesh
    labels : (N,) array_like, optional (default None)
        Labels indicating parcel to which each vertex belongs. If provided,
        distances will be averaged within M unique labels

    Returns
    -------
    dist : (N,) or (M,) np.ndarray
        Distance of `vertex` to all other vertices in `coords` (or to all
        unique parcels in `labels`, if provided)

    """
    dist = np.squeeze(cdist(coords[[vertex]], coords))
    return _get_parcel_distance(vertex, dist, labels)


def _get_parcel_distance(vertex, dist, labels=None):
    """
    Average `dist` within `labels`, if provided

    Parameters
    ----------
    vertex : int
        Index of vertex used to calculate `dist`
    dist : (N,) array_like
        Distance of `vertex` to all other vertices
    labels : (N,) array_like, optional (default None)
        Labels indicating parcel to which each vertex belongs. If provided,
        `dist` will be average within distinct labels.

    Returns
    -------
    dist : np.ndarray
        Distance from `vertex` to all vertices/parcels, cast to float32

    """

    if labels is not None:
        dist = ndimage.mean(input=np.delete(dist, vertex),
                            labels=np.delete(labels, vertex),
                            index=np.unique(labels))

    return np.atleast_2d(dist).astype(np.float32)


""" Functions for Connectome Workbench-style neuroimaging file I/O. """

from ..utils.dataio import load
from ..utils.checks import *
import nibabel as nib
import numpy as np

__all__ = ['image2txt', 'check_surface', 'check_image_file']


def image2txt(image_file, outfile, maskfile=None, delimiter=' '):
    """
    Export neuroimaging data to txt file.

    Parameters
    ----------
    image_file : filename
        Path to input neuroimaging file
    outfile : filename
        Path to output text file
    maskfile : filename or None, default None
        Path to neuroimaging file containing a binary map where non-zero values
        indicate masked brain regions.
    delimiter : str, default ' '
        Character used to delimit elements in `image_file`

    Notes
    -----
    More generally, this can be done via ``wb_command -cifti-convert -to-text <image_file> <outfile>``.

    """
    x = check_image_file(image_file)
    check_outfile(outfile)
    if maskfile is not None:
        mask = check_image_file(maskfile).astype(bool)
        x = x[~mask]
    is_string_like(delimiter)
    np.savetxt(outfile, x, delimiter=delimiter)


def check_image_file(image):
    """
    Check a neuroimaging file and return internal scalar neuroimaging data.

    Parameters
    ----------
    image : filename
        Path to neuroimaging file or txt file

    Returns
    -------
    (N,) np.ndarray
        Scalar brain map values

    Raises
    ------
    FileNotFoundError : `image` does not exist
    IOError : filetype not recognized
    ValueError : `image` contains more than one neuroimaging map

    """
    try:
        x = load(image)
    except FileNotFoundError:
        raise FileNotFoundError("No such file: {}".format(image))
    except nib.loadsave.ImageFileError:
        try:
            x = np.loadtxt(image)
        except (TypeError, ValueError):
            e = "Cannot work out file type of {}".format(image)
            raise IOError(e)
    if x.ndim > 1:
        raise ValueError("Image contains more than one map: {}".format(image))
    return x


def check_surface(surface):
    """
    Check and load MNI coordinates from a surface file.

    Parameters
    ----------
    surface : filename
        Path to GIFTI-format surface file (.surf.gii)

    Returns
    -------
    (N,3) np.ndarray
        MNI coordinates. columns 0,1,2 correspond to X,Y,Z coord, respectively

    Raises
    ------
    ValueError : `surface` does not contain 3 columns (assumed to be X, Y, Z)

    """
    coords = load(surface)
    nvert, ndim = coords.shape
    if ndim != 3:
        e = "expected three columns in surface file but found {}".format(ndim)
        raise ValueError(e)
    return coords


""" Functions for creating graphs from surface meshes. """

import numpy as np
from scipy import sparse


def _get_edges(faces):
    """
    Gets set of edges defined by `faces`.

    Parameters
    ----------
    faces : (F, 3) array_like
        Set of indices creating triangular faces of a mesh

    Returns
    -------
    edges : (F*3, 2) array_like
        All edges in `faces`

    """
    faces = np.asarray(faces)
    edges = np.sort(faces[:, [0, 1, 1, 2, 2, 0]].reshape((-1, 2)), axis=1)
    return edges


def get_direct_edges(vertices, faces):
    """
    Gets (unique) direct edges and weights in mesh describes by inputs.

    Parameters
    ----------
    vertices : (N, 3) array_like
        Coordinates of `vertices` comprising mesh with `faces`
    faces : (F, 3) array_like
        Indices of `vertices` that compose triangular faces of mesh

    Returns
    -------
    edges : (E, 2) array_like
        Indices of `vertices` comprising direct edges (without duplicates)
    weights : (E, 1) array_like
        Distances between `edges`

    """
    edges = np.unique(_get_edges(faces), axis=0)
    weights = np.linalg.norm(np.diff(vertices[edges], axis=1), axis=-1)
    return edges, weights.squeeze()


def get_indirect_edges(vertices, faces):
    """
    Gets indirect edges and weights in mesh described by inputs.

    Indirect edges are between two vertices that participate in faces sharing
    an edge.

    Parameters
    ----------
    vertices : (N, 3) array_like
        Coordinates of `vertices` comprising mesh with `faces`
    faces : (F, 3) array_like
        Indices of `vertices` that compose triangular faces of mesh

    Returns
    -------
    edges : (E, 2) array_like
        Indices of `vertices` comprising indirect edges (without duplicates)
    weights : (E, 1) array_like
        Distances between `edges` on surface

    References
    ----------
    https://github.com/mikedh/trimesh (MIT licensed)

    """
    # first generate the list of edges for the provided faces and the
    # index for which face the edge is from (which is just the index of the
    # face repeated thrice, since each face generates three direct edges)
    edges = _get_edges(faces)
    edges_face = np.repeat(np.arange(len(faces)), 3)

    # every edge appears twice in a watertight surface, so we'll first get the
    # indices for each duplicate edge in `edges` (this should, assuming all
    # goes well, have rows equal to len(edges) // 2)
    order = np.lexsort(edges.T[::-1])
    edges_sorted = edges[order]
    dupe = np.any(edges_sorted[1:] != edges_sorted[:-1], axis=1)
    dupe_idx = np.append(0, np.nonzero(dupe)[0] + 1)
    start_ok = np.diff(np.concatenate((dupe_idx, [len(edges_sorted)]))) == 2
    groups = np.tile(dupe_idx[start_ok].reshape(-1, 1), 2)
    edge_groups = order[groups + np.arange(2)]

    # now, get the indices of the faces that participate in these duplicate
    # edges, as well as the edges themselves
    adjacency = edges_face[edge_groups]
    nondegenerate = adjacency[:, 0] != adjacency[:, 1]
    adjacency = np.sort(adjacency[nondegenerate], axis=1)
    adjacency_edges = edges[edge_groups[:, 0][nondegenerate]]

    # the non-shared vertex index is the same shape as adjacency, holding
    # vertex indices vs face indices
    indirect_edges = np.zeros(adjacency.shape, dtype=np.int32) - 1

    # loop through the two columns of adjacency
    for i, fid in enumerate(adjacency.T):
        # faces from the current column of adjacency
        face = faces[fid]
        # get index of vertex not included in shared edge
        unshared = np.logical_not(np.logical_or(
            face == adjacency_edges[:, 0].reshape(-1, 1),
            face == adjacency_edges[:, 1].reshape(-1, 1)))
        # each row should have one "uncontained" vertex; ignore degenerates
        row_ok = unshared.sum(axis=1) == 1
        unshared[~row_ok, :] = False
        indirect_edges[row_ok, i] = face[unshared]

    # get vertex coordinates of triangles pairs with shared edges, ordered
    # such that the non-shared vertex is always _last_ among the trio
    shared = np.sort(face[np.logical_not(unshared)].reshape(-1, 1, 2), axis=-1)
    shared = np.repeat(shared, 2, axis=1)
    triangles = np.concatenate((shared, indirect_edges[..., None]), axis=-1)
    # `A.shape`: (3, N, 2) corresponding to (xyz coords, edges, triangle pairs)
    A, B, V = vertices[triangles].transpose(2, 3, 0, 1)

    # calculate the xyz coordinates of the foot of each triangle, where the
    # base is the shared edge
    # that is, we're trying to calculate F in the equation `VF = VB - (w * BA)`
    # where `VF`, `VB`, and `BA` are vectors, and `w = (AB * VB) / (AB ** 2)`
    w = (np.sum((A - B) * (V - B), axis=0, keepdims=True)
         / np.sum((A - B) ** 2, axis=0, keepdims=True))
    feet = B - (w * (B - A))
    # calculate coordinates of midpoint b/w the feet of each pair of triangles
    midpoints = (np.sum(feet.transpose(1, 2, 0), axis=1) / 2)[:, None]
    # calculate Euclidean distance between non-shared vertices and midpoints
    # and add distances together for each pair of triangles
    norms = np.linalg.norm(vertices[indirect_edges] - midpoints, axis=-1)
    weights = np.sum(norms, axis=-1)

    # NOTE: weights won't be perfectly accurate for a small subset of triangle
    # pairs where either triangle has angle >90 along the shared edge. in these
    # the midpoint lies _outside_ the shared edge, so neighboring triangles
    # would need to be taken into account. that said, this occurs in only a
    # minority of cases and the difference tends to be in the ~0.001 mm range
    return indirect_edges, weights


def make_surf_graph(vertices, faces, mask=None):
    """
    Constructs adjacency graph from `surf`.

    Parameters
    ----------
    vertices : (N, 3) array_like
        Coordinates of `vertices` comprising mesh with `faces`
    faces : (F, 3) array_like
        Indices of `vertices` that compose triangular faces of mesh
    mask : (N,) array_like, optional (default None)
        Boolean mask indicating which vertices should be removed from generated
        graph. If not supplied, all vertices are used.

    Returns
    -------
    graph : scipy.sparse.csr_matrix
        Sparse matrix representing graph of `vertices` and `faces`

    Raises
    ------
    ValueError : inconsistent number of vertices in `mask` and `vertices`
    """

    if mask is not None and len(mask) != len(vertices):
        raise ValueError('Supplied `mask` array has different number of '
                         'vertices than supplied `vertices`.')

    # get all (direct + indirect) edges from surface
    direct_edges, direct_weights = get_direct_edges(vertices, faces)
    indirect_edges, indirect_weights = get_indirect_edges(vertices, faces)
    edges = np.row_stack((direct_edges, indirect_edges))
    weights = np.hstack((direct_weights, indirect_weights))

    # remove edges that include a vertex in `mask`
    if mask is not None:
        idx, = np.where(mask)
        mask = ~np.any(np.isin(edges, idx), axis=1)
        edges, weights = edges[mask], weights[mask]

    # construct our graph on which to calculate shortest paths
    return sparse.csr_matrix((np.squeeze(weights), (edges[:, 0], edges[:, 1])),
                             shape=(len(vertices), len(vertices)))


from .geo import cortex, subcortex, parcellate, volume
from .io import image2txt
from brainsmash.utils.dataio import load

__all__ = ['cortex', 'subcortex', 'parcellate', 'volume', 'image2txt', 'load']

# TODO "workbench" is now an a misleading name for this subpackage...


# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
import os
import sys
import sphinx_rtd_theme

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../"))
sys.path.insert(0, os.path.abspath("../../"))
sys.path.insert(0, os.path.abspath("../brainsmash/"))
sys.path.insert(0, os.path.abspath("../brainsmash/mapgen/"))
sys.path.insert(0, os.path.abspath("../brainsmash/workbench/"))
sys.path.insert(0, os.path.abspath("../brainsmash/utils/"))

# -- Project information -----------------------------------------------------

project = 'BrainSMASH'
copyright = '2020, Joshua B. Burt, John D. Murray.'
author = 'Joshua B. Burt'

# The full version, including alpha/beta/rc tags
version = '1.0.0'


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc',
              'sphinx.ext.napoleon',
              'sphinx.ext.autosummary',
              ]

napoleon_google_docstring = False
napoleon_numpy_docstring = True
napoleon_include_init_with_doc = False
napoleon_include_special_with_doc = False
napoleon_use_param = True
napoleon_use_ivar = True
napoleon_use_rtype = True
add_function_parentheses = False
autosummary_generate = True

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

source_suffix = '.rst'

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'

html_sidebars = {
    '**': [
        'relations.html',  # needs 'show_related': True theme option to display
        'searchbox.html',
    ]
}

numfig = True
numfig_secnum_depth = 1

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"

html_logo = 'images/logo.png'

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ['_static']

# Output file base name for HTML help builder.
htmlhelp_basename = 'BrainSMASHdoc'
