import os
import numpy as np
import scipy.sparse
from itertools import combinations


class Graph:
    def __init__(self, number_of_nodes, edges, degrees, neighbors):
        self.number_of_nodes = number_of_nodes
        self.edges = edges
        self.degrees = degrees
        self.neighbors = neighbors

    def __len__(self):
        return self.number_of_nodes

    def greedy_clique_partition(self):
        cliques = []
        leftover_nodes = (-self.degrees).argsort().tolist()

        while leftover_nodes:
            clique_center, leftover_nodes = leftover_nodes[0], leftover_nodes[1:]
            clique = {clique_center}
            neighbors = self.neighbors[clique_center].intersection(leftover_nodes)
            densest_neighbors = sorted(neighbors, key=lambda x: -self.degrees[x])
            for neighbor in densest_neighbors:
                # Can you add it to the clique, and maintain cliqueness?
                if all([neighbor in self.neighbors[clique_node] for clique_node in clique]):
                    clique.add(neighbor)
            cliques.append(clique)
            leftover_nodes = [node for node in leftover_nodes if node not in clique]

        return cliques

    @staticmethod
    def erdos_renyi(number_of_nodes, edge_probability, random):
        edges = set()
        degrees = np.zeros(number_of_nodes, dtype=int)
        neighbors = {node: set() for node in range(number_of_nodes)}
        for edge in combinations(np.arange(number_of_nodes), 2):
            if random.uniform() < edge_probability:
                edges.add(edge)
                degrees[edge[0]] += 1
                degrees[edge[1]] += 1
                neighbors[edge[0]].add(edge[1])
                neighbors[edge[1]].add(edge[0])
        graph = Graph(number_of_nodes, edges, degrees, neighbors)
        return graph

    @staticmethod
    def barabasi_albert(number_of_nodes, affinity, random):
        assert affinity >= 1 and affinity < number_of_nodes

        edges = set()
        degrees = np.zeros(number_of_nodes, dtype=int)
        neighbors = {node: set() for node in range(number_of_nodes)}
        for new_node in range(affinity, number_of_nodes):
            # first node is connected to all previous ones (star-shape)
            if new_node == affinity:
                neighborhood = np.arange(new_node)
            # remaining nodes are picked stochastically
            else:
                neighbor_prob = degrees[:new_node] / (2*len(edges))
                neighborhood = random.choice(new_node, affinity, replace=False, p=neighbor_prob)
            for node in neighborhood:
                edges.add((node, new_node))
                degrees[node] += 1
                degrees[new_node] += 1
                neighbors[node].add(new_node)
                neighbors[new_node].add(node)

        graph = Graph(number_of_nodes, edges, degrees, neighbors)
        return graph


def generate_indset(graph, filename):
    cliques = graph.greedy_clique_partition()
    inequalities = set(graph.edges)
    for clique in cliques:
        clique = tuple(sorted(clique))
        for edge in combinations(clique, 2):
            inequalities.remove(edge)
        if len(clique) > 1:
            inequalities.add(clique)

    # Put trivial inequalities for nodes that didn't appear
    # in the constraints, otherwise SCIP will complain
    used_nodes = set()
    for group in inequalities:
        used_nodes.update(group)
    for node in range(10):
        if node not in used_nodes:
            inequalities.add((node,))

    with open(filename, 'w') as lp_file:
        lp_file.write("maximize\nOBJ:" + "".join([f" + 1 x{node+1}" for node in range(len(graph))]) + "\n")
        lp_file.write("\nsubject to\n")
        for count, group in enumerate(inequalities):
            lp_file.write(f"C{count+1}:" + "".join([f" + x{node+1}" for node in sorted(group)]) + " <= 1\n")
        lp_file.write("\nbinary\n" + " ".join([f"x{node+1}" for node in range(len(graph))]) + "\n")


