from __future__ import print_function, division
import numpy as np
import sys
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, './brain_data')
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from utils.util_funcs import matrix_polynomial, correlation_from_covariance, cov, corr


# Inputs: N       :: int number of vertices, num_signals :: int # signals to generate
# Output: signals :: 2D numpy array of signals
def gen_white_signals(N, num_signals):
    mean, cov = np.zeros(N), np.eye(N)

    #each signal is a column in matrix: signals[:,i]
    signals = np.random.multivariate_normal(mean, cov, num_signals).T
    return signals


def sample_white_signals(square_matrix_size, num_matrices, signals_per_matrix, dtype=torch.float32):
    zero_means = torch.zeros(square_matrix_size, dtype=dtype)
    identity_cov = torch.diag(torch.ones(square_matrix_size, dtype=dtype))

    white_signals_generator = MultivariateNormal(zero_means, covariance_matrix=identity_cov)
    # shape: num_matrices x signals_per_matrix x square_matrix_size. Ex) 1064 x 500 x 68
    #   -> rows are samplesd vectors
    white_signals = white_signals_generator.sample((num_matrices, signals_per_matrix))

    # want shape: num_matrices x square_matrix_size x signals_per_matrix. Ex) 1064 x 68 x 500
    #   -> cols are sampled vectors
    return torch.transpose(white_signals, dim0=1, dim1=2)


def compute_analytic_summary_stats(S, coeffs, filters=None):
    assert len(S.shape) == 3, f'must have batch, not single matrix'

    if filters is None:
        filters = matrix_polynomial(S, coeffs)
    # Cov = H^2 when diffusing white signals over graph with Graph Filter
    cov_analytic  = torch.bmm(filters, filters)
    corr_analytic = correlation_from_covariance(cov_analytic)

    return cov_analytic, corr_analytic


def compute_diffusion_summary_stats(S, coeffs, num_signals, dtype=torch.float32):
    assert len(S.shape) == 3, f'must have batch, not single matrix'
    num_matrices, N = S.shape[0], S.shape[1]

    filters = matrix_polynomial(S, coeffs)
    # Cov = H^2 when diffusing white signals over graph with Graph Filter
    cov_analytic, corr_analytic = compute_analytic_summary_stats(S, coeffs, filters=filters)

    white_signals = sample_white_signals(square_matrix_size=N, num_matrices=num_matrices, signals_per_matrix=num_signals, dtype=dtype)
    diffused_signals = torch.bmm(filters, white_signals)
    covariance, correlation = cov(diffused_signals), corr(diffused_signals)

    return (cov_analytic.to(dtype), corr_analytic.to(dtype)),\
           (covariance.to(dtype), correlation.to(dtype))

def compute_diffusion_summary_stats_individual(S, coeffs, num_signals, sum_stat='sample_cov', dtype=torch.float32):
    assert sum_stat in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']
    #assert len(S.shape) == 3, f'must have batch, not single matrix'
    if S.ndim == 2:
        S = S.unsqueeze(dim=0)
    num_matrices, N = S.shape[0], S.shape[1]

    filters = matrix_polynomial(S, coeffs)
    if sum_stat == 'analytic_cov':
        return torch.bmm(filters, filters)
    elif sum_stat == 'analytic_cov':
        return correlation_from_covariance(torch.bmm(filters, filters))

    white_signals = sample_white_signals(square_matrix_size=N, num_matrices=num_matrices,
                                         signals_per_matrix=num_signals, dtype=dtype)
    diffused_signals = torch.bmm(filters, white_signals)

    if sum_stat == 'sample_cov':
        return cov(diffused_signals)
    elif sum_stat == 'sample_corr':
        return corr(diffused_signals)

if __name__ == "__main__":

    N, num_signals = 3, 5000000
    white_signals_np = gen_white_signals(N, num_signals)
    #white_signals_torch = sample_white_signals(N, num_matrices=2, signals_per_matrix=num_signals)

    # test compute_diffusion_summary_stats
    coeffs = torch.tensor([1, .5, 1])
    As = torch.zeros((2, 3, 3))
    As[0] = torch.tensor([[0, 1, 0], [1, 0, .5], [0, .5, 0]])
    As[1] = torch.tensor([[0, 0, .5], [0, 0, 1], [.5, 1, 0]])
    (cov_analytic, corr_analytic), (covariance, correlation) = compute_diffusion_summary_stats(As, coeffs, num_signals=num_signals)

    # compare to np cov when manually diffusing signals
    diff_signals = torch.zeros(len(As), white_signals_np.shape[0], white_signals_np.shape[1])
    test_analytic_cov = torch.zeros(len(As), white_signals_np.shape[0], white_signals_np.shape[0])
    mm_filters = matrix_polynomial(As, coeffs)
    for i in range(len(diff_signals)):
        filter = coeffs[0]*torch.eye(white_signals_np.shape[0]) + \
                 coeffs[1]*As[i] + \
                 coeffs[2]*(As[i]@As[i])
        test_analytic_cov[i] = filter @ filter
        diff_signals[i] = filter @ white_signals_np

    diff_cov = torch.zeros(len(diff_signals), white_signals_np.shape[0], white_signals_np.shape[0])
    for i in range(len(diff_signals)):
        diff_cov[i] = torch.tensor(np.cov(diff_signals[i]))
    assert np.allclose(test_analytic_cov, cov_analytic)
    assert np.allclose(diff_cov, covariance, atol=.01)

    analytic_close_cov = np.abs(cov_analytic - covariance)
    analytic_close_corr = np.abs(corr_analytic - correlation)
    analytic_close_cov_max = torch.max(analytic_close_cov)
    analytic_close_corr_max = torch.max(analytic_close_corr)
    print(f'num_signals: {num_signals}')
    print(f'max diff between analytical and sample cov: {analytic_close_cov_max}')
    print(f'max diff between analytical and sample corr: {analytic_close_corr_max}')
