import queue
import networkx as nx
import graph_methods as graph
import numpy as np
import pickle


class TreeNode:
    def __init__(self, value=None, depth=None):
        self.value = value
        self.children = []
        self.depth = depth
        self.node_set = set()
        self.parent = None
        self.volume = 0
        self.g_val = 0
        self.cut_val = 0

    def add_child(self, node):
        self.children.append(node)
        node.parent = self

    def __str__(self):
        return self.value


def build(layer_list):
    q = queue.Queue()
    root = TreeNode('root', depth=0)
    q.put(root)
    for layer in layer_list:
        # print(layer)
        length = q.qsize()
        for _ in range(length):
            parent = q.get()
            # print(layer, parent.value)
            if parent.value in layer.keys():
                for val in layer[parent.value]:
                    child = TreeNode(val, depth=parent.depth + 1)
                    q.put(child)
                    parent.add_child(child)
    return root


def get_k_layer_ground_truth(node, k):
    res = dict()
    if node.depth > k:
        return res
    elif node.depth == k:
        res[node.value] = node.node_set

    for child in node.children:
        res.update(get_k_layer_ground_truth(child, k))
    return res


def calc_node_cut_val(node, G):
    for i in range(len(node.children)):
        for j in range(i + 1, len(node.children)):
            node.cut_val += graph.cut_value(G, node.children[i].node_set, node.children[j].node_set)


def dfs(node, G):
    if not node.children:
        node.node_set = {node.value}
        node.volume = sum(edge_data['weight'] for _, _, edge_data in G.edges(node.value, data=True))
        node.g_val = node.volume
    else:
        for child in node.children:
            dfs(child, G)
            node.node_set.update(child.node_set)
            node.volume += child.volume
            node.g_val += child.g_val
        calc_node_cut_val(node, G)
        node.g_val -= 2 * node.cut_val


def print_tree(node, depth, f):
    if node is not None:
        f.write('-' * depth + str(node.value) + ' ' + str(len(node.children)) + '\n')
        if depth > 100:
            return
        for child in node.children:
            print_tree(child, depth + 1, f)


def calc_SE(node, G_volume):
    res = 0
    if node.parent is not None:
        res = - node.g_val / G_volume * np.log2(node.volume / node.parent.volume)
    for child in node.children:
        res += calc_SE(child, G_volume)
    # if not node.children:
    #     res += node.g_val / G_volume * np.log2(node.volume / G_volume)
    return res


def calc_HME(node, G_volume):
    res = 0
    if len(node.children) > 0 and node.g_val > 0:
        res -= node.g_val / G_volume * np.log2(node.g_val / (2 * node.cut_val + 2 * node.g_val))
    for child in node.children:
        # print('node.node_set', node.node_set)
        # print('child.node_set', child.node_set, 'child.g_val', child.g_val)
        # print('node.cut_val', node.cut_val, 'node.g_val', node.g_val)
        # print('2 * node.cut_val + 2 * node.g_val', 2 * node.cut_val + 2 * node.g_val)
        if child.g_val > 0:
            res -= child.g_val / G_volume * np.log2(child.g_val / (2 * node.cut_val + 2 * node.g_val))
        # print(- child.g_val/ G_volume * np.log2(child.g_val / (2 * node.cut_val + 2 * node.g_val)))
        res += calc_HME(child, G_volume)
    return res


def calc_Das(node):
    res = 0
    if node.children is not None:
        res = node.cut_val * len(node.node_set)
    for child in node.children:
        res += calc_Das(child)
    return res


def get_leaf_depth(node, depth):
    res = []
    if node.children is None or len(node.children) == 0:
        res = [depth]
        return res
    for child in node.children:
        if len(node.children) == 1:
            res += get_leaf_depth(child, depth)
        else:
            res += get_leaf_depth(child, depth + 1)
    return res


def calc_depth_balance_factor(depths):
    # return np.std(depths) / np.mean(depths)
    return np.std(depths)


def calc_size_balance_factor(node):
    if node.children is None or len(node.children) == 0:
        factor = 0
    elif len(node.children) == 2:
        left_child, right_child = node.children
        factor = len(node.node_set) * np.abs(len(left_child.node_set) - len(right_child.node_set)) / len(node.node_set)
        # print(node.volume, factor)
        factor += calc_size_balance_factor(left_child)
        factor += calc_size_balance_factor(right_child)
    else:
        raise Exception(f'Not a binary tree')
    return factor


