import numpy as np

from scipy.sparse.csgraph import connected_components

def sigmoid(z):
    return 1/(1 + np.exp(-z))

def inverse_sigmoid(z):
    return np.log(np.divide(z, 1-z))

def state_action_pairs(S, A):
    indices_matrix = np.repeat(np.arange(S*A).reshape(-1,1), S*A, 1)

    q1 = indices_matrix[np.triu_indices(S*A, 1)].reshape(-1,1)
    q2 = indices_matrix.T[np.triu_indices(S*A, 1)].reshape(-1,1)

    return np.concatenate((q1, q2), axis=1)

def build_query_set(S, A, H, h):
    sa_pairs = state_action_pairs(S, A)

    query_set = np.zeros((sa_pairs.shape[0], 2*H), dtype=int)

    query_set[:, h] = sa_pairs[:, 0]
    query_set[:, H+h] = sa_pairs[:, 1]

    return query_set

def best_basis_h(Q, W, H, h, q=None):
    assert (q is None or np.any(np.all(q == Q, axis=1))), "Fixed query must either be None or a row of Q"

    if q is not None:
        q_idx = np.where(np.all(q==Q, axis=1))[0][0]
        W[q_idx] = np.max(W)+1

    sort_idx = np.argsort(W.reshape(-1,1), axis=0).ravel()

    Q = Q[sort_idx, :]

    n = Q.shape[0]

    basis = np.ones(n, dtype=bool)

    sa_pairs = Q[:, [h, H+h]]
    V = len(np.unique(sa_pairs.ravel()))

    G = np.zeros((V, V))

    for i in range(n):
        G[sa_pairs[i, 0], sa_pairs[i, 1]] = 1

    # elements in sa_pairs are already sorted by increasing weights
    for i in range(n):
        pair = sa_pairs[i, :]
        G1 = G.copy()
        G1[pair[0], pair[1]] = 0

        if connected_components(G1, directed=False)[0] == 1: # remove edge
            G = G1
            basis[i] = 0

    return Q[basis, :]