import fast_algorithm_for_2OC
import networkx as nx
import datetime
import random
import time
import numpy as np
from scipy.spatial.distance import pdist, squareform
from sklearn.decomposition import PCA
from sklearn.datasets import fetch_openml
import subprocess
import re
import os


def read_network(path):
    f = open(path, "r").readlines()
    edges = set()
    nodes = set()
    for edge in f:
        a, b = edge.split()
        edges.add((int(a), int(b)))
        nodes.add(int(a))
        nodes.add(int(b))
    return list(nodes), list(edges)


def calc_delta(G: nx.Graph, cluster):
    new_G: nx.Graph = G
    new_C1, new_C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(G)
    S1 = set(new_C1)
    S2 = set(new_C2)
    B = S1.intersection(S2)
    A = S1 - S2
    C = S2 - S1

    H1 = nx.volume(new_G, cluster) / 2 * len(new_G)
    wA = nx.volume(new_G, A) / 2
    wB = nx.volume(new_G, B) / 2
    wC = nx.volume(new_G, C) / 2
    wAB = 0
    wBC = 0
    wAC = 0
    for u, v in G.edges(data=False):
        w = 1
        if 'weight' in G[u][v]:
            w = G[u][v]['weight']
        if (u in A and v in B) or (u in B and v in A):
            wAB += w
        if (u in A and v in C) or (u in C and v in A):
            wAC += w
        if (u in B and v in C) or (u in C and v in B):
            wBC += w

    H2 = (wA + wAB) * (len(A) + len(B)) + (wC + wBC) * (len(B) + len(C)) + wAC * len(cluster) + 1 / 2 * wB * (
            len(cluster) + len(B))
    delta = (H1 - H2) / H1
    return delta


def find_k_clusters(G: nx.Graph, k, path="data/ex.txt"):
    C1, C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(G)
    result = [C1, C2]
    cnt_cluster = 2

    dic = {}
    for idx, cluster in enumerate(result):
        new_G: nx.Graph = G.subgraph(cluster)
        print(idx)
        delta = calc_delta(new_G, cluster)
        dic[tuple(cluster)] = delta
    print("step2")
    while cnt_cluster < k:
        delta_max = 0
        id_max = 0
        for idx, cluster in enumerate(result):
            delta = dic[tuple(cluster)]
            if delta_max < delta:
                delta_max = delta
                id_max = idx

        new_G: nx.Graph = G.subgraph(result[id_max])
        del result[id_max]
        new_C1, new_C2 = fast_algorithm_for_2OC.greedy_select_overlap_greedy(new_G)
        result.append(new_C1)
        result.append(new_C2)
        dic[tuple(new_C1)] = calc_delta(G.subgraph(new_C1), new_C1)
        dic[tuple(new_C2)] = calc_delta(G.subgraph(new_C2), new_C2)
        cnt_cluster += 1
        print(cnt_cluster)
        for item in result:
            print(len(item))
    print(result)
    for item in result:
        print(len(item), item)
    f = open(path, "w")
    for c in result:
        for node in c:
            f.write(f"{node} ")
        f.write('\n')
    f.close()


def new_id(start_val: int = 1):
    """
    利用yield分配Node的ID
    :return:
    """
    i = start_val
    while True:
        yield i
        i += 1


id_generator = 1

def create_k_clusters_3(k: int, num_inner_node: list, near_overlap: list, p1=0.7, p2=0.05,
                        edges_path="data/edges.txt",
                        gt_path="data/gt.txt"):
    gt = []
    id_generator = new_id()
    node2communities = dict()
    overlapcommunity2nodes = dict()
    for i in range(k):
        lis = []
        num = random.randint(num_inner_node[0], num_inner_node[1])
        for j in range(num):
            node = str(next(id_generator))
            lis.append(node)
            node2communities[node] = (i, -1)
        gt.append(lis)
    for i in range(k):
        for j in range(i + 1, k):
            cnt = random.randint(near_overlap[0], near_overlap[1])
            overlapcommunity2nodes[(i, j)] = set()
            for l in range(cnt):
                node = (str(next(id_generator)))
                gt[i].append(node)
                gt[j].append(node)
                node2communities[node] = (i, j)
                overlapcommunity2nodes[(i, j)].add(node)
    edges = set()
    # p1
    for c in gt:
        for i in range(0, len(c) - 1):
            for j in range(i + 1, len(c)):
                if random.random() < p1:
                    edges.add((c[i], c[j]))
    # p2
    for i in range(0, len(gt)):
        for j in range(i + 1, len(gt)):
            c1 = gt[i]
            c2 = gt[j]
            for u in c1:
                for v in c2:
                    if u == v:
                        continue
                    if random.random() < p2:
                        edges.add((u, v))
    edges = sorted(list(edges))
    f = open(edges_path, 'w')
    for u, v in edges:
        f.write(f"{u} {v}\n")
    f.close()
    f = open(gt_path, "w")
    for c in gt:
        f.write(' '.join(c))
        f.write('\n')
    f.close()