def calc_volume_balance_factor(node):
    if node.children is None or len(node.children) == 0:
        factor = 0
    elif len(node.children) == 2:
        left_child, right_child = node.children
        factor = node.volume * np.abs((left_child.volume - right_child.volume) / node.volume)
        # print(node.volume, factor)
        factor += calc_volume_balance_factor(left_child)
        factor += calc_volume_balance_factor(right_child)
    else:
        raise Exception(f'Not a binary tree')
    return factor


def get_internal_nodes_volume_sum(node):
    if node.children is None or len(node.children) == 0:
        volume_sum = 0
    elif len(node.children) == 2:
        left_child, right_child = node.children
        # volume_sum = node.volume
        volume_sum = node.volume
        volume_sum += get_internal_nodes_volume_sum(left_child)
        volume_sum += get_internal_nodes_volume_sum(right_child)
    else:
        raise Exception(f'Not a binary tree')
    return volume_sum


def get_internal_nodes_size_sum(node):
    if node.children is None or len(node.children) == 0:
        volume_sum = 0
    elif len(node.children) == 2:
        left_child, right_child = node.children
        # volume_sum = node.volume
        volume_sum = len(node.node_set)
        volume_sum += get_internal_nodes_size_sum(left_child)
        volume_sum += get_internal_nodes_size_sum(right_child)
    else:
        raise Exception(f'Not a binary tree')
    return volume_sum


from igraph import Graph, EdgeSeq


def create_igraph_tree(root):
    G = Graph(directed=True)
    # 用于映射多叉树节点和igraph图中的顶点索引
    node_to_vertex = {}

    def add_node_to_graph(node):
        v = G.add_vertex()
        node_to_vertex[(node.value, node.depth)] = v
        # print(node_to_vertex)

    def add_edge_to_graph(parent, child):
        v_parent = node_to_vertex[(parent.value, parent.depth)]
        v_child = node_to_vertex[(child.value, child.depth)]
        G.add_edge(v_parent, v_child)

    def traverse_tree(node):
        for child in node.children:
            add_node_to_graph(child)
            add_edge_to_graph(node, child)
            traverse_tree(child)

    # 从根节点开始遍历多叉树
    add_node_to_graph(root)
    traverse_tree(root)
    # 找到根节点并记录其顶点索引
    root_vertex_index = node_to_vertex[(root.value, root.depth)].index

    return G, root_vertex_index


