import numpy as np
import igraph as ig
import numpy as np
from scipy.sparse import csr_matrix, csc_matrix
from scipy.sparse.linalg import spsolve

def simulate_time_unrolled_dag(d, s0, graph_type, number_of_lags=0, average_degrees_per_lagged_node=[None, None]):
    """Simulate random time unrolled DAG.

    Args:
        d (int): num of nodes
        s0 (int): expected num of edges for instantaneous adjacency matrix
        graph_type (str): ER, SF, BP
        number_of_lags: number of additional adjecency matrices for non-instantaneous dependencies
        average_degrees_per_lagged_node: degrees of the adjacency matrices that represent non-instantaneous dependencies

    Returns:
        A (list of np.ndarray): list of (number_of_lags + 1) elements where each is a [d, d] binary adj matrix.
                                The first adjacency matrix represents instantaneous dependencies (intra-slice) and the rest number_of_lags the inter-slice.
    """
    def _random_permutation(M):
        # np.random.permutation permutes first axis only
        P = np.random.permutation(np.eye(M.shape[0]))
        return P.T @ M @ P

    def _random_acyclic_orientation(B_und):
        return np.tril(_random_permutation(B_und), k=-1)

    def _graph_to_adjmat(G):
        return np.array(G.get_adjacency().data)

    A = []
    for t in range(number_of_lags + 1):
        s0 = average_degrees_per_lagged_node[t - 1] * d if t > 0 else s0

        if graph_type == 'ER':
            # Erdos-Renyi
            G_und = ig.Graph.Erdos_Renyi(n=d, m=s0)
            B_und = _graph_to_adjmat(G_und)
            B = _random_acyclic_orientation(B_und)
        elif graph_type == 'SF':
            # Scale-free, Barabasi-Albert
            G = ig.Graph.Barabasi(n=d, m=int(round(s0 / d)), directed=True)
            B = _graph_to_adjmat(G)
        elif graph_type == 'BP':
            # Bipartite, Sec 4.1 of (Gu, Fu, Zhou, 2018)
            top = int(0.2 * d)
            G = ig.Graph.Random_Bipartite(top, d - top, m=s0, directed=True, neimode=ig.OUT)
            B = _graph_to_adjmat(G)
        else:
            raise ValueError('unknown graph type')
        
        B_perm = _random_permutation(B)
        assert ig.Graph.Adjacency(B_perm.tolist()).is_dag()
        A.append(B_perm)

    return A


def simulate_parameter(B, w_ranges=((-2.0, -0.5), (0.5, 2.0))):
    """Assigning random weights to an adjacency matrix.

    Args:
        B (np.ndarray): [d, d] binary adj matrix
        w_ranges (tuple): disjoint weight ranges

    Returns:
        W (np.ndarray): [d, d] weighted adj matrix
    """
    W = np.zeros(B.shape)
    S = np.random.randint(len(w_ranges), size=B.shape)
    for i, (low, high) in enumerate(w_ranges):
        U = np.random.uniform(low=low, high=high, size=B.shape)
        W += B * (S == i) * U
    return W


def sparse_rct_sem_suboptimal(W_full, T, n=1, sparsity=0.3, std=0.01, noise_type='gauss', noise_effect='spectral'):
    """
        W_full : list of adjacencies (length = p + 1)
        T: number of desired timesteps
        n: number of sequences to produce
    """
    W = block_toeplitz(W_full, T)

    #number of nodes
    d = W_full[0].shape[0]
    # number of entries of a sequence
    I = np.eye(T * d)

    # initializing the sparse spectrum
    pos = np.random.choice([0, 1], size=(n, d * T), p=[1 - sparsity, sparsity]) 
    sign = np.random.choice([-1, 1], size=(n, d * T)) 
    C = pos * np.random.uniform(-1, 1, size=(n, d * T))  * sign

    # computing matrix of independent noises
    if std==0:
        Ns = np.zeros((n, d * T))
        Nf = np.zeros((n, d * T))
    elif noise_type == 'gauss':
        Ns = np.random.normal(scale=std, size=(n, d * T))
        Nf = np.random.normal(scale=std, size=(n, d * T))
    else: # considering gumbel case
        noise_scale = np.sqrt(6) * std / np.pi
        Ns = np.random.gumbel(scale=noise_scale, size=(n, d * T))
        Nf = np.random.gumbel(scale=noise_scale, size=(n, d * T))

    if noise_effect == 'spectral':
        Ns = np.zeros((n, d * T))
    elif noise_effect == 'signal':
        Nf = np.zeros((n, d * T))

    # computation according to definition of transitive closure (either Floyd-Warshall or zero)
    # refl_trans_clos = np.linalg.inv(I - W)
    A = csc_matrix(I - W.T)
    B = (C + Nf).T
    X = spsolve(A, B) # (X = XW + C + Nf)
    X = X.T + Ns

    X = X.reshape(n, T, d)
    C = C.reshape(n, T, d)

    return X, C