def create_graph_from_edge_weights(edge_weights, k):
    num_nodes = edge_weights.shape[0]
    G = nx.Graph()

    for i in range(num_nodes):
        edges = [(i, j, edge_weights[i, j]) for j in range(num_nodes) if j != i]
        sorted_edges = sorted(edges, key=lambda x: x[2], reverse=True)
        selected_edges = sorted_edges[:k]
        G.add_weighted_edges_from(selected_edges)

    return G


def load_minst(download=True):
    if download is True:
        mnist = fetch_openml('mnist_784')  # 加载数据集
        X, y = mnist["data"], mnist["target"]
        X = np.array(X)
        y = np.array(y).astype(int)

        # 将特征写入文件
        with open('mnist_features.txt', 'w') as f:
            for row in X:
                f.write(' '.join(map(str, row)) + '\n')

        # 将标签写入文件
        with open('mnist_labels.txt', 'w') as f:
            for label in y:
                f.write(str(label) + '\n')
    else:
        X = np.loadtxt('mnist_features.txt')
        y = np.loadtxt('mnist_labels.txt')
    return X, y


def save_data(file_path, data):
    with open(file_path, 'w') as file:
        for item in data:
            lis = item.tolist()
            if type(lis) is float or type(lis) is int:
                lis = [lis]
            file.write(' '.join(map(str, lis)))
            file.write('\n')


def save_graph(file_path, G: nx.Graph):
    with open(file_path, 'w') as file:
        for (u, v) in G.edges(data=False):
            w = 1
            if 'weight' in G[u][v]:
                w = G[u][v]['weight']
            file.write('{} {} {}\n'.format(u, v, G[u][v]['weight']))


