import os
import networkx as nx
import random
import torch_geometric.transforms as T
import numpy as np
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import AttributedGraphDataset, LastFMAsia

Dataset = ['PubMed', "computers", "photo"]

def load_data(dataset_name):
    if dataset_name == "Cora":
        dataset = Planetoid(root='/data', name='Cora')
        data = dataset[0]
    elif dataset_name == "CiteSeer":
        dataset = Planetoid(root='/data', name='CiteSeer')
        data = dataset[0]
    elif dataset_name == "PubMed":
        dataset = Planetoid(root='/data', name='PubMed')
        data = dataset[0]
    elif dataset_name == "computers":
        dataset = Amazon(root='/data/computers', name='computers')
        data = dataset[0]
    elif dataset_name == "photo":
        dataset = Amazon(root='/data/photo', name='photo')
        data = dataset[0]
    elif dataset_name == "chameleon":
        preProcDs = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='chameleon', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data = data
    elif dataset_name == "squirrel":
        preProcDs = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(root='/data', name='squirrel', geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        data = data
    elif dataset_name == "texas":
        dataset = WebKB(root='/data', name = "Texas")
        data = dataset[0]
    elif dataset_name == "ogbn_arxiv":
        dataset = PygNodePropPredDataset(root='/data', name='ogbn-arxiv', transform=T.ToSparseTensor())
        data = dataset[0]
        data.adj_t = data.adj_t.to_symmetric()
        edge_index = data.adj_t.coo()
        data.edge_index = torch.stack([edge_index[0], edge_index[1]], dim=0)
        data.y = data.y.squeeze(1)
    elif 'cSBM' in dataset_name:
        path = '/data/' 
        dataset = dataset_ContextualSBM(path, name=dataset_name)
        data = dataset[0]
    elif 'Facebook' in dataset_name:
            path = '/data/Facebook' 
            dataset = AttributedGraphDataset(path, name=dataset_name)
            data = dataset[0]
            data.y = data.y.max(1)[1]
    elif 'LastFM' in dataset_name:
        path = '/data/LastFM' 
        dataset = LastFMAsia(path)
        data = dataset[0]
    return dataset, data

def build_graph(edge_index):
    G = nx.Graph()
    for i in range(data.num_nodes):
        G.add_node(i)
    edge_index = data.edge_index.t().tolist()
    for edge in edge_index:
        G.add_edge(edge[0], edge[1])
    print(G.number_of_nodes(), G.number_of_edges())
    return G

def get_subgraph_edge_index(G, nodes):
    node_to_idx = {node: idx for idx, node in enumerate(nodes)}
    edge_index = [[], []]
    for i in range(len(nodes)):
        for j in range(i + 1, len(nodes)):
            if G.has_edge(nodes[i], nodes[j]):
                edge_index[0].append(node_to_idx[nodes[i]])
                edge_index[1].append(node_to_idx[nodes[j]])
                edge_index[0].append(node_to_idx[nodes[j]])
                edge_index[1].append(node_to_idx[nodes[i]])
    return edge_index

def find_clique_and_3hop_paths(G, num):
    cliques = []
    clique_edge_indices = []
    paths_3hop = []
    paths_edge_indices = []
    cnt1 = 0
    cnt2 = 0
    nodes = list(G.nodes())
    n = len(nodes)

    for u, v in G.edges():
        for w in G.neighbors(u):
            for k in G.neighbors(v):
                if w != v and not G.has_edge(w, v) and k != u and not G.has_edge(k, u) and k != w and not G.has_edge(w, k) and cnt1 <= num:  
                    path = sorted([u, v, w, k])
                    edge_index = get_subgraph_edge_index(G, path)
                    paths_3hop.append(path)
                    paths_edge_indices.append(edge_index)
                    cnt1 += 1
                elif w != v and G.has_edge(w, v) and k != u and G.has_edge(k, u) and k != w and G.has_edge(w, k) and cnt2 <= num:
                    path = sorted([u, v, w, k])
                    edge_index = get_subgraph_edge_index(G, path)
                    cliques.append(path)
                    clique_edge_indices.append(edge_index)
                    cnt2 += 1

                if cnt1 >= num and cnt2 >= num:
                    return cliques, clique_edge_indices, paths_3hop, paths_edge_indices
    return cliques, clique_edge_indices, paths_3hop, paths_edge_indices

def find_non_clique_non_3hop(G, num):
    nodes_list = []
    edge_indices = []
    cnt = 0
    all_nodes = list(G.nodes())
    for i in range(len(all_nodes)):
        u = all_nodes[i]
        for j in range(i + 1, len(all_nodes)):
            v = all_nodes[j]
            for k in range(j + 1, len(all_nodes)):
                w = all_nodes[k]
                for l in range(k + 1, len(all_nodes)):
                    x = all_nodes[l]
                    edges = [
                        G.has_edge(u, v),
                        G.has_edge(u, w),
                        G.has_edge(u, x),
                        G.has_edge(v, w),
                        G.has_edge(v, x),
                        G.has_edge(w, x),
                    ]
                    nodes = sorted([u, v, w, x])
                    if sum(edges) == 1:
                        edge_index = get_subgraph_edge_index(G, nodes)
                        nodes_list.append(nodes)
                        edge_indices.append(edge_index)
                        cnt += 1
                        if cnt == num:
                            return nodes_list, edge_indices
    return nodes_list, edge_indices

def sample_combos(nodes_list, edge_indices, num_samples):
    if len(nodes_list) > num_samples:
        indices = np.random.choice(len(nodes_list), num_samples, replace=False)
        sampled_nodes = [nodes_list[i] for i in indices]
        sampled_edge_indices = [edge_indices[i] for i in indices]
    else:
        sampled_nodes = nodes_list
        sampled_edge_indices = edge_indices
    return sampled_nodes, sampled_edge_indices

def save_to_npy(data, file_name):
    np.save(file_name, data)
    print(f"File has been saved in {file_name}")

def save_to_npy_nonregular(data, file_name):
    np.save(file_name, data, allow_pickle=True)
    print(f"File has been saved in {file_name}")

def main(edge_index, output_dir, num_samples):
    G = build_graph(edge_index)
    cliques, cliques_edgeindex, paths_3hop, paths_3hop_edgeindex = find_clique_and_3hop_paths(G, num = 2000)
    print(len(cliques), len(paths_3hop))
    sampled_cliques,  sampled_cliques_edgeindex = sample_combos(cliques, cliques_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_cliques, f"{output_dir}/sampled_cliques_4smia.npy")
    save_to_npy(sampled_cliques_edgeindex, f"{output_dir}/sampled_cliques_edgeindex_4smia.npy")

    sampled_3hop_paths, sampled_3hop_paths_edgeindex = sample_combos(paths_3hop, paths_3hop_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_3hop_paths, f"{output_dir}/sampled_3hop_paths_4smia.npy")
    save_to_npy(sampled_3hop_paths_edgeindex, f"{output_dir}/sampled_3hop_paths_edgeindex_4smia.npy")
    
    non_clique_non_3hop, non_clique_non_3hop_edgeindex = find_non_clique_non_3hop(G, num = 2000)
    print(len(non_clique_non_3hop), len(non_clique_non_3hop_edgeindex))
    sampled_non_clique_non_3hop, sampled_non_clique_non_3hop_edgeindex = sample_combos(non_clique_non_3hop, non_clique_non_3hop_edgeindex, num_samples=num_samples)
    save_to_npy(sampled_non_clique_non_3hop, f"{output_dir}/sampled_non_clique_non_3hop_4smia.npy")
    save_to_npy(sampled_non_clique_non_3hop_edgeindex, f"{output_dir}/sampled_non_clique_non_3hop_edgeindex_4smia.npy")

for dataset_name in Dataset:
    dataset, data = load_data(dataset_name)
    output_dir = "/data/sample_result/" + dataset_name
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    main(data.edge_index, output_dir, num_samples = 500)
