#!/opt/conda/bin/python3
import  random

import matplotlib.pyplot as plt
import numpy as np

"Code modified from https://github.com/delta2323/gnn-asymptotics/tree/master/gnn_dynamics"

random.seed(0)
np.random.seed(0)

plt.style.use('ggplot')
plt.rcParams["figure.figsize"] = (16,9)
plt.rcParams["font.size"] = 30
# # plt.rcParams["font.weight"] = 'bold'
plt.rcParams["xtick.color"] = 'black'
plt.rcParams["ytick.color"] = 'black'
plt.rcParams["axes.edgecolor"] = 'black'
plt.rcParams["axes.linewidth"] = 1

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Helper
#----------------------------------------------------------------------------------------------------------------------------------------------------

def make_qp(lambda_):
    """Makes random symmetric matrix from eigen values."""
    N = len(lambda_)
    lambda_ = np.diag(lambda_)
    Q = _sample_orthogonal_matrix(N)
    Q = np.sign(Q[0,0])*Q #Need q_1 to be positive (this does not guarantee except for our seed)
    P = np.matmul(np.matmul(Q, lambda_), Q.T)
    return Q, P


def make_w(s):
    """Makes random matrix from singular values."""
    C = len(s)
    U = _sample_orthogonal_matrix(C)
    V = _sample_orthogonal_matrix(C)
    s = np.diag(s)
    W = np.matmul(np.matmul(U, s), V)
    return W


def _sample_orthogonal_matrix(N):
    M = np.random.uniform(-1, 1, (N, N))
    Q, _ = np.linalg.qr(M, 'complete')
    return Q

#----------------------------------------------------------------------------------------------------------------------------------------------------

def make_sample_points(L, T, N, C):
    x = [np.linspace(-L, L, T) for _ in range(N * C)]
    p = np.meshgrid(*x)
    p = [p_.ravel() for p_ in p]
    p = np.stack(p, axis=-1)
    p = p.reshape(-1, N, C)
    return p

def make_dynamics(P, W, b):
    def f(x):
        y = np.tensordot(P, x, axes=(1, 1))
        y = y.transpose((1, 0, 2))
        x_ = np.matmul(y, W) + b
        x_ = np.where(x_ > 0, x_, 0)
        return x_
    return f

def make_leaky_dynamics(P, W, b):
    def f(x):
        y = np.tensordot(P, x, axes=(1, 1))
        y = y.transpose((1, 0, 2))
        x_ = np.matmul(y, W) + b
        x_ = np.where(x_ > 0, x_, 0.001*x_)
        return x_
    return f

#----------------------------------------------------------------------------------------------------------------------------------------------------


def _line(L, S, a):
    xs = np.linspace(-L, L, S)
    ys = xs * a
    L_eps = L - 1e-3
    idx = (xs < L_eps) & (-L_eps < xs) & (-L_eps < ys) & (ys < L_eps)
    return xs[idx], ys[idx]


def streamplot(p, p_next, L, a, title, out, cbar):
    assert p.shape[-1] == 1
    assert p.shape[-2] == 2
    p = p[..., 0]
    p_next = p_next[..., 0]
    
    delta = p_next - p
    
    n_points = len(delta)
    S = int(np.sqrt(n_points))
    assert n_points == S * S
    
    p = p.reshape(S, S, 2)
    delta = delta.reshape(S, S, 2)

    x_delta = delta[..., 0]
    y_delta = delta[..., 1]
    v = np.sqrt(x_delta ** 2 + y_delta ** 2)

    x = np.unique(p[..., 0])
    y = np.unique(p[..., 1])

    line_x, line_y = _line(5, S, a)
    plt.figure(figsize=(6, 5))
    plt.xlim(x.min(), x.max())
    plt.ylim(y.min(), y.max())
    # plt.title(title)
    strm = plt.streamplot(x, y, x_delta, y_delta, color=v, linewidth=2, arrowsize=2)
    if cbar:
        plt.colorbar(strm.lines)
    # plt.axis('off')
    plt.plot(line_x, line_y, c='r', dashes=[2,1], linewidth=5)
    plt.xticks([],[])
    plt.yticks([],[])
    plt.savefig(f'/root/workspace/out/sct_gnn/{out}.pdf',format='pdf',bbox_inches='tight')

#----------------------------------------------------------------------------------------------------------------------------------------------------
# Main
#----------------------------------------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':
    # Generate random matrices
    N = 2
    lambda_ = 0.5 * np.ones(N,dtype=np.float64)
    lambda_[0] = 1.0
    Q, P = make_qp(lambda_)
    print('trajectory graph:', P)

    C = 1
    s = np.ones(C) * -0.5
    s[0] = 1.2
    W = make_w(s)

    for alpha in [-.5, -.25, 0, 1, 2]:
        kv = Q[:,0]
        kv = kv/np.linalg.norm(kv)

        assert(np.allclose(P@kv,kv))

        b = alpha*kv

        # Make dynamics
        f = make_dynamics(P, W, b)

        # Make sample points and forward one time step
        T = 20
        L = 1
        p = make_sample_points(L, T, N, C)
        p_next = f(p)


        # Debug print
        lambda_, e = np.linalg.eig(P)
        e = e[:, np.argsort(lambda_)]
        e1 = e[:, -1]  # eigen vector for largest eigen vector
        e2 = e[:, -2]  # eigen vector for second largest eigen vector
        _, s, _ = np.linalg.svd(W)

        streamplot(p, p_next, L, e1[1] / e1[0],
                        f'$\\alpha=${alpha:.2f}', f'traj_sct_{alpha:.2f}', cbar=False)#(alpha==1))

    # # Again for Leaky
    for alpha in [-.5, -.25, 0, 1, 2]:
        # Make dynamics
        f = make_leaky_dynamics(P, W, b)

        # Make sample points and forward one time step
        T = 20
        L = 1
        p = make_sample_points(L, T, N, C)
        p_next = f(p)


        # Debug print
        lambda_, e = np.linalg.eig(P)
        e = e[:, np.argsort(lambda_)]
        e1 = e[:, -1]  # eigen vector for largest eigen vector
        e2 = e[:, -2]  # eigen vector for second largest eigen vector
        _, s, _ = np.linalg.svd(W)

        streamplot(p, p_next, L, e1[1] / e1[0],
                        f'$\\alpha=${alpha:.2f}', f'traj_sct_leaky_{alpha:.2f}', cbar=False)#(alpha==1))