from utils import OurTimer
import collections
import itertools
import torch
import networkx as nx
import numpy as np
from copy import deepcopy

from config import FLAGS
from dvn_encoder import create_encoder
from dvn_decoder import create_decoder
from dvn_preencoder import create_preencoder
from dvn import DVN
from utils_nn import MLP, get_MLP_args

def create_u2v_li(nn_map, cs_map, candidate_map):
    u2v_li = {}
    for u in cs_map.keys():
        if u in nn_map:
            v_li = [nn_map[u]]
        elif u in candidate_map:
            v_li = candidate_map[u]
        else:
            v_li = cs_map[u]
        u2v_li[u] = v_li
    return u2v_li

def create_dvn(d_in_raw, d_in):
    pre_encoder = create_preencoder(d_in_raw, d_in)
    encoder_gnn_consensus, d_out = create_encoder(d_in)
    decoder_policy, decoder_value = create_decoder()
    mlp_final = MLP(*get_MLP_args([64, 32, 16, 8, 4, 1]))
    norm_li = \
        torch.nn.ModuleList([
            torch.nn.LayerNorm(d_in),
            torch.nn.LayerNorm(d_in),
            torch.nn.LayerNorm(d_out),
            torch.nn.LayerNorm(d_out),
        ])
    dvn = DVN(pre_encoder, encoder_gnn_consensus, decoder_policy, decoder_value, norm_li)
    dvn_wrapper = DVN_wrapper(dvn, mlp_final)
    return dvn_wrapper

class DVN_wrapper(torch.nn.Module):
    def __init__(self, dvn, mlp_final):
        super(DVN_wrapper, self).__init__()
        self.dvn = dvn
        self.mlp_final = mlp_final

    def reset_cache(self):
        self.dvn.encoder.reset_cache()

    def forward(self, gq, gt, u, v_li, nn_map, cs_map, candidate_map,
                cache_embeddings, graph_filter=None, filter_key=None,
                execute_action=None, query_tree=None):

        timer = None
        if FLAGS.time_analysis:
            timer = OurTimer()

        # unpack inputs
        Xq, edge_indexq, Xt, edge_indext = \
            gq.x, gq.edge_index, gt.x, gt.edge_index
        u2v_li = create_u2v_li(nn_map, cs_map, candidate_map)

        if FLAGS.time_analysis:
            timer.time_and_clear(f'create_u2v_li')

        # if graph_filter is None or len(nn_map) == 0:
        node_mask = np.array([])
        # else:
        #     node_mask, node_mask_inv = \
        #         graph_filter.get_node_mask(filter_key, gq.nx_graph, gt.nx_graph, nn_map.values(), u2v_li)
        #     edge_indext = \
        #         graph_filter.get_node_mask_oh_edge_index(edge_indext, node_mask_inv)#, gt.nx_graph.number_of_nodes())

        if FLAGS.time_analysis:
            timer.time_and_clear(f'node_mask {gt.nx_graph.number_of_nodes() - len(node_mask)} nodes selected; before {gt.nx_graph.number_of_nodes()}')

        out_policy, out_value, out_other = \
            self.dvn(
                Xq, edge_indexq, Xt, edge_indext,
                gq.nx_graph, gt.nx_graph,
                nn_map, cs_map, candidate_map,
                u2v_li, node_mask, cache_embeddings,
                execute_action, query_tree, u=u, v_li=v_li
            )
        if FLAGS.time_analysis:
            timer.time_and_clear(f'fast?')
            timer.print_durations_log()
        return out_policy, out_value, out_other

class EmbeddingCache:
    def __init__(self):
        self.Xt_li = []
        self.key = None

    def get_Xt_li(self, key):
        if key is None or key != self.key:
            return None
        else:
            return self.Xt_li

