import fast_algorithm_for_2OC
import networkx as nx
import datetime
import random
import time
import numpy as np
import subprocess
import re
import os
import matplotlib.pyplot as plt
from HOC import HOC


def read_network(path):
    f = open(path, "r").readlines()
    edges = set()
    nodes = set()
    for edge in f:
        a, b = edge.split()
        edges.add((int(a), int(b)))
        nodes.add(int(a))
        nodes.add(int(b))
    return list(nodes), list(edges)


def calc_delta(G: nx.Graph, cluster):
    new_G: nx.Graph = G
    new_C1, new_C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(G)
    S1 = set(new_C1)
    S2 = set(new_C2)
    B = S1.intersection(S2)
    A = S1 - S2
    C = S2 - S1

    H1 = nx.volume(new_G, cluster) / 2 * len(new_G)
    wA = nx.volume(new_G, A) / 2
    wB = nx.volume(new_G, B) / 2
    wC = nx.volume(new_G, C) / 2
    wAB = 0
    wBC = 0
    wAC = 0
    for u, v in G.edges(data=False):
        w = 1
        if 'weight' in G[u][v]:
            w = G[u][v]['weight']
        if (u in A and v in B) or (u in B and v in A):
            wAB += w
        if (u in A and v in C) or (u in C and v in A):
            wAC += w
        if (u in B and v in C) or (u in C and v in B):
            wBC += w

    H2 = (wA + wAB) * (len(A) + len(B)) + (wC + wBC) * (len(B) + len(C)) + wAC * len(cluster) + 1 / 2 * wB * (
            len(cluster) + len(B))
    delta = (H1 - H2) / H1
    return delta


def new_id(start_val: int = 1):
    """
    利用yield分配Node的ID
    :return:
    """
    i = start_val
    while True:
        yield i
        i += 1


id_generator = 1


def find_k_clusters(G: nx.Graph, k, batch_size=0.02, path="data/ex.txt", overlap=True):
    dag_edges = dict()
    id_generator = new_id()

    name_root = "internal_" + str(next(id_generator))
    if k > 1:
        dag_edges[name_root] = set()
        C1, C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(G, batch_size, overlap)
        name_C1 = "internal_" + str(next(id_generator))
        dag_edges[name_C1] = set(C1)
        dag_edges[name_root].add(name_C1)
        name_C2 = "internal_" + str(next(id_generator))
        dag_edges[name_C2] = set(C2)
        dag_edges[name_root].add(name_C2)
        result = [(C1, name_C1), (C2, name_C2)]
        cnt_cluster = 2

        dic = {}
        for i, (cluster, name_cluster) in enumerate(result):
            new_G: nx.Graph = G.subgraph(cluster)
            delta = calc_delta(new_G, cluster)
            dic[tuple(cluster)] = (delta, name_cluster)
        # print("step2")
        while cnt_cluster < k:
            delta_max = -1e9
            name_max = ""
            id_max = 0
            for i, (cluster, name_cluster) in enumerate(result):
                (delta, name_cluster) = dic[tuple(cluster)]
                if delta_max < delta:
                    delta_max = delta
                    name_max = name_cluster
                    id_max = i

            new_G: nx.Graph = G.subgraph(result[id_max][0])
            del result[id_max]
            # for node in new_G.nodes():
            #     if node in dag_edges[name_max]:
            #         dag_edges[name_max].remove(node)
            dag_edges[name_max] = set()

            new_C1, new_C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(new_G, batch_size, overlap)
            name_C1 = "internal_" + str(next(id_generator))
            dag_edges[name_C1] = set(new_C1)
            dag_edges[name_max].add(name_C1)
            name_C2 = "internal_" + str(next(id_generator))
            dag_edges[name_C2] = set(new_C2)
            dag_edges[name_max].add(name_C2)
            result.append((new_C1, name_C1))
            result.append((new_C2, name_C2))
            dic[tuple(new_C1)] = (calc_delta(G.subgraph(new_C1), new_C1), name_C1)
            dic[tuple(new_C2)] = (calc_delta(G.subgraph(new_C2), new_C2), name_C2)
            cnt_cluster += 1
        # for item in result:
        #     print(len(item), item)
    else:
        result = [(G.nodes(), name_root)]
        dag_edges[name_root] = set(G.nodes())
    # print(dag_edges)
    f = open(path, "w")
    for c, name in result:
        for node in c:
            f.write(f"{node} ")
        f.write('\n')
    f.close()
    return result, dag_edges