if __name__ == "__main__":
    st_time = time.strftime('%Y-%m-%d %H-%M-%S %A')
    num_clusters = 4
    os.makedirs('log', exist_ok=True)
    log_path = "log/log-eff-k-{}-{}.txt".format(num_clusters, time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime()))
    fp = open(log_path, mode="w", encoding="utf-8")

    suffix = '17'
    os.makedirs('data', exist_ok=True)
    edges_path = "data/edges_{}.txt".format(suffix)
    gt_path = "data/gt_{}.txt".format(suffix)
    ex_path = "data/ex_{}.txt".format(suffix)

    data, label = load_minst(download=False)

    target_labels = [1, 7, 3, 8]
    filtered_data = []
    filtered_label = []
    for vec, lb in zip(data, label):
        if int(lb) in target_labels:
            filtered_data.append(vec)
            filtered_label.append(lb)
    save_data('mnist_features_{}.txt'.format(suffix), filtered_data)
    save_data('mnist_labels_{}.txt'.format(suffix), filtered_label)

    pca = PCA(n_components=20)
    pca_data = pca.fit_transform(filtered_data)
    distances = pdist(pca_data)
    distance_matrix = squareform(distances)
    sigma = 2500.0
    edge_weights = np.exp(-distance_matrix ** 2 / (2 * sigma ** 2))

    G = create_graph_from_edge_weights(edge_weights, 100)
    save_graph('data/mnist_total_{}_edges.txt'.format(suffix), G)

    find_k_clusters(G, num_clusters, path=ex_path)
    end_time = time.strftime('%Y-%m-%d %H-%M-%S %A')
    # print(len(sub_nodes), len(sub_edges))
    print("node size:{}, number of edges:{}".format(len(G.nodes), len(G.edges)), file=fp)
    print("start time:{}, end time:{}".format(st_time, end_time), file=fp)
    time_difference = datetime.datetime.strptime(end_time, '%Y-%m-%d %H-%M-%S %A') - \
                      datetime.datetime.strptime(st_time, '%Y-%m-%d %H-%M-%S %A')
    seconds_difference = time_difference.total_seconds()
    print("seconds:{}".format(seconds_difference), file=fp)

    # osbm
    # random.seed(23)
    # np.random.seed(23)
    # num_inner_node = [20, 20]
    # near_overlap = [2, 2]
    # print("inner noode{}, near_overlap{}".format(num_inner_node, near_overlap), file=fp)
    # for p1 in np.arange(0.6, 0.6 + 0.01, 0.05):
    #     for p2 in np.arange(0.1, 0.1 + 0.01, 0.05):
    #         for id12 in range(2, 3):
    #             st_time = time.strftime('%Y-%m-%d %H-%M-%S %A')
    #             edges_path = "data/{}/edges_p1{:.2f}_p2{:.2f}_{}.txt".format("test", p1, p2, 5)
    #             gt_path = "data/{}/gt_p1{:.2f}_p2{:.2f}_{}.txt".format("test", p1, p2, 5)
    #             ex_path = "data/{}/ex_p1{:.2f}_p2{:.2f}_{}.txt".format("test", p1, p2, 5)
    #             # create_k_clusters_3(k=num_clusters, num_inner_node=[100, 2000], near_overlap=[0, 100], p1=p1, p2=p2,
    #             #                     edges_path=edges_path, gt_path=gt_path)
    #             np.random.seed(23)
    #             create_k_clusters_3(k=num_clusters, num_inner_node=num_inner_node, near_overlap=near_overlap, p1=p1,
    #                                 p2=p2,
    #                                 edges_path=edges_path, gt_path=gt_path)
    #             # exit(0)
    #             nodes, edges = read_network(edges_path)
    #             print("=================", file=fp)
    #             print(p1, p2, id12, len(nodes), len(edges), file=fp)
    #             # gt_path = "stanford_dataset/com-amazon.top5000.cmty.txt"
    #             # gt_path = "stanford_dataset/com-dblp.top5000.cmty.txt"
    #             # S = set()
    #             # f = open(gt_path, "r").readlines()
    #             # for item in f:
    #             #     for u in list(item.split()):
    #             #         S.add(int(u))
    #             # sub_nodes = []
    #             # sub_edges = []
    #             # for u in nodes:
    #             #     if u in S:
    #             #         sub_nodes.append(u)
    #             # for u, v in edges:
    #             #     if u in S and v in S:
    #             #         sub_edges.append((u, v))
    #             # print(len(sub_nodes), len(sub_edges))
    #             # exit(0)
    #             # nodes, edges = read_true_network("social_data/amazon.eg2")
    #
    #             G = nx.Graph()
    #             np.random.shuffle(edges)
    #             G.add_weighted_edges_from([(str(u), str(v), 1) for u, v in edges])
    #             # G.add_nodes_from(nodes)
    #
    #             find_k_clusters(G, num_clusters, path=ex_path)
    #             end_time = time.strftime('%Y-%m-%d %H-%M-%S %A')
    #             # print(len(sub_nodes), len(sub_edges))
    #             print("node size:{}, number of edges:{}".format(len(nodes), len(edges)), file=fp)
    #             print("start time:{}, end time:{}".format(st_time, end_time), file=fp)
    #             time_difference = datetime.datetime.strptime(end_time, '%Y-%m-%d %H-%M-%S %A') - \
    #                               datetime.datetime.strptime(st_time, '%Y-%m-%d %H-%M-%S %A')
    #             seconds_difference = time_difference.total_seconds()
    #             print("seconds:{}".format(seconds_difference), file=fp)
    #             # result = subprocess.check_output(
    #             #     './Overlapping-NMI-master/onmi.exe ./data/9/ex_p10.1_p20.1_0.txt ./data/9/gt_p10.1_p20.1_0.txt',
    #             #     shell=False)
    #             # print('./Overlapping-NMI-master/onmi.exe {} {}'.format(ex_path, gt_path))
    #             result = subprocess.check_output(
    #                 './Overlapping-NMI-master/onmi.exe {} {}'.format(ex_path, gt_path),
    #                 shell=False)
    #
    #             # result = subprocess.call('dir', shell=True)
    #             print("result: {}".format(result), file=fp)
    #             matches = re.findall(r'\d*\.\d+|\d+', result.decode('utf-8'))
    #             print(matches, file=fp)
