from config import FLAGS

from collections import defaultdict
from copy import deepcopy
import networkx as nx

#########################################################################
# Bidomain
#########################################################################
class Bidomain(object):
    def __init__(self, left, right, natts, bid=None):
        self.left = left
        self.right = right
        self.natts = natts
        self.bid = bid

    def __len__(self):
        return len(self.left) * len(self.right)

def get_natts_hash(node):
    if 'fuzzy_matching' in FLAGS.reward_calculator_mode:
        natts = []
    else:
        natts = FLAGS.node_feats_for_mcs
    natts_hash = tuple([node[natt] for natt in natts])
    return natts_hash


def unroll_bidomains(natts2bds):
    bidomains = [bd for bds in natts2bds.values() for bd in bds]
    return bidomains


def get_natts2g2abd_sg_nids(natts2g2nids, natts2bds, nn_map):
    natts2g2abd_sg_nids = defaultdict(dict)
    sg1, sg2 = set(nn_map.keys()), set(nn_map.values())
    for natts, g2nid in natts2g2nids.items():
        left_cum, right_cum = set(), set()
        if natts in natts2bds:
            for bd in natts2bds[natts]:
                left_cum.update(bd.left)
                right_cum.update(bd.right)
        left_cum.update(sg1.intersection(g2nid['g1']))  # TODO: potential bottleneck O(nn_map)
        left_cum.update(sg2.intersection(g2nid['g2']))
        natts2g2abd_sg_nids[natts]['g1'] = left_cum
        natts2g2abd_sg_nids[natts]['g2'] = right_cum
    return natts2g2abd_sg_nids


def assign_bids(natts2bds):
    # to ensure that given the same natts2bds, bid assignment is deterministic => sorted
    bid = 0
    for natts in sorted(natts2bds.keys()):
        for bd in natts2bds[natts]:
            bd.bid = bid
            bid += 1