# Copyright (c) 2022 Tianyu Wen
# Licensed under the MIT License.

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

inf = 1e8


def get_edges(adj):
    edges = []
    num_nodes = adj.shape[0]
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if adj[i, j] == 1:
                edges.append(tuple([i, j]))
    return edges


def get_SPD(adj):
    num_nodes = adj.shape[0]
    D = np.full_like(adj, inf)
    for node in range(num_nodes):
        D[node][node] = 0
        for other_node in range(num_nodes):
            if other_node == node:
                continue
            if adj[node][other_node]:
                D[node][other_node] = 1

    # Floyd Algorithm
    for v in range(num_nodes):
        for i in range(num_nodes):
            for j in range(num_nodes):
                if D[i][v] + D[v][j] < D[i][j]:
                    D[i][j] = D[i][v] + D[v][j]

    return D


def get_nx_graph(num_nodes, edges):
    G = nx.Graph()
    G.add_nodes_from(list(range(num_nodes)))
    for edge in edges:
        G.add_edge(edge[0], edge[1])
    return G


def get_all_nonisomorphic(num_nodes):
    f = open('./data/graph' + str(num_nodes) + 'cA.txt')
    graphs = []
    while f.readline():
        A = []
        for i in range(num_nodes):
            row = f.readline()
            row = [int(row[i]) for i in range(num_nodes)]
            A.append(row)
        f.readline()
        graphs.append(A)
    return np.array(graphs)


def get_srgs_by_file(file_num):
    f = open('./data/srg' + str(file_num) + '.txt')
    graphs = []
    info = f.readline().split(sep=',')
    dim, degree, lamb, mu, num_graphs = int(info[0]), int(info[1]), int(info[2]), int(info[3]), int(info[4])
    print('dim={}, degree={}, lamb={}, mu={}, num_graphs={}'.format(dim, degree, lamb, mu, num_graphs))

    while f.readline():
        A = []
        for i in range(dim):
            row = f.readline()
            row = [int(row[i]) for i in range(dim)]
            A.append(row)
        graphs.append(A)
    return dim, np.array(graphs)


def get_SPD_subgraph_V_E(adj, SPD):
    num_nodes = adj.shape[0]
    edges = get_edges(adj)
    G = get_nx_graph(num_nodes, edges)
    nodes_list = range(num_nodes)

    SPD_subgraph_V_num = np.eye(num_nodes)
    SPD_subgraph_E_num = np.zeros_like(adj)

    for source in nodes_list:
        for target in nodes_list:
            V_in_subgraph = []
            SPD_len = SPD[source][target]
            if source == target:
                continue
            for node in nodes_list:
                if SPD[source][node] + SPD[node][target] == SPD_len:
                    V_in_subgraph.append(node)
            subgraph = G.subgraph(V_in_subgraph)
            SPD_subgraph_V_num[source][target] = len(V_in_subgraph)
            SPD_subgraph_E_num[source][target] = subgraph.size()

    return SPD_subgraph_V_num.astype(np.int), SPD_subgraph_E_num
