# Copyright (c) 2022 Tianyu Wen
# Licensed under the MIT License.

from collections import Counter
from utils import *
import copy


class WL_test:
    def __init__(self):
        self.cur_max = 0
        self.unhashed_list = []
        self.hashed_list = []

    def iter(self, adj, L, num_iter):
        labels = copy.deepcopy(L)
        if tuple([labels[0]]) not in self.unhashed_list:
            self.cur_max += 1
            self.unhashed_list.append(tuple([labels[0]]))
            self.hashed_list.append(self.cur_max)
        num_nodes = adj.shape[0]
        for i in range(num_iter):
            new_labels = []
            for node in range(num_nodes):
                neighbour_labels = []
                for other_node in range(num_nodes):
                    if other_node == node:
                        continue
                    if adj[node][other_node]:
                        neighbour_labels.append(labels[other_node])
                if neighbour_labels:
                    aggr_label = tuple([labels[node], Counter(neighbour_labels)])
                else:
                    aggr_label = tuple([labels[node]])

                if aggr_label not in self.unhashed_list:
                    self.cur_max += 1
                    self.unhashed_list.append(aggr_label)
                    self.hashed_list.append(self.cur_max)
                new_labels.append(self.hashed_list[self.unhashed_list.index(aggr_label)])

            for j in range(num_nodes):
                labels[j] = new_labels[j]

        return labels

    def reset_hash(self):
        self.cur_max = 0
        self.unhashed_list = []
        self.hashed_list = []

    def iso_test2(self, adj1, adj2, L1, L2, num_iter):
        self.reset_hash()
        new_L1 = self.iter(adj1, L1, num_iter)
        new_L2 = self.iter(adj2, L2, num_iter)
        return Counter(new_L1) == Counter(new_L2)


class GT_test:
    def __init__(self, RSE):
        self.cur_max = 1
        self.unhashed_list = []
        self.hashed_list = []
        self.RSE = RSE

    def iter(self, adj, L, num_iter):
        labels = copy.deepcopy(L)
        num_nodes = adj.shape[0]
        SPD = get_SPD(adj)
        if self.RSE == 'SPIS':
            SPD_subgraph_V_num, SPD_subgraph_E_num = get_SPD_subgraph_V_E(adj, SPD)
        for i in range(num_iter):
            new_labels = []
            for node in range(num_nodes):
                neighbour_labels = []
                for other_node in range(num_nodes):
                    if self.RSE == 'SPD':
                        RPE = SPD[node][other_node]
                    elif self.RSE == 'SPIS':
                        RPE = tuple([SPD[node][other_node], SPD_subgraph_V_num[node][other_node], SPD_subgraph_E_num[node][other_node]])
                    else:
                        raise NotImplemented
                    neighbour_labels.append(tuple([labels[other_node], RPE]))
                aggr_label = Counter(neighbour_labels)
                if aggr_label not in self.unhashed_list:
                    self.cur_max += 1
                    self.unhashed_list.append(aggr_label)
                    self.hashed_list.append(self.cur_max)
                new_labels.append(self.hashed_list[self.unhashed_list.index(aggr_label)])

            for j in range(num_nodes):
                labels[j] = new_labels[j]

        return labels

    def reset_hash(self):
        self.cur_max = 1
        self.unhashed_list = []
        self.hashed_list = []

    def iso_test2(self, adj1, adj2, L1, L2, num_iter):
        self.reset_hash()
        new_L1 = self.iter(adj1, L1, num_iter)
        new_L2 = self.iter(adj2, L2, num_iter)
        return Counter(new_L1) == Counter(new_L2)