class GraphFilter:
    def __init__(self, cache_size=10):
        self.queue = collections.deque([None for _ in range(cache_size)], 100)
        self.cache = {}

    def add_to_cache(self, key, val):
        key_rm = self.queue.popleft()
        if key_rm is not None and key_rm != key:
            del self.cache[key_rm]
        if key not in self.cache:
            self.cache[key] = val
        self.queue.append(key)

    def check_if_new_filter(self, cur_node, nn_map):
        if len(nn_map) == 1:
            filter_key = self.get_filter_key(nn_map)
        else:
            filter_key = cur_node.filter_key
        return filter_key

    def get_filter_key(self, nn_map):
        filter_key = frozenset(nn_map.values())
        return filter_key

    def get_node_mask_diameter(self, gq, gt, root_li):
        assert len(root_li) >= 1
        k = max(nx.diameter(gq), 2)
        sel_nodes = set()
        for root in root_li:
            if len(sel_nodes) == 0:
                # sel_nodes = set(nx.ego_graph(gt, root, k).nodes())
                sel_nodes = set([v for _, v_successors in nx.bfs_successors(gt, root, k) for v in v_successors]).union({root})
            else:
                sel_nodes = sel_nodes.intersection(nx.ego_graph(gt, root, k).nodes())
        return sel_nodes

    def get_node_mask_circle(self, u2v_li):
        sel_nodes = set(itertools.chain.from_iterable([v_li for v_li in u2v_li.values()]))
        return sel_nodes

    def get_node_mask(self, filter_key, gq, gt, root_li, u2v_li):
        if FLAGS.use_node_mask_diameter:
            # import time
            if filter_key is not None and filter_key in self.cache:
                node_mask_diameter = self.cache[filter_key]
                # t0 = time.time()
                # t1 = t0
            else:
                # t0 = time.time()
                node_mask_diameter = self.get_node_mask_diameter(gq, gt, root_li)
                self.add_to_cache(filter_key, node_mask_diameter)
                # t1 = time.time()
        else:
            node_mask_diameter = None
        node_mask_circle = self.get_node_mask_circle(u2v_li)
        # t2 = time.time()
        node_mask_inv = \
             node_mask_circle if node_mask_diameter is None else \
                 node_mask_diameter.intersection(node_mask_circle)
        node_mask = set(gt.nodes()) - node_mask_inv
        # t3 = time.time()
        # print(f'diameter:{t1-t0}\tcircle:{t2-t1}\tmerge:{t3-t2}\tfilter_key:{filter_key}')
        return list(node_mask), list(node_mask_inv)

    def get_node_mask_oh_edge_index(self, edge_index, node_mask):#, gt_num_nodes):
        edge_index_np = edge_index.detach().cpu().numpy()
        edge_index_filtered = edge_index[:, np.in1d(edge_index_np[0], node_mask)]# * np.in1d(edge_index_np[1], node_mask)]

        return edge_index_filtered # node_mask_one_hot,