def sparse_rct_sem(W_full, T, n=1, sparsity=0.3, std=0.01, noise_type='gauss', noise_effect='both'):
    """
        --- Optimal implementation -----
        W_full : list of adjacencies (length = p + 1)
        W_full = [A, B_1, B_2, ..., B_p]
        T: number of desired timesteps
        n: number of sequences to produce
    """

    #number of nodes
    d = W_full[0].shape[0]
    p = len(W_full)  - 1

    # matrices
    A = W_full[0]
    I = np.eye(d)
    A = csc_matrix(I - A.T)
    
    B = np.concatenate(W_full[1:][::-1], axis=0) # B = [B_p
                                                 #      B_{p-1}
                                                 #      ...
                                                 #      B_2
                                                 #      B_1]

    # initializing the sparse spectrum
    pos = np.random.choice([0, 1], size=(n, d * T), p=[1 - sparsity, sparsity]) 
    sign = np.random.choice([-1, 1], size=(n, d * T)) 
    C = pos * np.random.uniform(0.1, 1, size=(n, d * T)) # * sign
    # C = pos * np.random.normal(0, 1, size=(n, d * T)) 

    # computing matrix of independent noises
    if std==0:
        Ns = np.zeros((n, d * T))
        Nf = np.zeros((n, d * T))
    elif noise_type == 'gauss':
        Ns = np.random.normal(scale=std, size=(n, d * T))
        Nf = np.random.normal(scale=std, size=(n, d * T))
    else: # considering gumbel case
        noise_scale = np.sqrt(6) * std / np.pi
        Ns = np.random.gumbel(scale=noise_scale, size=(n, d * T))
        Nf = np.random.gumbel(scale=noise_scale, size=(n, d * T))

    if noise_effect == 'root_causes':
        Ns = np.zeros((n, d * T))
    elif noise_effect == 'signal':
        Nf = np.zeros((n, d * T))

    # adding noise to the root causes
    C_noisy = C + Nf
    X = np.zeros(C_noisy.shape)

    X[:, :d] = C_noisy[:, :d]
    for t in range(1, T):
        if t < p:
            Y = X[:, 0 : t * d] @ B[- (t * d):, :] + C_noisy[:, t * d : (t + 1) * d] # y = [x(t-p) ... x(t-1)] B + c[t]
        else:
            Y = X[:, (t - p) * d : t * d] @ B + C_noisy[:, t * d : (t + 1) * d] # y = [x(t-p) ... x(t-1)] B + c[t]


        X[:, t * d : (t + 1) * d] = spsolve(A, Y.T).T # x[t] = x[t]A + y

    # adding measurement noise
    X = X + Ns

    X = X.reshape(n, T, d)
    C = C.reshape(n, T, d)

    return X, C


def block_matrices(B, W, N=2):
    #number of nodes
    d = W.shape[0]
    I = np.eye(N)
    I_shift = np.roll(I, 1, 1)
    I_shift[-1:, 0] = 0

    # computing 
    # |0 W 0 0 |
    # |0 0 W 0 |
    # |0 0 0 W |
    # |0 0 0 0 | = W_total
    B_block = np.kron(I_shift, B)
    W_block = np.kron(I_shift, W)

    return B_block, W_block

def block_toeplitz(W_full, T):
    """
        W_full : list of adjacencies (length = p + 1)
        T: number of desired timesteps
    """
    #number of nodes
    p = len(W_full) - 1
    d = W_full[0].shape[0]
    I = np.eye(T)
    I_shift = np.roll(I, 1, 1)
    I_shift[-1:, 0] = 0

    # computing 
    # |W 0 0 0 |
    # |0 W 0 0 |
    # |0 0 W 0 |
    # |0 0 0 W | 
    A = np.kron(I, W_full[0])

    # computing 
    # |0 W 0 0 |
    # |0 0 W 0 |
    # |0 0 0 W |
    # |0 0 0 0 |
    I_i = I_shift
    for i in range(p):
        A += np.kron(I_i, W_full[i + 1])
        I_i = I_i @ I_shift

    # result 
    # |W_0 W_1 W_2 0   |
    # |0   W_0 W_1 W_2 |
    # |0   0   W_0 W_1 |
    # |0   0   0   W_0 | 
    return A 


if __name__ == "__main__":
    # testing the above functionalities
    import numpy as np

    a = np.array([0])
    b = np.array([1])

    W_full = [a,b,a,b]
    W = block_toeplitz(W_full, 10)
    print(W)

    X, C = sparse_rct_sem(W_full, 10, n=1, sparsity=0.1, std=0)
    print(X, C)

