from B_layer import FLGnnConv
from DataSetLoader import karatecClub
from torch import nn
from torch_geometric.nn import GCN
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils.convert import to_networkx
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

device = "cuda"
class Model(nn.Module):

    def __init__(self, in_channel: int):
        super().__init__()
        hidden = 128
        self.lin = nn.Linear(in_features=in_channel, out_features=hidden)
        self.f1 = FLGnnConv(in_channels=hidden, out_channels=hidden, norm=False, order=0, num_mf=3, windows_size=5, stride_size=5, extract_ratio=1, cross=0.9)
        self.f2 = FLGnnConv(in_channels=hidden, out_channels=hidden, norm=False, order=0, num_mf=3, windows_size=5, stride_size=5,  extract_ratio=1, cross=0.9)
        self.f3 = FLGnnConv(in_channels=hidden, out_channels=hidden, norm=False, order=0, num_mf=3, windows_size=5, stride_size=5, extract_ratio=1, cross=0.9)
        self.md = nn.ModuleList([self.f1, self.f2, self.f3])

    def forward(self, x, edge_index, iteration):
        x = self.lin(x)
        for e in range(iteration):
            print(e)
            for m in self.md:
                x = m(x, edge_index)
        return x

class GCCN(nn.Module):

    def __init__(self, in_channel: int):
        super().__init__()
        hidden = 4
        self.lin = nn.Linear(in_features=in_channel, out_features=hidden)
        self.GCN = GCN(in_channels=hidden, out_channels=hidden, num_layers=2, hidden_channels=hidden, act="tanh")

    def forward(self, x, edge_index, iteration):
        x = self.lin(x)
        for e in range(iteration):
            x = self.GCN(x, edge_index)
        return x

if __name__ == '__main__':
    color = {0: "red", 1: "blue", 2: "green", 3: "purple"}
    with torch.no_grad():
        m = Model(in_channel=34).to(device)
        graph = karatecClub(32)
        x, y, edge_index = graph.x, graph.y, graph.edge_index
        x = m(x, edge_index, 200)

        # draw raw graph
        x_ = x.cpu().numpy().copy()
        y = y.cpu().numpy()
        edge_index_ = edge_index.cpu().numpy()
        tsne = TSNE(2)
        x_ = tsne.fit_transform(x_)
        # draw edge
        for e in range(edge_index_.shape[1]):
            src = edge_index_[0][e]
            tar = edge_index_[1][e]
            x_x = [x_[src][0], x_[tar][0]]
            x_y = [x_[src][1], x_[tar][1]]
            plt.plot(x_x, x_y, color="black", alpha=0.1)

        plt.scatter(x_[:, 0], x_[:, 1], c=[color[i] for i in y])
        plt.show()








