"""
Finds a large independent set in graph G where nodes are binary strings of length n.
Nodes in G are connected if they share a subsequence of length at least n - s, where s = 1.

The `priority` function assigns a priority to each node indicating its importance for inclusion in the independent set.

Desired properties of the `priority` function:
- **Efficiency**: The function should be computationally efficient.
- **Avoid Redundant Computations**: Do not perform unnecessary calculations or repeat work.
- **Clarity**: The code should be easy to understand, with appropriate comments.
- **Innovation**: Explore different strategies for calculating the priority. Consider specific characteristics of the binary strings, such as:
    - Patterns in the binary string.
    - The number of ones or zeros (Hamming weight).
    - Distribution of bits (e.g., runs of ones or zeros).

Improve the `priority_v1` function over its previous versions below.
"""

import itertools
import hashlib
import numpy as np
import networkx as nx
import lmdb
import json
import os

def load_graph(graph_db_path):
    """ Load the graph from an LMDB database. """
    G = nx.Graph()
    graph_env = lmdb.open(graph_db_path, readonly=True, lock=False)

    with graph_env.begin() as txn:
        cursor = txn.cursor()
        for key, value in cursor:
            node = key.decode()
            neighbors = json.loads(value.decode())
            for neighbor in neighbors:
                G.add_edge(node, neighbor)

    graph_env.close()
    return G


def hash_priority_mapping(priorities, sequences):
    """ Generate a hash based on the mapping of sequences to their priority scores. """
    mapping = [(seq, priorities[seq]) for seq in sequences]
    mapping_sorted = sorted(mapping, key=lambda x: x[0])  # Sort by sequence
    mapping_str = ','.join(f'{seq}:{score}' for seq, score in mapping_sorted)
    return hashlib.sha256(mapping_str.encode()).hexdigest()


def evaluate(params, graph_dir):
    n, s = params
    independent_set, hash_value = solve(n, s, graph_dir)
    return (len(independent_set), hash_value)


def solve(n, s, graph_dir):
    """ Find a large independent set in a loaded graph while avoiding unnecessary copies. """
    path = os.path.join(graph_dir, f"graph_s{s}_n{n}.lmdb")

    G = load_graph(path)  # Load the graph directly, no copying
    G_for_priority = G.copy()  

    sequences = [''.join(seq) for seq in itertools.product('01', repeat=n)]
    priorities = {node: priority(node, G_for_priority, n, s) for node in G.nodes}

    # Sort nodes first by priority (higher is better), then by lexicographic order (ascending)
    nodes_sorted = sorted(G.nodes, key=lambda x: (-priorities[x], x))

    independent_set = set()
    for node in nodes_sorted:
        if node not in G:
            continue
        independent_set.add(node)
        neighbors = list(G.neighbors(node))
        G.remove_node(node)  # Remove the node from the original graph
        G.remove_nodes_from(neighbors)  # Remove its neighbors

    hash_value = None
    if n == 9:
        hash_value = hash_priority_mapping(priorities, sequences)

    return independent_set, hash_value


def priority(node, G, n, s):
    """ Returns the priority with which we want to add `node` to the independent set. """
    return 0.0  