if __name__ == '__main__':
    layers = [{'root': {'583', '157', '603', '352', '74', '517', '532', '476', '548', '491', '513', '170', '361', '146',
                        '91', '22', '538', '141', '893', '914', '348', '503', '362', '886', '620', '174', '133', '447',
                        '350', '679', '536', '322', '84', '71', '99', 5717, '635', '85', '453', '857', '496', '161',
                        '149', '63', '368', '572', '899', '32', '848', '64', '28', '898', '610', '173', '817', '812',
                        '553', '615', '25', '499', '798', '880', '644', '918', '563', '605', '881', '897', '922', '813',
                        '518', '905', '868', '849', '693', '155', '408', '458', '547', '377', '49', '579', '505', '891',
                        '570', '120', '145', '519', '445', '909', '540', '468', '514', '526', '82', '906', '55', '677',
                        '591', '90', '129', '669', '872', '471', '522', '103', '530', '57', '473', '358', '566', '637',
                        '351', '698', '342', '19', '67', '181', '354', '61', '580', '574', '68', '43', '597', '497',
                        '30', '151', '89', '482', '689', '810', '442', '641', '54', '874', '655', '576', '130', '867',
                        '21', '508', '140', '618', '152', '81', '392', '40', '70', '799', '851', '588', '648', '825',
                        '4', '8', '829', '163', '188', '190', '594', '360', '105', 6483, '672', '124', '35', '59',
                        '822', '154', '187', '528', '51', '419', '195', '586', '815', '183', '381', '573', '336', '465',
                        '197', '502', 6020, '504', '908', '14', '39', '116', 5518, '148', '658', '811', '651', '869',
                        '710', '901', '894', '93', '143', '126', '673', '23', '533', '139', '622', '462', '395', '332',
                        '521', '210', '506', '804', '150', '467', '316', '331', '606', '341', '534', '697', '887', '17',
                        '479', '919', '706', '380', '112', '317', '843', '114', '441', '31', '3', '686', '495', '86',
                        '88', '196', '541', '42', '806', '509', '34', '692', '115', '802', '125', '363', '75', '448',
                        '542', '417', '36', '896', '879', '539', '537', '816', '621'}},
              {5518: {5507, 5509, 5510, 5512, 5516},
               6020: {6016, '235', '746', '727', '737', '719', '788', '722', '744', 6002, '734', '718', 6013, 6014,
                      6015}, 6483: {6481, 6474, 6476, 6478}, 5717: {5712, 5716, 5711}}, {
                  5512: {'9', '238', '38', '611', '98', '694', '595', '678', '106', '711', '634', '687', '234', '58',
                         '607', '47', '306', '653', '650', '237', '56', '111', '690', '707', '258', '50', '95', '46',
                         '662', '676', '612', '699', '599', '102', '701', '674', '684', '631', '37', '656', '600', '45',
                         '702', '628', '278', '312', '96', '596', '614', '670', '647', '619', '613', '661'},
                  5507: {'850', '854', '904', '827', '833', '838', '800', '902', '846', '805', '803', '903', '826',
                         '885', '823', '860', '844', '801', '858', '925', '832', '845', '856', '819', '834', '900',
                         '836', '920', '828', '863'},
                  5509: {'726', '260', '283', '274', '282', '236', '289', '292', '293', '218', '765', '785', '281',
                         '792', '291', '794', '261', '793', '233', '255', '296', '243', '287', '761', '257', '277',
                         '254', '759', '714', '787', '219', '760', '756', '242', '290', '275', '247'},
                  5510: {'818', '97', '41', '94', '107', '16', '875', '10', '18', '15', '5', '20', '44', '824', '2',
                         '48', '108', '104', '29', '13', '80', '820', '101', '72', '12', '11', '113', '26', '78', '27',
                         '110', '109', '883', '76', '33', '882', '60', '53', '52', '1', '807', '6'},
                  5516: {'888', '916', '333', '973', '335', '92', '924', '73', '66', '866', '877', '389', '87', '329',
                         '340', '864', '83', '917', '842', '781', '847', '77', '69', '79', '378', '876', '865', '870',
                         '940', '953', '941', '325', '364', '821', '65', '946', '789', '861', '892', '921', '878',
                         '974', '938', '912', '852', '971', '62', '326', '855', '913', '327', '923', '969', '945',
                         '328', '835', '7', '889', '841', '907', '337', '814', '808', '324', '330', '730', '809', '991',
                         '871', '840', '837', '911', '910', '873', '915', '345', '926', '937', '374', '895', '347',
                         '830', '853', '365', '0', '890', '859', '831', '24', '100', '724', '884', '338', '862', '839',
                         '382'},
                  6013: {'988', '954', '976', '970', '958', '956', '952', '975', '962', '989', '992', '999', '950',
                         '942', '930', '939', '990', '959', '997', '943', '965', '998', '964', '957', '944', '960',
                         '966', '994', '967', '951', '980', '955', '972', '936', '968'},
                  6014: {'751', '784', '776', '735', '750', '767', '774', '739', '748', '740', '743', '764', '745',
                         '771', '738', '770', '741', '725', '752', '749', '772', '780', '729', '768', '731', '753',
                         '769'},
                  6015: {'728', '773', '786', '762', '754', '715', '778', '742', '736', '791', '732', '747', '723',
                         '757', '721', '782', '717', '766', '797', '733', '716', '758', '763', '720', '777', '779',
                         '713', '775', '796', '795', '783', '790', '755'},
                  6016: {'280', '297', '295', '313', '240', '215', '273', '314', '284', '245', '299', '303', '253',
                         '226', '229', '217', '310', '301', '285', '232', '212', '213', '220', '223', '216', '228',
                         '308', '304', '214', '231', '251', '222'},
                  6002: {'307', '286', '268', '311', '264', '270', '221', '225', '156', '302', '263', '269', '279',
                         '266', '267'},
                  6481: {'996', '625', '667', '665', '987', '638', '696', '949', '632', '666', '683', '682', '609',
                         '985', '590', '657', '626', '986', '709', '680', '675', '623', '995', '589', '977', '660',
                         '630', '602', '627', '616', '593', '654', '645', '705', '604', '961', '668', '617', '691',
                         '948', '659', '963', '639', '681', '652', '981', '624', '601', '688', '929', '592', '649',
                         '671', '633', '935', '928', '984', '927', '640', '993', '979', '704', '978', '934', '643',
                         '932', '646', '983', '703', '636', '629', '664', '712', '933', '608', '708', '598', '663',
                         '642', '700', '931', '947', '982', '695', '685'},
                  6474: {'567', '562', '565', '552', '560', '550', '587', '186', '494', '569', '558', '543', '160',
                         '184', '182', '578', '581', '123', '568', '557', '178', '179', '551', '577', '561', '121',
                         '138', '564', '185', '554', '500', '117'},
                  6478: {'524', '349', '346', '211', '132', '545', '355', '383', '393', '527', '191', '385', '571',
                         '339', '353', '357', '531', '376', '323', '516', '319', '535', '512', '344', '556', '321',
                         '511', '373', '388', '507', '118', '525', '523', '207', '369', '119', '366', '320', '529',
                         '375', '546', '343', '370', '501', '367', '318', '391', '387', '315', '585', '356', '498',
                         '549', '493', '584', '544', '371', '390', '492', '359', '575', '386', '582', '520', '555',
                         '192', '334', '394', '372', '168', '384', '559', '510', '515', '379'},
                  6476: {'193', '166', '205', '180', '142', '208', '171', '203', '175', '128', '172', '198', '209',
                         '127', '147', '153', '200', '164', '176', '167', '144', '204', '206', '131', '137', '165',
                         '201', '158', '189', '134', '162', '202', '194', '169', '159', '177', '135', '136', '199',
                         '122'},
                  5712: {'449', '404', '415', '413', '412', '451', '407', '439', '416', '411', '414', '430', '437',
                         '424', '457', '409', '405'},
                  5711: {'239', '259', '246', '250', '305', '249', '252', '244', '298', '262', '276', '230', '248',
                         '294', '224', '309', '288', '227', '241', '300', '265', '271', '256', '272'},
                  5716: {'484', '489', '418', '464', '444', '421', '440', '488', '422', '436', '406', '480', '470',
                         '490', '433', '420', '463', '472', '474', '460', '426', '477', '487', '450', '400', '410',
                         '401', '478', '423', '402', '446', '398', '429', '431', '435', '454', '456', '455', '486',
                         '466', '452', '399', '459', '461', '428', '483', '434', '427', '485', '469', '397', '425',
                         '475', '432', '438', '481', '403', '443', '396'}}]
    noweighted_G = nx.read_edgelist('graph.in', nodetype=str, data=False)
    # noweighted_G = nx.complete_graph(18)
    # # noweighted_G = nx.read_edgelist('hsbm_data.txt', nodetype = str, data = False)
    # # noweighted_G = nx.read_edgelist('output.txt', nodetype = str, data = False)
    G = nx.Graph()
    G.add_weighted_edges_from([(u, v, 1) for u, v in noweighted_G.edges()])
    root = build(layers)

    dfs(root, G)
    print(get_k_layer_ground_truth(root, 1))
    print(get_k_layer_ground_truth(root, 2))
    print(root.volume, 'G.volume')

    # print(len(root.node_set), root.node_set)
    # for child in root.children:
    #     print(len(child.node_set), child.value, child.node_set)
    #     print(child.g_val, child.volume, child.cut_val, child.depth)

    print('SE', calc_SE(root, root.volume))
    print('Das', calc_Das(root))
    # print('depth', get_leaf_depth(root, depth = 0))
    print('HME', calc_HME(root, root.volume))
    # print(calc_depth_balance_factor(get_leaf_depth(root, depth = 0)))
    # print(calc_size_balance_factor(root), get_internal_nodes_volume_sum(root), calc_size_balance_factor(root) / get_internal_nodes_volume_sum(root))
    # for k in range(len(layers)):
    #     print(get_k_layer_ground_truth(root, k))
    # with open('tmp_res.txt', 'w') as f:
    #     print_tree(root, 0, f)