def generate_setcover(nrows, ncols, density, filename, rng, max_coef=100):

    nnzrs = int(nrows * ncols * density)

    assert nnzrs >= nrows  # at least 1 col per row
    assert nnzrs >= 2 * ncols  # at leats 2 rows per col

    # compute number of rows per column
    indices = rng.choice(ncols, size=nnzrs)  # random column indexes
    indices[:2 * ncols] = np.repeat(np.arange(ncols), 2)  # force at leats 2 rows per col
    _, col_nrows = np.unique(indices, return_counts=True)

    # for each column, sample random rows
    indices[:nrows] = rng.permutation(nrows) # force at least 1 column per row
    i = 0
    indptr = [0]
    for n in col_nrows:

        # empty column, fill with random rows
        if i >= nrows:
            indices[i:i+n] = rng.choice(nrows, size=n, replace=False)

        # partially filled column, complete with random rows among remaining ones
        elif i + n > nrows:
            remaining_rows = np.setdiff1d(np.arange(nrows), indices[i:nrows], assume_unique=True)
            indices[nrows:i+n] = rng.choice(remaining_rows, size=i+n-nrows, replace=False)

        i += n
        indptr.append(i)

    # objective coefficients
    c = rng.randint(max_coef, size=ncols) + 1

    # sparce CSC to sparse CSR matrix
    A = scipy.sparse.csc_matrix(
            (np.ones(len(indices), dtype=int), indices, indptr),
            shape=(nrows, ncols)).tocsr()
    indices = A.indices
    indptr = A.indptr

    # write problem
    with open(filename, 'w') as file:
        file.write("minimize\nOBJ:")
        file.write("".join([f" +{c[j]} x{j+1}" for j in range(ncols)]))

        file.write("\n\nsubject to\n")
        for i in range(nrows):
            row_cols_str = "".join([f" +1 x{j+1}" for j in indices[indptr[i]:indptr[i+1]]])
            file.write(f"C{i}:" + row_cols_str + f" >= 1\n")

        file.write("\nbinary\n")
        file.write("".join([f" x{j+1}" for j in range(ncols)]))


def generate_cauctions(random, filename, n_items=100, n_bids=500, min_value=1, max_value=100,
                       value_deviation=0.5, add_item_prob=0.9, max_n_sub_bids=5,
                       additivity=0.2, budget_factor=1.5, resale_factor=0.5,
                       integers=False, warnings=False):

    assert min_value >= 0 and max_value >= min_value
    assert add_item_prob >= 0 and add_item_prob <= 1

    def choose_next_item(bundle_mask, interests, compats, add_item_prob, random):
        n_items = len(interests)
        prob = (1 - bundle_mask) * interests * compats[bundle_mask, :].mean(axis=0)
        prob /= prob.sum()
        return random.choice(n_items, p=prob)

    # common item values (resale price)
    values = min_value + (max_value - min_value) * random.rand(n_items)

    # item compatibilities
    compats = np.triu(random.rand(n_items, n_items), k=1)
    compats = compats + compats.transpose()
    compats = compats / compats.sum(1)

    bids = []
    n_dummy_items = 0

    # create bids, one bidder at a time
    while len(bids) < n_bids:

        # bidder item values (buy price) and interests
        private_interests = random.rand(n_items)
        private_values = values + max_value * value_deviation * (2 * private_interests - 1)

        # substitutable bids of this bidder
        bidder_bids = {}

        # generate initial bundle, choose first item according to bidder interests
        prob = private_interests / private_interests.sum()
        item = random.choice(n_items, p=prob)
        bundle_mask = np.full(n_items, 0)
        bundle_mask[item] = 1

        # add additional items, according to bidder interests and item compatibilities
        while random.rand() < add_item_prob:
            # stop when bundle full (no item left)
            if bundle_mask.sum() == n_items:
                break
            item = choose_next_item(bundle_mask, private_interests, compats, add_item_prob, random)
            bundle_mask[item] = 1

        bundle = np.nonzero(bundle_mask)[0]

        # compute bundle price with value additivity
        price = private_values[bundle].sum() + np.power(len(bundle), 1 + additivity)
        if integers:
            price = int(price)

        # drop negativaly priced bundles
        if price < 0:
            if warnings:
                print("warning: negatively priced bundle avoided")
            continue

        # bid on initial bundle
        bidder_bids[frozenset(bundle)] = price

        # generate candidates substitutable bundles
        sub_candidates = []
        for item in bundle:

            # at least one item must be shared with initial bundle
            bundle_mask = np.full(n_items, 0)
            bundle_mask[item] = 1

            # add additional items, according to bidder interests and item compatibilities
            while bundle_mask.sum() < len(bundle):
                item = choose_next_item(bundle_mask, private_interests, compats, add_item_prob, random)
                bundle_mask[item] = 1

            sub_bundle = np.nonzero(bundle_mask)[0]

            # compute bundle price with value additivity
            sub_price = private_values[sub_bundle].sum() + np.power(len(sub_bundle), 1 + additivity)
            if integers:
                sub_price = int(sub_price)

            sub_candidates.append((sub_bundle, sub_price))

        # filter valid candidates, higher priced candidates first
        budget = budget_factor * price
        min_resale_value = resale_factor * values[bundle].sum()
        for bundle, price in [
            sub_candidates[i] for i in np.argsort([-price for bundle, price in sub_candidates])]:

            if len(bidder_bids) >= max_n_sub_bids + 1 or len(bids) + len(bidder_bids) >= n_bids:
                break

            if price < 0:
                if warnings:
                    print("warning: negatively priced substitutable bundle avoided")
                continue

            if price > budget:
                if warnings:
                    print("warning: over priced substitutable bundle avoided")
                continue

            if values[bundle].sum() < min_resale_value:
                if warnings:
                    print("warning: substitutable bundle below min resale value avoided")
                continue

            if frozenset(bundle) in bidder_bids:
                if warnings:
                    print("warning: duplicated substitutable bundle avoided")
                continue

            bidder_bids[frozenset(bundle)] = price

        # add XOR constraint if needed (dummy item)
        if len(bidder_bids) > 2:
            dummy_item = [n_items + n_dummy_items]
            n_dummy_items += 1
        else:
            dummy_item = []

        # place bids
        for bundle, price in bidder_bids.items():
            bids.append((list(bundle) + dummy_item, price))

    # generate the LP file
    with open(filename, 'w') as file:
        bids_per_item = [[] for item in range(n_items + n_dummy_items)]

        file.write("maximize\nOBJ:")
        for i, bid in enumerate(bids):
            bundle, price = bid
            file.write(f" +{price} x{i+1}")
            for item in bundle:
                bids_per_item[item].append(i)

        file.write("\n\nsubject to\n")
        for item_bids in bids_per_item:
            if item_bids:
                for i in item_bids:
                    file.write(f" +1 x{i+1}")
                file.write(f" <= 1\n")

        file.write("\nbinary\n")
        for i in range(len(bids)):
            file.write(f" x{i+1}")


