# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license

from setuptools import setup

with open("spatiotemporal/_version.py", "r") as f:
    exec(f.read())

with open("README.md", "r") as f:
    long_desc = f.read()


setup(
    name = 'spatiotemporal',
    version = __version__,
    description = 'Tools for spatial and temporal autocorrelation',
    long_description = long_desc,
    long_description_content_type='text/markdown',
    author = 'Max Shinn',
    author_email = 'm.shinn@ucl.ac.uk',
    maintainer = 'Max Shinn',
    maintainer_email = 'm.shinn@ucl.ac.uk',
    license = 'MIT',
    python_requires='>=3.6',
    url='https://github.com/mwshinn/spatiotemporal',
    packages = ['spatiotemporal'],
    install_requires = ['numpy', 'scipy', 'pandas'],
    classifiers = [
        'Development Status :: 4 - Beta',
        'Intended Audience :: Science/Research',
        'License :: OSI Approved :: MIT License',
        'Programming Language :: Python :: 3',
        'Topic :: Scientific/Engineering',
        'Topic :: Scientific/Engineering :: Physics',
        'Topic :: Scientific/Engineering :: Mathematics',
        'Topic :: Scientific/Engineering :: Medical Science Apps.',
        'Topic :: Scientific/Engineering :: Bio-Informatics'],
)



# This is mostly just a simple smoke test.  Would be better to have more
# elaborate testing at some point in the future.  That being said, the
# functions in this package were all pulled from a different package which had
# better testing, and they have all been scientifically validated.  So the main
# risk here is if some numpy or scipy function changes its behavior in a way
# that passes invisibly.

import numpy as np
import spatiotemporal as st
import scipy.stats

tss = [np.cumsum(np.random.RandomState(1000).randn(100,200), axis=1), # Brownian noise
       np.random.RandomState(999).randn(50, 400), # Gaussian noise
       ]