def generate_overlapping_block_model(OC, probs, edges_path="data/edges.txt", gt_path="data/gt.txt", sparse=False):
    # OC: list: 每个cluster点
    # probs: dict(bitset->prob): bitset表示其overlap设计点集，空集表示外连边

    G = nx.Graph()
    for C in OC:
        for node in C:
            G.add_node(node)

    if 0 in probs.keys():
        nodes = list(G.nodes())
        if sparse:
            num_nodes = len(nodes)
            num_edges = int(probs[0] * num_nodes ** 2 / 2)
            for _ in range(num_edges):
                u = nodes[np.random.randint(low=num_nodes)]
                v = nodes[np.random.randint(low=num_nodes)]
                if u != v:
                    G.add_edge(u, v)
        else:
            for i, u in enumerate(nodes):
                for j, v in enumerate(nodes):
                    if i > j:
                        if np.random.rand() < probs[0]:
                            G.add_edge(u, v)
                    else:
                        break
    for state in probs.keys():
        if state == 0:
            continue
        nodes = set(G.nodes())
        i = 0
        tmp = state
        for _ in range(len(OC)):
            if tmp & 1:
                nodes &= set(OC[i])
            # else:
            #     nodes = nodes - set(OC[i])
            tmp >>= 1
            i += 1
        # print(nodes, probs[state])
        nodes = list(nodes)
        if sparse:
            num_nodes = len(nodes)
            num_edges = int(probs[state] * num_nodes ** 2 / 2)
            for _ in range(num_edges):
                u = nodes[np.random.randint(low=num_nodes)]
                v = nodes[np.random.randint(low=num_nodes)]
                if u != v:
                    G.add_edge(u, v)
        else:
            for i, u in enumerate(nodes):
                for j, v in enumerate(nodes):
                    if i > j and np.random.rand() < probs[state]:
                        G.add_edge(u, v)
    f = open(edges_path, 'w')
    for u, v in G.edges(data=False):
        f.write(f"{u} {v}\n")
    f.close()
    f = open(gt_path, "w")
    for c in OC:
        f.write(' '.join([str(item) for item in c]))
        f.write('\n')
    f.close()
    return G