def generate_capacited_facility_location(random, filename, n_customers, n_facilities, ratio, rng):

    c_x = rng.rand(n_customers)
    c_y = rng.rand(n_customers)

    f_x = rng.rand(n_facilities)
    f_y = rng.rand(n_facilities)

    demands = rng.randint(5, 35+1, size=n_customers)
    capacities = rng.randint(10, 160+1, size=n_facilities)
    fixed_costs = rng.randint(100, 110+1, size=n_facilities) * np.sqrt(capacities) \
            + rng.randint(90+1, size=n_facilities)
    fixed_costs = fixed_costs.astype(int)

    total_demand = demands.sum()
    total_capacity = capacities.sum()

    # adjust capacities according to ratio
    capacities = capacities * ratio * total_demand / total_capacity
    capacities = capacities.astype(int)
    total_capacity = capacities.sum()

    # transportation costs
    trans_costs = np.sqrt(
            (c_x.reshape((-1, 1)) - f_x.reshape((1, -1))) ** 2 \
            + (c_y.reshape((-1, 1)) - f_y.reshape((1, -1))) ** 2) * 10 * demands.reshape((-1, 1))

    # write problem
    with open(filename, 'w') as file:
        file.write("minimize\nobj:")
        file.write("".join([f" +{trans_costs[i, j]} x_{i+1}_{j+1}" for i in range(n_customers) for j in range(n_facilities)]))
        file.write("".join([f" +{fixed_costs[j]} y_{j+1}" for j in range(n_facilities)]))

        file.write("\n\nsubject to\n")
        for i in range(n_customers):
            file.write(f"demand_{i+1}:" + "".join([f" -1 x_{i+1}_{j+1}" for j in range(n_facilities)]) + f" <= -1\n")
        for j in range(n_facilities):
            file.write(f"capacity_{j+1}:" + "".join([f" +{demands[i]} x_{i+1}_{j+1}" for i in range(n_customers)]) + f" -{capacities[j]} y_{j+1} <= 0\n")

        # optional constraints for LP relaxation tightening
        file.write("total_capacity:" + "".join([f" -{capacities[j]} y_{j+1}" for j in range(n_facilities)]) + f" <= -{total_demand}\n")
        for i in range(n_customers):
            for j in range(n_facilities):
                file.write("\n")
                file.write(f"affectation_{i+1}_{j+1}: +1 x_{i+1}_{j+1} -1 y_{j+1} <= 0 ")

        file.write("\nbounds\n")
        for i in range(n_customers):
            for j in range(n_facilities):
                file.write(f"0 <= x_{i+1}_{j+1} <= 1\n")

        file.write("\nbinary\n")
        file.write("".join([f" y_{j+1}" for j in range(n_facilities)]))