def test_phase_randomize():
    for i,ts in enumerate(tss):
        sur = st.phase_randomize(ts)
        fft = np.mean(np.abs(np.fft.fft(ts)), axis=0)[1:len(ts)//2]
        fft_sur = np.mean(np.abs(np.fft.fft(sur)), axis=0)[1:len(sur)//2]
        assert sur.shape == ts.shape
        assert np.max(np.abs(fft-fft_sur)) < .0001


def test_zalesky():
    for i,ts in enumerate(tss):
        cm = np.corrcoef(ts)
        cm_s = st.zalesky_surrogate(cm, seed=i)
        assert np.abs(np.log(np.mean(cm)/np.mean(cm_s))) < .1
        assert np.abs(np.log(np.mean(cm)/np.mean(cm_s))) < .25

def test_eigensurrogate():
    for i,ts in enumerate(tss):
        cm = np.corrcoef(ts)
        cm_s = st.eigensurrogate_matrix(cm, seed=i)
        ev_cm = st.tools.get_eigenvalues(cm)
        ev_cm_s = st.tools.get_eigenvalues(cm_s)
        assert np.max(np.abs(np.log(ev_cm/ev_cm_s))) < 1e-10

def test_eigensurrogate_timeseries():
    for i,ts in enumerate(tss):
        cm = np.corrcoef(ts)
        ts_s = st.eigensurrogate_timeseries(cm, N_timepoints=1000, seed=i)
        ev_cm = st.tools.get_eigenvalues(cm)
        ev_cm_s = st.tools.get_eigenvalues(np.corrcoef(ts_s))
        assert np.max(np.abs(np.log(ev_cm/ev_cm_s))) < .3

def test_spatial_autocorrelation():
    poss = np.random.rand(200,3)*100
    dists = st.tools.distance_matrix_euclidean(poss)
    for params in [(10, .2), (50, .01), (5, .5), (30, -.4)]:
        cm = st.tools.spatial_exponential_floor(dists, params[0], params[1])
        fitparams = np.asarray(st.spatial_autocorrelation(cm, dists))
        assert np.max(np.abs(np.log(fitparams/params))) < .1

def test_temporal_autocorrelation():
    # Test by going through alpha
    for i,target_ta in enumerate([0, .2, .4, .6, .8]):
        alpha = st.models.ta_to_alpha_fast(1000, 1, .01, target_ta)
        spec = st.models.make_spectrum(1000, 1, alpha, .01)
        ts = st.models.correlated_spectral_sampling(np.asarray([[1]]), np.asarray([spec]), seed=i)
        ta = st.temporal_autocorrelation(ts)
        assert np.abs(ta - target_ta) < .05

# We're not going to test long_memory because it requires such a complicated
# setup.

def test_fingerprint():
    assert st.fingerprint([1, 2, 2, 2], np.asarray([[4, 4, 4, 5], [4, 5, 4, 3], [4, 4, 4, 5], [4, 5, 4, 3]])) == .5
    assert st.fingerprint([1, 2, 2, 1], np.asarray([[4, 4, 4, 5], [4, 5, 4, 3], [4, 4, 4, 5], [4, 5, 4, 3]])) == 0
    assert st.fingerprint([1, 2, 1, 2], np.asarray([[4, 4, 4, 5], [4, 5, 4, 3], [4, 4, 4, 5], [4, 5, 4, 3]])) == 1
    assert st.fingerprint([1, 1, 1, 1], np.asarray([[4, 4, 4, 5], [4, 5, 4, 3], [4, 4, 4, 5], [4, 5, 4, 3]])) == 1

def test_lin():
    assert st.lin([1, 2, 3], [1, 2, 3]) == 1
    assert st.lin(np.asarray([1, 2, 3]), np.asarray([3, 2, 1])) == -1
    assert st.lin([5, 5, 5], [3, 5, 7]) == 0
    assert np.all(st.lin(np.asarray([[5,5,5],[1,2,3]]), np.asarray([[1, 2, 3], [3, 2, 1], [1, 2, 3]])) == np.asarray([[0, 0, 0], [1, -1, 1]]))
    assert 0 < st.lin([1, 2, 3], [12, 13, 14])  < st.lin([1, 2, 3], [2, 3, 4]) < 1

def test_cosine():
    assert st.cosine([1, 2, 3], [2, 3, 4]) < 1
    assert st.cosine([1, 2, 3], [1, 2, 3]) == 1
    assert st.cosine([1, 2, 3], [-1, -2, -3]) == -1

def test_pearson():
    assert st.pearson([1, 2, 3], [2, 3, 4]) == 1
    assert st.pearson(np.asarray([1, 2, 3]), np.asarray([10, 8, 6])) == -1
    assert st.pearson([5, 5, 5], [3, 5, 7]) == 0
    assert np.all(st.pearson(np.asarray([[5,5,5],[1,2,3]]), np.asarray([[1, 2, 3], [3, 2, 1], [11, 12, 13]])) == np.asarray([[0, 0, 0], [1, -1, 1]]))
    m = np.random.randn(100,200)
    assert np.all(np.isclose(np.corrcoef(m), st.pearson(m, m)))

def test_spearman():
    assert st.spearman([1, 2, 3], [2, 3, 4]) == 1
    assert st.spearman(np.asarray([1, 2, 3]), np.asarray([10, 8, 6])) == -1
    assert st.spearman([5, 5, 5], [3, 5, 7]) == 0
    assert np.all(st.spearman(np.asarray([[5,5,5],[1,2,3]]), np.asarray([[1, 2, 3], [3, 2, 1], [11, 12, 13]])) == np.asarray([[0, 0, 0], [1, -1, 1]]))
    assert st.spearman([1, 2, 3], [3, 5, 6]) > st.pearson([1, 2, 3], [3, 5, 6])
    m = np.random.randn(100,200)
    assert np.all(np.isclose(scipy.stats.spearmanr(m.T).correlation, st.spearman(m, m)))

def test_spatiotemporal_model():
    poss = np.random.rand(50,3)*100
    seed = 100
    distance_matrix = st.tools.distance_matrix_euclidean(poss)
    num_timepoints = 50000 # Big to get a better fit
    ta_delta1s = np.random.RandomState(seed).rand(50)*.8
    sample_rate = 1
    highpass_freq = .01
    seed = 1
    for i,params in enumerate([(20, .2), (50, .5), (10, 0)]):
        tss = st.spatiotemporal_model_timeseries(distance_matrix, params[0], params[1], ta_delta1s, num_timepoints, sample_rate, highpass_freq, seed=i+seed+1)
        newtas = st.temporal_autocorrelation(tss)
        assert np.max(np.abs(ta_delta1s-newtas)) < .05
        # Don't know how to test SA-lambda...
        assert (np.mean(np.corrcoef(tss)) - params[1])

def test_intrinsic_timescale_sa_model():
    poss = np.random.rand(50,3)*100
    seed = 200
    distance_matrix = st.tools.distance_matrix_euclidean(poss)
    num_timepoints = 50000 # Big to get a better fit
    ta_delta1s = np.random.RandomState(seed).rand(50)*.3
    sample_rate = 1
    highpass_freq = .01
    seed = 1
    for i,params in enumerate([(20, .05), (50, .1), (10, 0)]):
        tss = st.intrinsic_timescale_sa_model_timeseries(distance_matrix, params[0], params[1], ta_delta1s, num_timepoints, sample_rate, highpass_freq, seed=i+seed+1)
        newtas = st.temporal_autocorrelation(tss)
        assert np.max(np.abs(ta_delta1s-newtas)) < .05
        # Don't know how to test SA-lambda...
        assert (np.mean(np.corrcoef(tss)) - params[1])


# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license
import numpy as np
import scipy.stats

def fingerprint(subjects, values):
    """Fingerprinting performance, from [Finn et al (2015)](https://www.nature.com/articles/nn.4135).

    Note:
        This implementation is slightly different than that of Finn et al (2015).
        Here, instead of having separate databases, we use all other scans from all
        other subjects as the "database".  Then, we see if the match is from the
        same subject or different subjects.  This means that if there are more than
        two observations for each subject, there will be more than one "correct"
        best match.  However, it also means that there are many more possible
        incorrect matches for a given subject than there are in Finn et al (2015).

    Args:
      subjects (list or 1xN numpy array): a list of length N giving the subject ID.  N should be
          the total number of observations, e.g., if there are 10 subjects with 3
          scans each, N = 30.
      values (Nxk numpy array): numpy matrix, where N is as above and k is the size of
          the feature on which to perform the fingerprinting.  E.g., if we are
          fingerprinting based on TA-delta1 of each node in a 360 node atlas, k=360.

    Returns:
      float: The fraction of correct fingerprinting identifications
    """

    assert values.shape[0] == len(subjects)
    corrs = np.corrcoef(values)
    np.fill_diagonal(corrs, -1)
    maxcorrs = np.argmax(corrs, axis=0)
    best_match_subject = np.asarray(subjects)[maxcorrs]
    return np.mean(best_match_subject == subjects)

def _cross(x, y, func):
    """Utility function for correlation/covariance functions"""
    x = np.asarray(x)
    y = np.asarray(y)
    if x.ndim == 1:
        x = x[None]
    if y.ndim == 1:
        y = y[None]
    assert x.ndim == 2, "X must be two-dimensional"
    assert y.ndim == 2, "Y must be two-dimensional"
    assert x.shape[1] == y.shape[1], "X and Y must have compatible dimensions"
    res = func(x,y)
    res[np.isnan(res)] = 0
    if res.shape == (1,1):
        return res[0,0]
    return res


def lin(x, y):
    """Lin's concordance correlation coefficient

    Lin's concordnce ranges from -1 to 1 and achieves a tradeoff between
    correlation and variance explained.  This function returns a matrix of the
    concordance for row vectors.

    Args:
      x (list of length N or KxN numpy array): the first vector on which to find Lin's concordance
      y (list of length N or MxN numpy array): the second vector on which to find Lin's concordance

    Returns:
      KxM numpy array: The Lin's concordance matrix between x and y.  If K and M are 1 or x and y are lists, return a float instead.

    """
    _lin = lambda x,y : 2*(x-np.mean(x, axis=1, keepdims=True))@(y-np.mean(y, axis=1, keepdims=True)).T/x.shape[1]/(np.var(x, axis=1)[:,None] + np.var(y, axis=1)[None,:] + (np.mean(x, axis=1)[:,None]-np.mean(y, axis=1)[None,:])**2)
    return _cross(x,y,_lin)

def cosine(x, y):
    """Cosine similarity

    Cosine similarity ranges from -1 to 1 and measures the cosine of the angles
    between the vectors.  This function returns a matrix of the similarity for row vectors.

    Args:
      x (list of length N or KxN numpy array): the first vector on which to find cosine similarity
      y (list of length N or MxN numpy array): the second vector on which to find cosine similarity

    Returns:
      KxM numpy array: The cosine similarity matrix between x and y.  If K and M are 1 or x and y are lists, return a float instead.

    """
    _cosine = lambda x,y : x@y.T/np.sqrt(np.sum(x**2, axis=1)[:,None])/np.sqrt(np.sum(y**2, axis=1)[None,:])
    return _cross(x,y,_cosine)

def pearson(x, y):
    """Matrix of Pearson correlations

    Pearson correlation ranges from -1 to 1.  This function differs from
    np.corrcoef because it allows you to pass x and y as matrices without
    computing the correlation between within-matrix rows.  For large x and y
    this is a substantial speed increase.  This function returns a matrix of
    the correlation for row vectors.  (Sometimes this operation is mistakenly
    called "cross-correlation".)

    Args:
      x (list of length N or KxN numpy array): the first vector on which to find Pearson correlation
      y (list of length N or MxN numpy array): the second vector on which to find Pearson correlation

    Returns:
      KxM numpy array: The Pearson correlation matrix between x and y.  If K and M are 1 or x and y are lists, return a float instead.

    """
    _pearson = lambda x,y : (x-np.mean(x,axis=1,keepdims=True))@(y-np.mean(y,axis=1,keepdims=True)).T/x.shape[1]/np.sqrt(np.var(x, axis=1)[:,None])/np.sqrt(np.var(y, axis=1)[None,:])
    return _cross(x,y,_pearson)

def spearman(x, y):
    """Matrix of Spearman correlations

    Spearman correlation ranges from -1 to 1.  This function differs from
    scipy.stats.spearmanr because it allows you to pass x and y as matrices
    without computing the correlation between within-matrix rows.  For large x
    and y this is a substantial speed increase.  This function returns a matrix
    of the correlation for row vectors (unlike scipy.stats.spearmanr, which uses
    column vectors).

    Args:
      x (list of length N or KxN numpy array): the first vector on which to find Spearman correlation
      y (list of length N or MxN numpy array): the second vector on which to find Spearman correlation

    Returns:
      KxM numpy array: The Spearman correlation matrix between x and y.  If K and M are 1 or x and y are lists, return a float instead.

    """
    def _spearman(x,y):
        x = scipy.stats.rankdata(x, axis=1)
        y = scipy.stats.rankdata(y, axis=1)
        return (x-np.mean(x,axis=1,keepdims=True))@(y-np.mean(y,axis=1,keepdims=True)).T/x.shape[1]/np.sqrt(np.var(x, axis=1)[:,None])/np.sqrt(np.var(y, axis=1)[None,:])
    return _cross(x,y,_spearman)


# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license
import numpy as np
import scipy.signal
import scipy.optimize
from .tools import spatial_exponential_floor

def spatiotemporal_model_timeseries(distance_matrix, sa_lambda, sa_inf, ta_delta1s, num_timepoints, sample_rate, highpass_freq, seed=None):
    """Simulate the spatiotemporal model from [Shinn et al (2023)](https://www.nature.com/articles/s41593-023-01299-3)

    Args:
      distance_matrix (NxN numpy array): the NxN distance matrix, representing the spatial distance
          between location of each of the timeseries.  This should usually be the
          output of the `distance_matrix_euclidean` function.
      sa_lambda (float): the SA-λ parameter
      sa_inf (float): the SA-∞ parameter
      ta_delta1s (length-N list of floats): a list of TA-Δ₁ values, of length N, for generating the timeseries
      num_timepoints (int): length of timeseries to generate
      sample_rate (float): the spacing between timepoints (e.g. TR in fMRI)
      highpass_freq (float): if non-zero, apply a highpass filter above the
          given frequency.  A good default is 0.01 for fMRI timeseries.
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxT numpy array: For each of the N nodes, a timeseries of length T=`num_timepoints` according to the spatiotemporal model
    """
    assert num_timepoints % 2 == 0, "Must be even timeseries length"
    # Filtered brown noise spectrum
    spectrum = make_spectrum(tslen=num_timepoints, sample_rate=sample_rate, alpha=2, highpass_freq=highpass_freq)
    spectra = np.asarray([spectrum]*len(ta_delta1s))
    # Spatial autocorrelation matrix
    corr = spatial_exponential_floor(distance_matrix, sa_lambda, sa_inf)
    # Create spatially embedded timeseries with the given spectra
    tss = correlated_spectral_sampling(corr, spectra, seed=seed)
    # Compute the standard deviation of nosie we need to add to get the desired TA-delta1
    noises = [how_much_noise(spectrum, max(.001, ta_delta1)) for ta_delta1 in ta_delta1s]
    # Add noise to the timeseries
    rng = np.random.RandomState(seed)
    tss += rng.randn(tss.shape[0], tss.shape[1]) * np.asarray(noises).reshape(-1,1)
    return tss

def intrinsic_timescale_sa_model_timeseries(distance_matrix, sa_lambda, sa_inf, ta_delta1s, num_timepoints, sample_rate, highpass_freq, seed=0):
    """Simulate the intrinsic timescale + spatial autocorrelation model from [Shinn et al (2023)](https://www.nature.com/articles/s41593-023-01299-3)

    Args:
      distance_matrix (NxN numpy array): the NxN distance matrix, representing the spatial distance
          between location of each of the timeseries.  This should usually be the
          output of the `distance_matrix_euclidean` function.
      sa_lambda (float): the SA-λ parameter
      sa_inf (float): the SA-∞ parameter
      ta_delta1s (length-N list of floats): a list of TA-Δ₁ values, of length N, for generating the timeseries
      num_timepoints (int): length of timeseries to generate
      sample_rate (float): the spacing between timepoints (e.g. TR in fMRI)
      highpass_freq (float): if non-zero, apply a highpass filter above the
          given frequency.  A good default is 0.01 for fMRI timeseries.
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxT numpy array: For each of the N nodes, a timeseries of length T=`num_timepoints` according to the intrinsic timescale + spatial autocorrelation model
    """
    assert num_timepoints % 2 == 0, "Must be even timeseries length"
    # Determine the pink noise exponent alpha from the TA-delta1
    alphas = [ta_to_alpha_fast(sample_rate=sample_rate, tslen=num_timepoints, highpass_freq=highpass_freq, target_ta=max(0,ta_delta1)) for ta_delta1 in ta_delta1s]
    # Use these alpha values to construct desired frequency spectra
    spectra = np.asarray([make_spectrum(tslen=num_timepoints, sample_rate=sample_rate, alpha=alpha, highpass_freq=highpass_freq) for alpha in alphas])
    # Spatial autocorrelation matrix
    corr = spatial_exponential_floor(distance_matrix, sa_lambda, sa_inf)
    # Compute timeseries from desired correlation matrix and frequency spectra
    tss = correlated_spectral_sampling(cm=corr, spectra=spectra, seed=seed)
    return tss


def correlated_spectral_sampling(cm, spectra, seed=None):
    """Generate timeseries with given amplitude spectra and correlation matrices

    This implements Correlated Spectral Sampling, as described in [Shinn et al
    (2023)](https://www.nature.com/articles/s41593-023-01299-3).

    Args:
      cm (NxN numpy array): The correlation matrix
      spectra (Nxk numpy array): A list of Fourier spectra generated by [make_spectrum][spatiotemporal.models.make_spectrum]
          Each of the N spectra are associated with a row/column of `cm`.
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxT numpy array: N timeseries of length T.  Timeseries i will have a power spectrum given by `spectra[i]`,
          and will be correlated with the other timeseries with correlations cm[i].

    """
    N_regions = cm.shape[0]
    N_freqs = len(spectra[0])
    N_timepoints = (N_freqs-1)*2
    assert spectra.shape == (N_regions, N_freqs)
    sum_squares = np.sum(spectra**2, axis=1, keepdims=True)
    cosine_similarity = (spectra @ spectra.T)/np.sqrt(sum_squares @ sum_squares.T)
    covmat = cm / cosine_similarity
    if np.min(np.linalg.eigvalsh(covmat)) < -1e-8:
        raise PositiveSemidefiniteError("Correlation matrix is not possible with those spectra using this method!")
    randstate = np.random.RandomState(seed)
    rvs = randstate.multivariate_normal(np.zeros(N_regions), cov=covmat, size=N_freqs*2)
    reals = rvs[0:N_freqs].T * spectra
    ims = rvs[N_freqs:].T * spectra
    # Since the signal length is even, frequencies +/- 0.5 are equal
    # so the coefficient must be real.
    ims[:,-1] = 0
    # The DC component must be zero and real
    reals[:,0] = 0
    ims[:,0] = 0
    tss = np.fft.irfft(reals + 1J*ims, n=N_timepoints, axis=-1)
    return tss

def make_spectrum(tslen, sample_rate, alpha, highpass_freq):
    """Create a 1/f^alpha spectrum.

    Args:
      tslen (int): the length of the timeseries represented by the spectrum
      sample_rate (float): the spacing between timepoints (e.g. TR in fMRI)
      alpha (float): the pink noise exponent, between 0 and 2.
      highpass_freq (float): if non-zero, apply a highpass filter above the
          given frequency.  A good default is 0.01 for fMRI timeseries.

    Returns:
      length-k numpy array: Return the fourier spectrum (amplitude spectrum)
    """
    freqs = np.fft.rfftfreq(tslen, sample_rate)
    with np.errstate(all="warn"):
        spectrum = freqs**(-alpha/2)
    if highpass_freq > 0:
        butter = scipy.signal.iirfilter(ftype='butter', N=4, Wn=highpass_freq, btype='highpass', fs=1/sample_rate, output='ba')
        butterresp = scipy.signal.freqz(*butter, fs=1/sample_rate, worN=len(freqs), include_nyquist=True)
        assert np.all(np.isclose(freqs, butterresp[0]))
        spectrum = spectrum * np.abs(butterresp[1])
    spectrum[0] = 0
    return spectrum

def how_much_noise(spectrum, target_ta):
    """Determine the standard deviation of noise to add to achieve a target TA-Δ₁.

    This function answers the following question: If I generate timeseries with
    frequency spectrum (amplitude spectrum) `spectrum`, and then add white
    noise to the generated timeseries, what should the standard deviation of
    this white noise be if I want the timeseries to have the TA-Δ₁ coefficient
    `target_ta`?

    Args:
      spectrum (length-k numpy array): the power spectrum to generate from, e.g.,
          that generated from [make_spectrum][spatiotemporal.models.make_spectrum].
      target_ta (float): the desired
          TA-Δ₁.

    Returns:
      float: The standard deviation of white noise to add
    """

    N = len(spectrum)
    weightedsum = np.sum(spectrum[1:]**2*np.cos(np.pi*np.arange(1, N)/N))
    try:
        sigma = np.sqrt((weightedsum - np.sum(spectrum**2)*target_ta)/(target_ta*N**2))
    except FloatingPointError:
        sigma = 0
    return sigma

def ta_to_alpha(tslen, sample_rate, highpass_freq, target_ta):
    """Compute the (pink noise) alpha which would give, noiseless, the given TA-Δ₁.

    Generate timeseries with get_spectrum_ta, i.e. high pass filtered.

    Args:
      tslen (int): the length of the timeseries represented by the spectrum
      sample_rate (float): the spacing between timepoints (e.g. TR in fMRI)
      alpha (float): the pink noise exponent, between 0 and 2.
      highpass_freq (float): if non-zero, apply a highpass filter above the
          given frequency.  A good default is 0.01 for fMRI timeseries.

    Returns:
      float: a value of alpha such that the filtered pink noise with
          this exponent has TA-Δ₁ coefficient `target_ta`.
    """
    objfunc = lambda alpha : (get_spectrum_ta(make_spectrum(tslen, sample_rate, alpha[0], highpass_freq)) - target_ta)**2
    x = scipy.optimize.minimize(objfunc, 1.5, bounds=[(0, 2)])
    return float(x.x[0])

def get_spectrum_ta(spectrum):
    """Given a fourier spectrum, return the expected TA-Δ₁ of a timeseries generated with that spectrum.

    Args:
      spectrum (length-k numpy array): the power spectrum to generate from, e.g.,
          that generated from [make_spectrum][spatiotemporal.models.make_spectrum].

    Returns:
      float: the TA-Δ₁ value that would be expected if a timeseries had the given power spectrum and random phases.

    """
    N = len(spectrum)
    weightedsum = np.sum(spectrum**2*np.cos(np.pi*np.arange(0, N)/N))
    return weightedsum/np.sum(spectrum**2)

ta_to_alpha_cache = {}
def ta_to_alpha_fast(tslen, sample_rate, highpass_freq, target_ta):
    """Identical to `ta_to_alpha`, but discretize and cache to increase speed.

    See [ta_to_alpha][spatiotemporal.models.ta_to_alpha] for documentation.
    """
    global ta_to_alpha_cache
    taround = round(target_ta, 2)
    key = (tslen, sample_rate, highpass_freq, taround)
    if key in ta_to_alpha_cache.keys():
        return ta_to_alpha_cache[key]
    val = ta_to_alpha(*key)
    ta_to_alpha_cache[key] = val
    return val

def make_noisy_spectrum(tslen, sample_rate, alpha, highpass_freq, target_ar1):
    """Similar to make_spectrum, except adds white noise to the spectrum
    (i.e. uniform distribution).  Returns the fourier spectrum (amplitude
    spectrum).

    This also applies the same filter as make_spectrum.

    """
    noiseless_spectrum = make_spectrum(tslen, sample_rate, alpha, highpass_freq)
    N = len(noiseless_spectrum)
    noise = how_much_noise(noiseless_spectrum, target_ar1)
    noisy_spectrum = np.sqrt(noiseless_spectrum**2 + noise**2 * N)
    noisy_spectrum[0] = 0
    return noisy_spectrum

class PositiveSemidefiniteError(Exception):
    pass


# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license
import numpy as np
import pandas
import scipy.spatial
import scipy.optimize

def spatial_autocorrelation(cm, dist, discretization=1):
    """Calculate the SA-λ and SA-∞ measures of spatial autocorrelation, defined in [Shinn et al (2023)](https://www.nature.com/articles/s41593-023-01299-3)

    Args:
      cm (NxN numpy array): NxN correlation matrix of timeseries, where N is the number of
          timeseries
      dist (NxN numpy array): the NxN distance matrix, representing the spatial distance
          between location of each of the timeseries.  This should usually be the
          output of the `distance_matrix_euclidean` function.
      discretization (int): The size of the bins to use when computing the SA parameters.
          The size of the discretization should ensure that there are a sufficient number of
          observations in each bin, but also enough total bins to make a meaningful estimation.
          Try increasing it or decreasing it according to the scale of your data.  Data that has values
          up to around 100 should be fine with the default.  Decrease or increase as necessary
          to get an appropriate estimation.

    Returns:
      tuple of floats: tuple of (SA-λ, SA-∞)
    """
    cm_flat = cm.flatten()
    dist_flat = dist.flatten()
    df = pandas.DataFrame(np.asarray([dist_flat, cm_flat]).T, columns=["dist", "corr"])
    df['dist_bin'] = np.round(df['dist']/discretization)*discretization
    df_binned = df.groupby('dist_bin').mean().reset_index().sort_values('dist_bin')
    binned_dist_flat = df_binned['dist_bin']
    binned_cm_flat = df_binned['corr']
    binned_dist_flat[0] = 1
    spatialfunc = lambda v : np.exp(-binned_dist_flat/v[0])*(1-v[1])+v[1]
    with np.errstate(all='warn'):
        res = scipy.optimize.minimize(lambda v : np.sum((binned_cm_flat-spatialfunc(v))**2), [10, .3], bounds=[(.1, 100), (-1, 1)])
    return (res.x[0], res.x[1])

def temporal_autocorrelation(x):
    """Compute TA-Δ₁ (lag-1 temporal autocorrelation) from the timeseries.

    Args:
      x: the timeseries of which to compute TA-Δ₁

        If `x` is a single list or one-dimensional numpy array, return the
        TA-Δ₁ estimate.  If `x` contains nested lists or is a NxT numpy, return
        a numpy array of length N, giving the TA-Δ₁ of each row of x.

    Returns:
      list of floats: The temporal autocorrelation of each timeseries in
          `x` is multidimensional, return nested lists in the same shape as the
           leading dimensions of `x`.

    Note:
        This is the biased estimator, but it is the default in numpy.  This
        is what we use throughout the manuscript.  The purpose of this function
        is to standardize computing TA-Δ₁.

    """
    if isinstance(x[0], (list, np.ndarray)):
        return np.asarray([temporal_autocorrelation(xe) for xe in x])
    return np.corrcoef(x[0:-1], x[1:])[0,1]

def long_memory(x, minscale, multivariate=False):
    """Estimate the long memory coefficient, from [Achard and Gannaz (2006)](https://doi.org/10.1111/jtsa.12170)

    See [Achard and Gannaz (2006)](https://doi.org/10.1111/jtsa.12170) for
    details of the coefficient and its estimator.

    Args:
      x (NxT numpy array): the matrix of timeseries (rows are regions, columns are timepoints)
      minscale (int): the minimum wavelet scale used to perform the estimation.
          As a rule of thumb, if data are low pass filtered, minscale should be the
          multiple of nyquist corresponding to the filter frequency, e.g. 2 if
          filtering is performed at half nyquist.
      multivariate: the type of estimation to perform, described in [Achard and Gannaz (2006)](https://doi.org/10.1111/jtsa.12170).
          Note that multivariate=True is extremely slow for any
          reasonably sized correlation matrix.

    Returns:
      float: The long memory coefficient

    Warning:
        This function is rather fragile: it requires rpy2 to be installed, as well
        as R, with the multiwave package.  It works on my computer but your results
        may vary.  If it doesn't work for you, export your data and use the
        "multiwave" package directly in R.

    """
    try:
        import rpy2.robjects.packages
        import rpy2.robjects.numpy2ri
    except ImportError:
        print("Rpy2 not available, long memory coefficient estimation not available.  "
              "Try using the 'multiwave' package in R instead.")
    x = x.transpose()
    rpy2.robjects.numpy2ri.activate()
    try:
        multiwave = rpy2.robjects.packages.importr('multiwave')
    except Exception as e:
        print("Please install the multiwave package in R")
        raise e
    filt = multiwave.scaling_filter("Daubechies", 8).rx2('h')
    if multivariate:
        res = list(multiwave.mww(x, filt, np.asarray([minscale,11])).rx2('d'))
    else:
        res = [multiwave.mww(x[:,i], filt, np.asarray([minscale,11])).rx2('d')[0] for i in range(0, x.shape[1])]
    rpy2.robjects.numpy2ri.deactivate()
    return res



# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license
import numpy as np
from .tools import get_eigenvalues, make_perfectly_symmetric
import scipy.stats

def eigensurrogate_matrix(cm, seed=None):
    """Eigensurrogate model, from [Shinn et al (2023)](https://www.nature.com/articles/s41593-023-01299-3)

    Determine the eigenvalues of the correlation matrix, and then sample a new
    correlation matrix with the same eigenvalues.

    Args:
      cm (NxN numpy array): a correlation matrix
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxN numpy array: a correlation matrix with the same eigenvalues as `cm`
    """
    desired_evs = get_eigenvalues(cm)
    rng = np.random.RandomState(seed)
    if min(desired_evs) < 0:
        desired_evs = np.maximum(0, desired_evs)
        newsum = np.sum(desired_evs)
        desired_evs[0] -= (newsum - len(desired_evs))
        print(f"Warning: eigenvalues were less than zero in source matrix by {newsum-len(desired_evs)}")
    m = scipy.stats.random_correlation.rvs(eigs=desired_evs, tol=1e-12, random_state=rng)
    np.fill_diagonal(m, 1)
    return make_perfectly_symmetric(m)

def eigensurrogate_timeseries(cm, N_timepoints, seed=None):
    """Timeseries from the eigensurrogate model.

    Sample timeseries which have the correlation matrix given by the
    eigensurrogate model.  Note that there are many ways to sample timeseries
    from the eigensurrogate model, but this is the simplest (a multivariate
    normal distribution).

    Args:
      cm (NxN numpy array): a correlation matrix
      N_timepoints (int): the length of the timeseries to sample
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxN_timepoints numpy array: timeseries generated from the eigensurrogate model

    """
    surrogate = eigensurrogate_matrix(cm)
    N_regions = cm.shape[0]
    rng = np.random.RandomState(seed)
    msqrt = scipy.linalg.sqrtm(surrogate)
    return msqrt @ rng.randn(N_regions, N_timepoints)

def phase_randomize(tss, seed=None):
    """Phase-randomized surrogate timeseries.

    Scramble a set of timeseries independently by preserving the amplitudes in
    Fouries space but randomly sampling new phases from the uniform distribution [0, 2π].

    Args:
      tss (NxT numpy array): should be a NxT matrix, where N is the number of timeseries and T is
          the number of samples in the timeseries.
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
        NxT numpy array: surrogate timeseries of the same shape as `tss`
    """
    surrogates = np.fft.rfft(tss, axis=1)
    (N, n_time) = tss.shape
    len_phase = surrogates.shape[1]
    # Generate random phases uniformly distributed in the
    # interval [0, 2*Pi]
    phases = np.random.RandomState(seed).uniform(low=0, high=2 * np.pi, size=(N, len_phase))
    # Add random phases uniformly distributed in the interval [0, 2*Pi]
    surrogates *= np.exp(1j * phases)
    # Calculate IFFT and take the real part, the remaining imaginary
    # part is due to numerical errors.
    return np.real(np.fft.irfft(surrogates, n=n_time, axis=1))

def zalesky_surrogate(cm, seed=None):
    """Zalesky matching surrogate, from [Zalesky et al (2012)](https://doi.org/10.1016/j.neuroimage.2012.02.001)

    Generate matrices with identical mean-FC and var-FC.  Adapted from code
    taken from [Zalesky et al (2012)](https://doi.org/10.1016/j.neuroimage.2012.02.001)

    Args:
      cm (NxN numpy array): a correlation matrix
      seed (int, optional): the random seed.  If not specified, it will use
          the current state of the numpy random number generator.

    Returns:
      NxN numpy array: a correlation matrix with the same mean and variance as `cm`

    """
    N_regions = cm.shape[0]
    tri = np.triu_indices(N_regions, 1)
    rng = np.random.RandomState(seed)
    desired_mean = np.mean(cm[tri])
    desired_var = np.var(cm[tri])
    def fitmean(mu, n):
        """
        n = number of timepoints
        """
        x = rng.randn(N_regions, n) # Each ROW is a different region.  This is inconsistent with the paper.
        y = rng.randn(n, 1)
        amax = 10
        amin = 0
        while np.abs(amax-amin) > .001:
            a = amin + (amax-amin)/2
            rho = np.corrcoef(x+a*(y@np.ones((1, N_regions))).T)
            assert rho.shape[0] == N_regions, "Bad shape"
            muhat = np.mean(rho[tri])
            if muhat > desired_mean:
                amax = a
            else:
                amin = a
        return rho
    nmax = 1000
    nmin = 2
    while nmax - nmin > 1:
        n = int(np.floor(nmin + (nmax-nmin)/2))
        rho = fitmean(desired_mean, n)
        muhat = np.mean(rho[tri])
        sigma2hat = np.var(rho[tri])
        if sigma2hat > desired_var:
            nmin = n
        else:
            nmax = n
    return make_perfectly_symmetric(rho)


# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license
import numpy as np
import scipy.linalg
import scipy.spatial

def get_eigenvalues(cm):
    """Find the eigenvalues of the correlation matrix `cm`.

    They will always be real and non-negative since correlation matrices are
    positive semidefinite
    """
    return scipy.linalg.eigvalsh(cm)

def make_perfectly_symmetric(cm):
    """Eliminate numerical errors in a correlation matrix."""
    return np.maximum(cm, cm.T)

def distance_matrix_euclidean(distances):
    """Returns a Euclidean distance matrix.

    `distances` should be a Nx3 numpy matrix, providing the xyz coordinates for
    N brain regions.

    """
    return scipy.spatial.distance.cdist(distances, distances)

def spatial_exponential_floor(distances, sa_lmbda, sa_inf):
    """Find a hypothetical spatial correlation matrix"""
    return np.exp(-distances/sa_lmbda)*(1-sa_inf)+sa_inf


__version__ = '1.0.1'


# Part of the spatiotemporal package for python
# Copyright 2022 Max Shinn <m.shinn@ucl.ac.uk>
# Available under the MIT license

from .stats import spatial_autocorrelation,temporal_autocorrelation,long_memory
from .models import spatiotemporal_model_timeseries, intrinsic_timescale_sa_model_timeseries
from .surrogates import eigensurrogate_matrix, eigensurrogate_timeseries, phase_randomize, zalesky_surrogate
from .extras import fingerprint, lin, cosine, pearson, spearman
