import pickle
from multiprocessing import Pool

import networkx as nx
import torch
import torch_geometric.transforms as T
from torch_geometric import seed_everything
from torch_geometric.utils import degree
from tqdm import tqdm

from children_genre.goodreads_children_genre import Goodreads_children_genre
from text_graph import get_graph_prompt

prefix = "As an AI language model, we want to predict whether there will be a link between [node_a] and [node b]. We provide the neighborhood information of [node_a] and [node_b] in the following two paragraphs."
def process_edge(supervision_edge, train_graph):
    user = 'user_' + str(supervision_edge[0].item())
    book = 'book_' + str(supervision_edge[1].item())
    
    print(user, book)
    
    user_id = user.lstrip('user_')
    book_id = book.lstrip('book_')
    key = user_id + '|' + book_id
    value = get_graph_prompt(user, book, train_graph, 'user', 3, 2) + get_graph_prompt(user, book, train_graph, 'book', 3, 2)

    return key, value


def main():
    seed_everything(66)
    Dataset = Goodreads_children_genre(root='.')
    data = Dataset[0]

    num_users = data['user'].num_nodes
    num_books = data['book'].num_nodes
    num_reviews = data['user', 'review', 'book'].num_edges
    num_descriptions = data['book', 'description', 'genre'].num_edges

    data['user', 'review', 'book'].edge_attr = torch.ones(num_reviews, 64)  # TODO
    data['book', 'description', 'genre'].edge_attr = torch.ones(num_descriptions, 64)  # TODO

    # select 4-star or 5-star review as positive edge
    positive_edges_mask = (data['user', 'review', 'book'].edge_label == 5) | (
            data['user', 'review', 'book'].edge_label == 4)
    data['user', 'review', 'book'].edge_index = data['user', 'review', 'book'].edge_index[:, positive_edges_mask]
    data['user', 'review', 'book'].edge_attr = data['user', 'review', 'book'].edge_attr[positive_edges_mask]

    # Add a reverse ('book', 'rev_review', 'user') relation for message passing:
    data = T.ToUndirected()(data)
    del data['book', 'rev_review', 'user'].edge_label  # Remove "reverse" label.
    del data['user', 'review', 'book'].edge_label  # Remove "reverse" label.

    # Perform a link-level split into training, validation, and test edges:
    train_data, val_data, test_data = T.RandomLinkSplit(
        num_val=0.85,
        num_test=0.1,
        disjoint_train_ratio=0.3,
        neg_sampling_ratio=2.0,
        edge_types=[('user', 'review', 'book')],
        rev_edge_types=[('book', 'rev_review', 'user')],
    )(data)

    # TODO: use val_data as training data
    data = train_data
    print(data)
    # assert len(data['user', 'review', 'book'].edge_label_index[0]) == 18795

    # message passing edges + positive supervision edges
    review_message_edge_index = data['user', 'review', 'book'].edge_index
    edge_mask = data['user', 'review', 'book'].edge_label.long() == 1
    review_supervision_edge_index = data['user', 'review', 'book'].edge_label_index[:, edge_mask]  # only label == 1
    review_edge_index = torch.concat((review_message_edge_index, review_supervision_edge_index), dim=1)

    review_edge_index = review_edge_index.tolist()
    user_edges = ['user_' + str(idx) for idx in review_edge_index[0]]
    book_edges = ['book_' + str(idx) for idx in review_edge_index[1]]
    review_edge_list = list(zip(user_edges, book_edges))

    description_edge_list = data['book', 'description', 'genre'].edge_index.tolist()
    book_edges = ['book_' + str(idx) for idx in description_edge_list[0]]
    genre_edges = ['genre_' + str(idx) for idx in description_edge_list[1]]
    description_edge_list = list(zip(book_edges, genre_edges))

    edge_list = review_edge_list + description_edge_list

    # negative supervision edges
    negative_edge_mask = data['user', 'review', 'book'].edge_label.long() == 0
    negative_edge_index = data['user', 'review', 'book'].edge_label_index[:, negative_edge_mask]  # only label == 0

    with open('children_genre/raw/nx_graph.pkl', 'rb') as f:
        G = pickle.load(f)

    edge_texts = {edge: G[edge[0]][edge[1]]['text'] for edge in edge_list}

    document = {}

    # message passing edges + positive supervision edges as train_graph edges
    train_graph = nx.Graph()
    for edge, text in tqdm(edge_texts.items()):
        train_graph.add_edge(*edge, text=text)

    # negative edge as train_graph nodes
    for edge in negative_edge_index.numpy().T:
        user_node = 'user_' + str(edge[0])
        book_node = 'book_' + str(edge[1])
        train_graph.add_node(user_node)
        train_graph.add_node(book_node)

    NP_review_supervision_edge_index = data['user', 'review', 'book'].edge_label_index

    with Pool() as p:
        args = [(edge, train_graph) for edge in NP_review_supervision_edge_index.T]
        results = list(tqdm(p.starmap(process_edge, args), total=len(args)))

    document = dict(results)

    with open('0.15_train_text', 'wb') as f:
        pickle.dump(document, f)

    print(len(document))


if __name__ == "__main__":
    main()