if __name__ == "__main__":
    repeat_times = 5
    # cmp_times = []
    # cmp_costs = []
    # cmp_nmis = []
    # num_nodes_range = list(range(500, 5500, 500))
    # for num_nodes in num_nodes_range:
    #     k = 4
    #     num_in = int(0.9 * num_nodes)
    #     data = list(range(0, num_in))
    #     sizes = []
    #     remaining = num_in
    #     for _ in range(k - 1):
    #         size = random.randint(num_in // k - num_in // 100, num_in // k + num_in // 100)
    #         sizes.append(size)
    #         remaining -= size
    #     sizes.append(remaining)
    #     OC = []
    #     start_index = 0
    #     for size in sizes:
    #         OC.append(data[start_index:start_index + size])
    #         start_index += size
    #     dic = dict()
    #     tmp = 0
    #     for i in range(k):
    #         for j in range(i + 1, k):
    #             dic[tmp] = (i, j)
    #             tmp += 1
    #     for node in range(int(0.9 * num_nodes), num_nodes):
    #         a, b = dic[node % tmp]
    #         OC[a].append(node)
    #         OC[b].append(node)
    #     for p_in, p_out in [(5e-1, 1e-3)]:
    #         print("============")
    #         print("num_nodes:{}, p_in:{}, p_out:{}".format(num_nodes, p_in, p_out))
    #         num_edges = []
    #         das_times = []
    #         das_costs = []
    #         das_nmis = []
    #         hoc_times = []
    #         hoc_costs = []
    #         hoc_nmis = []
    #         probs = {0: p_out}
    #         for i in range(k):
    #             probs[1 << i] = p_in
    #         for _ in range(repeat_times):
    #             # print(G)
    #             G = generate_overlapping_block_model(OC, probs)
    #             num_edges.append(len(G.edges()))
    #             now = time.time()
    #             result, dag_edges = find_k_clusters(G, k, overlap=False)
    #             das_times.append(time.time() - now)
    #             # print(das_times)
    #             dag = HOC()
    #             for u in dag_edges.keys():
    #                 for v in dag_edges[u]:
    #                     dag.add_edge(u, v)
    #             das_costs.append(dag.HOC_cost([(u, v, 1) for (u, v) in G.edges()]))
    #             nmi = subprocess.check_output(
    #                 './Overlapping-NMI-master/onmi.exe {} {}'.format("data/ex.txt", "data/gt.txt"),
    #                 shell=False)
    #             matches = re.findall(r'\d*\.\d+|\d+', nmi.decode('utf-8'))
    #             das_nmis.append(float(matches[-1]))
    #             now = time.time()
    #             result, dag_edges = find_k_clusters(G, k, overlap=True)
    #             hoc_times.append(time.time() - now)
    #             dag = HOC()
    #             for u in dag_edges.keys():
    #                 for v in dag_edges[u]:
    #                     dag.add_edge(u, v)
    #             hoc_costs.append(dag.HOC_cost([(u, v, 1) for (u, v) in G.edges()]))
    #             nmi = subprocess.check_output(
    #                 './Overlapping-NMI-master/onmi.exe {} {}'.format("data/ex.txt", "data/gt.txt"),
    #                 shell=False)
    #             matches = re.findall(r'\d*\.\d+|\d+', nmi.decode('utf-8'))
    #             hoc_nmis.append(float(matches[-1]))
    #         print("num_edges:{}".format(np.average(num_edges)))
    #         print("Dasgupta time:{}, cost:{}, nmi:{}".format(
    #             np.average(das_times),
    #             np.average(das_costs),
    #             np.average(das_nmis))
    #         )
    #         print("HOC time:{}, cost:{}, nmi:{}".format(
    #             np.average(hoc_times),
    #             np.average(hoc_costs),
    #             np.average(hoc_nmis))
    #         )
    #         cmp_times.append((das_times, hoc_times))
    #         cmp_costs.append((das_costs, hoc_costs))
    #         cmp_nmis.append((das_nmis, hoc_nmis))
    # plt.figure(figsize=(18, 6))
    #
    # plt.subplot(1, 3, 1)
    # data = np.array([das for (das, hoc) in cmp_times])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='o', color='blue',
    #              ecolor='lightgray', elinewidth=3, capsize=5)
    # data = np.array([hoc for (das, hoc) in cmp_times])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='*', color='green',
    #              ecolor='gray', elinewidth=3, capsize=5)
    # plt.xlabel('num of nodes')
    # plt.ylabel('times(s)')
    # plt.xticks(num_nodes_range)
    # plt.legend()
    #
    # plt.subplot(1, 3, 2)
    # data = np.array([[np.log(x) for x in das] for (das, hoc) in cmp_costs])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='o', color='blue',
    #              ecolor='lightgray', elinewidth=3, capsize=5)
    # data = np.array([[np.log(x) for x in hoc] for (das, hoc) in cmp_costs])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='*', color='green',
    #              ecolor='gray', elinewidth=3, capsize=5)
    # plt.xlabel('num of nodes')
    # plt.ylabel('log(HOC cost)')
    # plt.xticks(num_nodes_range)
    # plt.legend()
    #
    # plt.subplot(1, 3, 3)
    # data = np.array([das for (das, hoc) in cmp_nmis])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='o', color='blue',
    #              ecolor='lightgray', elinewidth=3, capsize=5)
    # data = np.array([hoc for (das, hoc) in cmp_nmis])
    # plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    # plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
    #              fmt='*', color='green',
    #              ecolor='gray', elinewidth=3, capsize=5)
    # plt.xlabel('num of nodes')
    # plt.ylabel('nmi')
    # plt.xticks(num_nodes_range)
    # plt.legend()

    cmp_times = []
    cmp_costs = []
    cmp_nmis = []
    num_nodes_range = list(range(10000, 110000, 10000))
    for num_nodes in num_nodes_range:
        k = 4
        num_in = int(0.9 * num_nodes)
        data = list(range(0, num_in))
        sizes = []
        remaining = num_in
        for _ in range(k - 1):
            size = random.randint(num_in // k - num_in // 100, num_in // k + num_in // 100)
            sizes.append(size)
            remaining -= size
        sizes.append(remaining)
        OC = []
        start_index = 0
        for size in sizes:
            OC.append(data[start_index:start_index + size])
            start_index += size
        dic = dict()
        tmp = 0
        for i in range(k):
            for j in range(i + 1, k):
                dic[tmp] = (i, j)
                tmp += 1
        for node in range(int(0.9 * num_nodes), num_nodes):
            a, b = dic[node % tmp]
            OC[a].append(node)
            OC[b].append(node)
        for p_in, p_out in [(2e-3, 2e-4)]:
            print("============")
            print("num_nodes:{}, p_in:{}, p_out:{}".format(num_nodes, p_in, p_out))
            num_edges = []
            das_times = []
            das_costs = []
            das_nmis = []
            hoc_times = []
            hoc_costs = []
            hoc_nmis = []
            for _ in range(repeat_times):
                probs = {0: p_out}
                for i in range(k):
                    probs[1 << i] = p_in
                G = generate_overlapping_block_model(OC, probs, sparse=True)
                # print(G)
                num_edges.append(len(G.edges()))
                now = time.time()
                result, dag_edges = find_k_clusters(G, k, overlap=False)
                das_times.append(time.time() - now)
                dag = HOC()
                for u in dag_edges.keys():
                    for v in dag_edges[u]:
                        dag.add_edge(u, v)
                das_costs.append(dag.HOC_cost([(u, v, 1) for (u, v) in G.edges()]))
                nmi = subprocess.check_output(
                    './Overlapping-NMI-master/onmi.exe {} {}'.format("data/ex.txt", "data/gt.txt"),
                    shell=False)
                matches = re.findall(r'\d*\.\d+|\d+', nmi.decode('utf-8'))
                das_nmis.append(float(matches[-1]))
                now = time.time()
                result, dag_edges = find_k_clusters(G, k, overlap=True)
                hoc_times.append(time.time() - now)
                dag = HOC()
                for u in dag_edges.keys():
                    for v in dag_edges[u]:
                        dag.add_edge(u, v)
                hoc_costs.append(dag.HOC_cost([(u, v, 1) for (u, v) in G.edges()]))
                nmi = subprocess.check_output(
                    './Overlapping-NMI-master/onmi.exe {} {}'.format("data/ex.txt", "data/gt.txt"),
                    shell=False)
                matches = re.findall(r'\d*\.\d+|\d+', nmi.decode('utf-8'))
                hoc_nmis.append(float(matches[-1]))
            print("num_edges:{}".format(np.average(num_edges)))
            print("Dasgupta time:{}, cost:{}, nmi:{}".format(
                np.average(das_times),
                np.average(das_costs),
                np.average(das_nmis))
            )
            print("HOC time:{}, cost:{}, nmi:{}".format(
                np.average(hoc_times),
                np.average(hoc_costs),
                np.average(hoc_nmis))
            )
            cmp_times.append((das_times, hoc_times))
            cmp_costs.append((das_costs, hoc_costs))
            cmp_nmis.append((das_nmis, hoc_nmis))

    num_nodes_range = [x // 10 ** 4 for x in num_nodes_range]
    plt.subplot(1, 3, 1)
    data = np.array([das for (das, hoc) in cmp_times])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='o', color='blue',
                 ecolor='lightgray', elinewidth=3, capsize=5)
    data = np.array([hoc for (das, hoc) in cmp_times])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='*', color='green',
                 ecolor='gray', elinewidth=3, capsize=5)
    plt.xlabel('num of nodes($10^4$)')
    plt.ylabel('times(s)')
    plt.xticks(num_nodes_range)
    plt.legend()

    plt.subplot(1, 3, 2)
    data = np.array([[np.log(x) for x in das] for (das, hoc) in cmp_costs])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='o', color='blue',
                 ecolor='lightgray', elinewidth=3, capsize=5)
    data = np.array([[np.log(x) for x in hoc] for (das, hoc) in cmp_costs])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='*', color='green',
                 ecolor='gray', elinewidth=3, capsize=5)
    plt.xlabel('num of nodes($10^4$)')
    plt.ylabel('$log_{10}(cost)$')
    plt.xticks(num_nodes_range)
    plt.legend()

    plt.subplot(1, 3, 3)
    data = np.array([das for (das, hoc) in cmp_nmis])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='o', linestyle='-', color='b', label='Dasgupta')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='o', color='blue',
                 ecolor='lightgray', elinewidth=3, capsize=5)
    data = np.array([hoc for (das, hoc) in cmp_nmis])
    plt.plot(num_nodes_range, np.mean(data, axis=1), marker='*', linestyle='-', color='green', label='HOC')
    plt.errorbar(num_nodes_range, np.mean(data, axis=1), yerr=np.std(data, axis=1),
                 fmt='*', color='green',
                 ecolor='gray', elinewidth=3, capsize=5)
    plt.xlabel('num of nodes($10^4$)')
    plt.ylabel('nmi')
    plt.xticks(num_nodes_range)
    plt.legend()

    plt.show()

    # num_in = 5000
    # k = 4
    # data = list(range(0, num_in))
    # sizes = []
    # remaining = num_in
    # for _ in range(k - 1):
    #     size = random.randint(num_in // k - num_in // 100, num_in // k + num_in // 100)
    #     sizes.append(size)
    #     remaining -= size
    # sizes.append(remaining)
    # OC = []
    # start_index = 0
    # for size in sizes:
    #     OC.append(data[start_index:start_index + size])
    #     start_index += size
    # dic = dict()
    # tmp = 0
    # for i in range(k):
    #     for j in range(i + 1, k):
    #         dic[tmp] = (i, j)
    #         tmp += 1
    # for node in range(20000, 20000 + num_in // 10):
    #     a, b = dic[node % tmp]
    #     OC[a].append(node)
    #     OC[b].append(node)
    # p_in, p_out = (6e-1, 1e-4)
    # probs = {0: p_out}
    # for i in range(k):
    #     probs[1 << i] = p_in
    # G = generate_overlapping_block_model(OC, probs)
    # k_values = range(1, 20)
    # res = [len(G.nodes()) * len(G.edges())]
    # times = [-1]
    # for k in k_values:
    #     if k > 1:
    #         print(k)
    #         now = time.time()
    #         result, dag_edges = find_k_clusters(G, k, overlap=True)
    #         times.append(time.time() - now)
    #         print(time.time() - now)
    #         dag = HOC()
    #         for u in dag_edges.keys():
    #             for v in dag_edges[u]:
    #                 dag.add_edge(u, v)
    #         cost = dag.HOC_cost([(u, v, 1) for (u, v) in G.edges()])
    #         res.append(cost)
    # plt.figure(figsize=(12, 6))
    # plt.subplot(1, 2, 1)
    # plt.plot(k_values, res, marker='o', linestyle='-', color='b')
    # plt.title('Plot of HOC cost')
    # plt.xlabel('k')
    # plt.ylabel('HOC')
    # plt.xticks(k_values)  # 设置 x 轴刻度
    #
    # plt.subplot(1, 2, 2)
    # plt.plot(k_values[1:], times[1:], marker='o', linestyle='-', color='r')
    # plt.title('Execution Time of HOC')
    # plt.xlabel('k')
    # plt.ylabel('Time(s)')
    # plt.xticks(k_values)
    # plt.grid()
    # plt.grid()
    # plt.show()
