import os
import networkx as nx
import random
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia

Dataset = ["computers", "photo"]

def load_data(dataset_name):
    if dataset_name == "Cora":
        dataset = Planetoid(root='/data', name='Cora')
        data = dataset[0]
    elif dataset_name == "CiteSeer":
        dataset = Planetoid(root='/data', name='CiteSeer')
        data = dataset[0]
    elif dataset_name == "PubMed":
        dataset = Planetoid(root='/data', name='PubMed')
        data = dataset[0]
    elif dataset_name == "computers":
        dataset = Amazon(root='/data/computers', name='computers')
        data = dataset[0]
    elif dataset_name == "photo":
        dataset = Amazon(root='/data/photo', name='photo')
        data = dataset[0]
    elif dataset_name == "chameleon":
        preProcDs = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data = data
    elif dataset_name == "squirrel":
        preProcDs = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data = data
    elif dataset_name == "texas":
        dataset = WebKB(root='/data', name = "Texas")
        data = dataset[0]
    elif dataset_name == "ogbn_arxiv":
        dataset = PygNodePropPredDataset(root='/data', name='ogbn-arxiv', transform=T.ToSparseTensor())
        data = dataset[0]
        data.adj_t = data.adj_t.to_symmetric()
        edge_index = data.adj_t.coo()
        data.edge_index = torch.stack([edge_index[0], edge_index[1]], dim=0)
        data.y = data.y.squeeze(1)
    elif 'cSBM' in dataset_name:
        path = '/data/' 
        dataset = dataset_ContextualSBM(path, name=dataset_name)
        data = dataset[0]
    elif 'Facebook' in dataset_name:
            path = '/data/Facebook' 
            dataset = AttributedGraphDataset(path, name=dataset_name)
            data = dataset[0]
            data.y = data.y.max(1)[1]
    elif 'LastFM' in dataset_name:
        path = '/data/LastFM' 
        dataset = LastFMAsia(path)
        data = dataset[0]
    return dataset, data

def build_graph(edge_index):
    G = nx.Graph()
    for i in range(data.num_nodes):
        G.add_node(i)
    edge_index = data.edge_index.t().tolist()
    for edge in edge_index:
        G.add_edge(edge[0], edge[1])
    print(G.number_of_nodes(), G.number_of_edges())
    return G

def get_subgraph_edge_index(G, nodes):
    node_to_idx = {node: idx for idx, node in enumerate(nodes)}
    edge_index = [[], []]
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            if G.has_edge(nodes[i], nodes[j]):
                edge_index[0].append(node_to_idx[nodes[i]])
                edge_index[1].append(node_to_idx[nodes[j]])
                edge_index[0].append(node_to_idx[nodes[j]])
                edge_index[1].append(node_to_idx[nodes[i]])
    return edge_index

def find_clique_and_2hop_paths(G, num):
    cliques = []
    clique_edge_indices = []
    paths_2hop = []
    paths_edge_indices = []
    cnt1 = 0
    cnt2 = 0

    for u, v in G.edges():
        for w in G.neighbors(u):
            if w != v and not G.has_edge(w, v):
                path = sorted([u, v, w])
                edge_index = get_subgraph_edge_index(G, path)
                paths_2hop.append(path)
                paths_edge_indices.append(edge_index)
                cnt2 += 1
            elif w != v and G.has_edge(w, v):
                path = sorted([u, v, w])
                edge_index = get_subgraph_edge_index(G, path)
                cliques.append(path)
                clique_edge_indices.append(edge_index)
                cnt1 += 1

        if cnt1 >= num and cnt2 >= num:
            return cliques, clique_edge_indices, paths_2hop, paths_edge_indices

        for w in G.neighbors(v):
            if w != u and not G.has_edge(w, u):
                path = sorted([u, v, w])
                edge_index = get_subgraph_edge_index(G, path)
                paths_2hop.append(path)
                paths_edge_indices.append(edge_index)
                cnt2 += 1

        if cnt1 >= num and cnt2 >= num:
            return cliques, clique_edge_indices, paths_2hop, paths_edge_indices
    return cliques, clique_edge_indices, paths_2hop, paths_edge_indices

def find_non_clique_non_2hop(G, num):
    nodes_list = []
    edge_indices = []
    cnt = 0
    all_nodes = list(G.nodes())
    for i in range(len(all_nodes)):
        u = all_nodes[i]
        for j in range(i + 1, len(all_nodes)):
            v = all_nodes[j]
            for k in range(j + 1, len(all_nodes)):
                w = all_nodes[k]
                edges = [
                    G.has_edge(u, v),
                    G.has_edge(v, w),
                    G.has_edge(u, w)
                ]
                nodes = sorted([u, v, w])
                if sum(edges) == 1:# or sum(edges) == 0:
                    edge_index = get_subgraph_edge_index(G, nodes)
                    nodes_list.append(nodes)
                    edge_indices.append(edge_index)
                    cnt += 1
                    if cnt == num:
                        return nodes_list, edge_indices
    return nodes_list, edge_indices


def sample_combos(nodes_list, edge_indices, num_samples):
    if len(nodes_list) > num_samples:
        indices = np.random.choice(len(nodes_list), num_samples, replace=False)
        sampled_nodes = [nodes_list[i] for i in indices]
        sampled_edge_indices = [edge_indices[i] for i in indices]
    else:
        sampled_nodes = nodes_list
        sampled_edge_indices = edge_indices
    return sampled_nodes, sampled_edge_indices


def save_to_npy(data, file_name):
    np.save(file_name, data)
    print(f"File has been saved in {file_name}")

def save_to_npy_nonregular(data, file_name):
    np.save(file_name, data, allow_pickle=True)
    print(f"File has been saved in {file_name}")

def main(edge_index, output_dir, num_samples):
    G = build_graph(edge_index)
    cliques, cliques_edgeindex, paths_2hop, paths_2hop_edgeindex = find_clique_and_2hop_paths(G, 1000)
    print(len(cliques), len(paths_2hop))
    sampled_cliques,  sampled_cliques_edgeindex = sample_combos(cliques, cliques_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_cliques, f"{output_dir}/sampled_cliques.npy")
    save_to_npy(sampled_cliques_edgeindex, f"{output_dir}/sampled_cliques_edgeindex.npy")

    sampled_2hop_paths, sampled_2hop_paths_edgeindex = sample_combos(paths_2hop, paths_2hop_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_2hop_paths, f"{output_dir}/sampled_2hop_paths.npy")
    save_to_npy(sampled_2hop_paths_edgeindex, f"{output_dir}/sampled_2hop_paths_edgeindex.npy")
    
    non_clique_non_2hop, non_clique_non_2hop_edgeindex = find_non_clique_non_2hop(G, num = 1000)
    print(len(non_clique_non_2hop), len(non_clique_non_2hop_edgeindex))
    sampled_non_clique_non_2hop, sampled_non_clique_non_2hop_edgeindex = sample_combos(non_clique_non_2hop, non_clique_non_2hop_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_non_clique_non_2hop, f"{output_dir}/sampled_non_clique_non_2hop.npy")
    save_to_npy_nonregular(sampled_non_clique_non_2hop_edgeindex, f"{output_dir}/sampled_non_clique_non_2hop_edgeindex.npy")

for dataset_name in Dataset:
    dataset, data = load_data(dataset_name)
    output_dir = "/data/sample_result/" + dataset_name
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    main(data.edge_index, output_dir, num_samples = 500)