# def plot_embeddings(self, embedding_logger, u, v_li, Xq_reg, Xt_reg, gq, gt, nn_map, cs_map):
#     # print('@@@ plot_embeddings embedding_logger', embedding_logger, ' len(nn_map)',  len(nn_map))
#     # exit()
#     if embedding_logger is not None and len(nn_map) > 0:
#     # if embedding_logger is not None:
#         override = \
#             {
#                 'g_q': {u : {'color':(0,1,1,1),'size':300}},
#                 'g_t': {
#                     v: {
#                         'color':(0, 1, 0, 1) if f'{u}_{v}' in embedding_logger.action_best_li else (1, 0, 0, 1),
#                         'size':300
#                     } for v in v_li
#                 }
#             }
#         embedding_logger.plot_embeddings(Xq_reg, Xt_reg, gq, gt, nn_map, cs_map, override=override, pn=embedding_logger.get_pn_state())
#
# def decode_dvn(self, u, v_li, gq, gt, nn_map, cs_map, query_tree, Xq_reg, Xt_reg, embedding_logger, detach):
#     Xq, edge_indexq, Xt, edge_indext = \
#         gq.x, gq.edge_index, gt.x, gt.edge_index
#     # Xq, Xt = self.dvn.pre_encoder(Xq, Xt, nn_map)
#
#     v_li_valid, nn_map_uv_li, candidate_map_li = \
#         self.execute_action_li(u, v_li, gq.nx_graph, gt.nx_graph, nn_map, cs_map, query_tree)
#     # self.filter_candidate_maps(candidate_map_li)
#     # t3 = time.time()
#
#     V_pred_li = []
#     # logger = Logger()
#     for v_valid, nn_map_uv, candidate_map in zip(v_li_valid, nn_map_uv_li, candidate_map_li):
#         if embedding_logger is not None:
#             embedding_logger.action = f'{u}_{v_valid}'
#         # assert len(candidate_map) > 0
#         if len(nn_map_uv) == gq.nx_graph.number_of_nodes():
#             V_pred = torch.tensor([1.0], device=FLAGS.device).view(-1)
#         else:
#             if FLAGS.dvn_config['encoder']['shared_encoder']:
#                 Xq, edge_indexq, Xt, edge_indext = None, None, None, None
#                 encoded_Xq_Xt = (Xq_reg, Xt_reg)
#             else:
#                 encoded_Xq_Xt = None
#
#             V_pred = \
#                 self.dvn(
#                     Xq, edge_indexq, Xt, edge_indext,
#                     gq.nx_graph, gt.nx_graph, nn_map_uv, cs_map,
#                     candidate_map=candidate_map,
#                     encoded_Xq_Xt=encoded_Xq_Xt,
#                     embedding_logger=embedding_logger)
#
#             if detach:
#                 V_pred = V_pred.detach()
#         V_pred_li.append(V_pred)
#     # t4 = time.time()
#
#     if len(V_pred_li) > 0:
#         # V_emb = torch.stack(V_pred_li).view(-1)
#         V_emb = torch.stack(V_pred_li, dim=0)
#         # V_emb = V_emb - torch.mean(V_emb, dim=0).view(1, -1)
#         V = self.mlp_final(V_emb).view(-1)
#         assert V.shape[0] == len(v_li_valid)
#     else:
#         V = []
#         assert len(V) == len(v_li_valid)
#     # t5 = time.time()
#
#     # if len(v_li_valid) > 1:
#     #     print(f'v_li_len: {len(v_li_valid)}')
#     #     print(f'time unpack: \t{t1-t0}s')
#     #     print(f'time encode: \t{t2-t1}s')
#     #     print(f'time exec a: \t{t3-t2}s')
#     #     print(f'time exec nn: \t{t4-t3}s')
#     #     print(f'time merge vec: \t{t5-t4}s')
#     #     exit(-1)
#     return V, v_li_valid
#
# def execute_action_li(self, u, v_li, gq, gt, nn_map, cs_map, query_tree):
#     # there are 2 v's here...
#     #   v is the (u,v) pair node id
#     #   V is the value function of the state after executing V(env(s,a=(u,v))
#     v_li_valid, nn_map_uv_li, candidate_map_li = [], [], []
#     for v in v_li:
#         nn_map_uv, candidate_map = self.execute_action(u, v, gq, gt, nn_map, cs_map, query_tree)
#         # if len(candidate_map) > 0:
#         v_li_valid.append(v)
#         nn_map_uv_li.append(nn_map_uv)
#         candidate_map_li.append(candidate_map)
#     return v_li_valid, nn_map_uv_li, candidate_map_li
#
# def execute_action(self, u, v, gq, gt, nn_map, cs_map, query_tree):
#     nn_map_new = deepcopy(nn_map)
#     nn_map_new[u] = v
#
#     nn_map_keys_set = set(nn_map_new.keys())
#     frontier_and_nn_map = set().union(*[set(nx.neighbors(gq, nid) for nid in nn_map_keys_set)])
#     unconnectedbd = set(gq.nodes()) - frontier_and_nn_map
#
#     candidate_map = {}
#     for u in get_u_next_li(gq, query_tree, nn_map_new):
#         candidate_map[u] = get_v_candidates_wrapper(u, nn_map_new, gq, gt, cs_map, query_tree)
#         if len(candidate_map[u]) == 0:
#             candidate_map = {}
#             break
#     for u in unconnectedbd:
#         candidate_map[u] = [v for v in cs_map[u] if len(set(nx.neighbors(gt, v)).intersection(nn_map_keys_set)) == 0]
#         if len(candidate_map[u]) == 0:
#             candidate_map = {}
#             break
#
#     # candidate_map = {}
#     # for (u, v_li) in cs_map.items():
#     #     if u not in nn_map_new.keys():
#     #         candidate_map[u] = \
#     #             [
#     #                 v for v in gt.nodes() if \
#     #                 {nn_map_new[bit] for bit in get_bidomain_bitvector(gq, u, nn_map_new.keys())} ==
#     #                 get_bidomain_bitvector(gt, v, nn_map_new.values())
#     #             ]
#     #         # assert len(candidate_map[u]) > 0
#     #         if len(candidate_map[u]) == 0:
#     #             candidate_map = {}
#     #             break
#     #         # assert len(candidate_map[u]) > 0
#     return nn_map_new, candidate_map
#
# def filter_candidate_maps(self, candidate_map_li):
#     u2key = {u:set(v_li) for (u,v_li) in candidate_map_li[0].items()}
#     for candidate_map in candidate_map_li[1:]:
#         for u, v_li in candidate_map.items():
#             if u in u2key and set(v_li) != u2key[u]:
#                 del u2key[u]
#     for i, candidate_map in enumerate(candidate_map_li):
#         for u in list(candidate_map.keys()):
#             if u in u2key:
#                 del candidate_map_li[i][u]
#     return candidate_map_li