def gen_lp_file_main(problem,train_number,valid_number,test_numver):

    seed = 0
    rng = np.random.RandomState(seed)

    if problem == 'setcover':
        nrows = 500
        ncols = 1000
        dens = 0.05
        max_coef = 100

        filenames = []
        nrowss = []
        ncolss = []
        denss = []

        # train instances
        n = train_number
        lp_dir = f'./instances/setcover/train/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'setcover_{i+1}.lp') for i in range(n)])
        nrowss.extend([nrows] * n)
        ncolss.extend([ncols] * n)
        denss.extend([dens] * n)

        # valid instances
        n = valid_number
        lp_dir = f'./instances/setcover/valid/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'setcover_{i+1}.lp') for i in range(n)])
        nrowss.extend([nrows] * n)
        ncolss.extend([ncols] * n)
        denss.extend([dens] * n)

        # test instances
        n = test_numver
        lp_dir = f'./instances/setcover/test/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'setcover_{i+1}.lp') for i in range(n)])
        nrowss.extend([nrows] * n)
        ncolss.extend([ncols] * n)
        denss.extend([dens] * n)

        # actually generate the instances
        for filename, nrows, ncols, dens in zip(filenames, nrowss, ncolss, denss):
            print(f'  generating file {filename} ...')
            generate_setcover(nrows=nrows, ncols=ncols, density=dens, filename=filename, rng=rng, max_coef=max_coef)

        print('done.')

    elif problem == 'indset':
        number_of_nodes = 500
        affinity = 4

        filenames = []
        nnodess = []

        # train instances
        n = train_number
        lp_dir = f'./instances/indset/train/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'indset_{i+1}.lp') for i in range(n)])
        nnodess.extend([number_of_nodes] * n)

        # valid instances
        n = valid_number
        lp_dir = f'./instances/indset/valid/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'indset_{i+1}.lp') for i in range(n)])
        nnodess.extend([number_of_nodes] * n)

        # test instances
        n = test_numver
        lp_dir = f'./instances/indset/test/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'indset_{i+1}.lp') for i in range(n)])
        nnodess.extend([number_of_nodes] * n)

        # actually generate the instances
        for filename, nnodes in zip(filenames, nnodess):
            print(f"  generating file {filename} ...")
            graph = Graph.barabasi_albert(nnodes, affinity, rng)
            generate_indset(graph, filename)

        print("done.")

    elif problem == 'cauctions':
        number_of_items = 100
        number_of_bids = 500
        filenames = []
        nitemss = []
        nbidss = []

        # train instances
        n = train_number
        lp_dir = f'./instances/cauctions/train/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'cauctions_{i+1}.lp') for i in range(n)])
        nitemss.extend([number_of_items] * n)
        nbidss.extend([number_of_bids ] * n)

        # valid instances
        n = valid_number
        lp_dir = f'./instances/cauctions/valid/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'cauctions_{i+1}.lp') for i in range(n)])
        nitemss.extend([number_of_items] * n)
        nbidss.extend([number_of_bids ] * n)

        # test instances
        n = test_numver
        lp_dir = f'./instances/cauctions/test/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'cauctions_{i+1}.lp') for i in range(n)])
        nitemss.extend([number_of_items] * n)
        nbidss.extend([number_of_bids ] * n)

        # actually generate the instances
        for filename, nitems, nbids in zip(filenames, nitemss, nbidss):
            print(f"  generating file {filename} ...")
            generate_cauctions(rng, filename, n_items=nitems, n_bids=nbids, add_item_prob=0.7)

        print("done.")

    elif problem == 'facilities':
        number_of_customers = 100
        number_of_facilities = 100
        ratio = 5
        filenames = []
        ncustomerss = []
        nfacilitiess = []
        ratios = []

        # train instances
        n = train_number
        lp_dir = f'./instances/facilities/train/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'facilities_{i+1}.lp') for i in range(n)])
        ncustomerss.extend([number_of_customers] * n)
        nfacilitiess.extend([number_of_facilities] * n)
        ratios.extend([ratio] * n)

        # valid instances
        n = valid_number
        lp_dir = f'./instances/facilities/valid/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'facilities_{i+1}.lp') for i in range(n)])
        ncustomerss.extend([number_of_customers] * n)
        nfacilitiess.extend([number_of_facilities] * n)
        ratios.extend([ratio] * n)

        # valid instances
        n = test_numver
        lp_dir = f'./instances/facilities/test/'
        print(f"{n} instances in {lp_dir}")
        if not os.path.exists(lp_dir):
            os.makedirs(lp_dir)
        filenames.extend([os.path.join(lp_dir, f'facilities_{i+1}.lp') for i in range(n)])
        ncustomerss.extend([number_of_customers] * n)
        nfacilitiess.extend([number_of_facilities] * n)
        ratios.extend([ratio] * n)

        # actually generate the instances
        for filename, ncs, nfs, r in zip(filenames, ncustomerss, nfacilitiess, ratios):
            print(f"  generating file {filename} ...")
            generate_capacited_facility_location(rng, filename, n_customers=ncs, n_facilities=nfs, ratio=r,rng=rng)

        print("done.")


if __name__ == '__main__':
    # problem: setcover, indset ,cauctions ,facilities
    # problem = 'facilities'
    train_number = 10000
    valid_number = 400
    test_numver = 100
    choices = ['setcover', 'indset', 'cauctions', 'facilities']
    problem_scale = ['easy','medium','hard']
    for problem in choices:
        gen_lp_file_main(problem,train_number,valid_number,test_numver)