import scipy
import numpy as np
from scipy.sparse import rand


class SCInput:

    def __init__(self, *, connections, set_prices):
        """
        This function creates a set cover input.
        It receives:
         * an adjacency matrix, where the rows represent elements and the columns represent rows.
         * prices for the sets.
         * an arrival order for the elements (a list of indices of elements).

        """
        self.num_elems = connections.shape[0]
        self.num_sets = connections.shape[1]
        self.connections = scipy.sparse.csr_array(connections)
        self.set_prices = set_prices.copy()

    @staticmethod
    def get_random_input(num_elems,
                         num_sets,
                         *,
                         uniform_conn_prob,
                         sigma,
                         seed=42):
        rng = np.random.default_rng(seed=seed)

        set_prices = rng.lognormal(size=num_sets, sigma=sigma)

        num_sets_wo_singletons = num_sets - num_elems

        # create a random connections matrix according to the given connection probability.
        connections = (scipy.sparse.rand(num_elems, num_sets_wo_singletons,
                                         density=uniform_conn_prob,
                                         random_state=rng) > 0).astype(int)

        # add connections for singleton sets.
        connections = scipy.sparse.hstack([connections, scipy.sparse.eye(num_elems)])
        return SCInput(connections=connections, set_prices=set_prices)
