from pathlib import Path
from typing import Dict
from collections import defaultdict
import random
import pickle
import numpy as np
import re
import json
import torch

from torch_geometric.data import Data, DataLoader
from data_loaders.load import _get_uniques_, _pad_statements_, count_stats, remove_dups

def load_clean_wikipeople_statements(subtype, maxlen=17) -> Dict:
    """
        :return: train/valid/test splits for the wikipeople dataset in its quints form
    """
    DIRNAME = Path('./data/clean/wikipeople')

    # Load raw shit
    with open(DIRNAME / 'train.txt', 'r') as f:
        raw_trn = []
        for line in f.readlines():
            raw_trn.append(line.strip("\n").split(","))

    with open(DIRNAME / 'test.txt', 'r') as f:
        raw_tst = []
        for line in f.readlines():
            raw_tst.append(line.strip("\n").split(","))

    with open(DIRNAME / 'valid.txt', 'r') as f:
        raw_val = []
        for line in f.readlines():
            raw_val.append(line.strip("\n").split(","))

    # Get uniques
    statement_entities, statement_predicates = _get_uniques_(train_data=raw_trn,
                                                             test_data=raw_tst,
                                                             valid_data=raw_val)

    st_entities = ['__na__'] + statement_entities
    st_predicates = ['__na__'] + statement_predicates

    entoid = {pred: i for i, pred in enumerate(st_entities)}
    prtoid = {pred: i for i, pred in enumerate(st_predicates)}

    train, valid, test = [], [], []
    for st in raw_trn:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        train.append(id_st)
    for st in raw_val:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        valid.append(id_st)
    for st in raw_tst:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        test.append(id_st)

    if subtype == "triples":
        maxlen = 3
    elif subtype == "quints":
        maxlen = 5

    train, valid, test = _pad_statements_(train, maxlen), \
                         _pad_statements_(valid, maxlen), \
                         _pad_statements_(test, maxlen)

    if subtype == "triples" or subtype == "quints":
        train, valid, test = remove_dups(train), remove_dups(valid), remove_dups(test)

    return {"train": train, "valid": valid, "test": test, "n_entities": len(st_entities),
            "n_relations": len(st_predicates), 'e2id': entoid, 'r2id': prtoid}

def load_clean_jf17k_statements(subtype, maxlen=15) -> Dict:
    PARSED_DIR = Path('./data/clean/jf17k')

    training_statements = []
    test_statements = []

    with open(PARSED_DIR / 'train.txt', 'r') as train_file, \
        open(PARSED_DIR / 'test.txt', 'r') as test_file:

        for line in train_file:
            training_statements.append(line.strip("\n").split(","))

        for line in test_file:
            test_statements.append(line.strip("\n").split(","))

    st_entities, st_predicates = _get_uniques_(training_statements, test_statements, test_statements)
    st_entities = ['__na__'] + st_entities
    st_predicates = ['__na__'] + st_predicates

    entoid = {pred: i for i, pred in enumerate(st_entities)}
    prtoid = {pred: i for i, pred in enumerate(st_predicates)}

    # sample valid as 20% of train
    random.shuffle(training_statements)
    tr_st = training_statements[:int(0.8*len(training_statements))]
    val_st = training_statements[int(0.8*len(training_statements)):]

    train, valid, test = [], [], []
    for st in tr_st:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        train.append(id_st)

    for st in val_st:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        valid.append(id_st)

    for st in test_statements:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        test.append(id_st)


    if subtype == "triples":
        maxlen = 3
    elif subtype == "quints":
        maxlen = 5

    train, valid, test = _pad_statements_(train, maxlen), \
                         _pad_statements_(valid, maxlen), \
                         _pad_statements_(test, maxlen)

    if subtype == "triples" or subtype == "quints":
        train, valid, test = remove_dups(train), remove_dups(valid), remove_dups(test)

    return {"train": train, "valid": valid, "test": test, "n_entities": len(st_entities),
            "n_relations": len(st_predicates), 'e2id': entoid, 'r2id': prtoid}


def load_clean_wd50k(name, subtype, maxlen=43) -> Dict:
    """
        :return: train/valid/test splits for the wd50k datasets
    """
    assert name in ['wd50k', 'wd50k_100', 'wd50k_33', 'wd50k_66'], \
        "Incorrect dataset"
    assert subtype in ["triples", "quints", "statements"], "Incorrect subtype: triples/quints/statements"

    DIRNAME = Path(f'./data/clean/{name}/{subtype}')

    # Load raw shit
    with open(DIRNAME / 'train.txt', 'r') as f:
        raw_trn = []
        for line in f.readlines():
            raw_trn.append(line.strip("\n").split(","))

    with open(DIRNAME / 'test.txt', 'r') as f:
        raw_tst = []
        for line in f.readlines():
            raw_tst.append(line.strip("\n").split(","))

    with open(DIRNAME / 'valid.txt', 'r') as f:
        raw_val = []
        for line in f.readlines():
            raw_val.append(line.strip("\n").split(","))

    # Get uniques
    statement_entities, statement_predicates = _get_uniques_(train_data=raw_trn,
                                                             test_data=raw_tst,
                                                             valid_data=raw_val)

    st_entities = ['__na__'] + statement_entities
    st_predicates = ['__na__'] + statement_predicates

    entoid = {pred: i for i, pred in enumerate(st_entities)}
    prtoid = {pred: i for i, pred in enumerate(st_predicates)}

    train, valid, test = [], [], []
    for st in raw_trn:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        train.append(id_st)
    for st in raw_val:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        valid.append(id_st)
    for st in raw_tst:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        test.append(id_st)

    if subtype != "triples":
        if subtype == "quints":
            maxlen = 5
        train, valid, test = _pad_statements_(train, maxlen), \
                             _pad_statements_(valid, maxlen), \
                             _pad_statements_(test, maxlen)

    if subtype == "triples" or subtype == "quints":
        train, valid, test = remove_dups(train), remove_dups(valid), remove_dups(test)

    return {"train": train, "valid": valid, "test": test, "n_entities": len(st_entities),
            "n_relations": len(st_predicates), 'e2id': entoid, 'r2id': prtoid}


def load_nodecl_dataset(name, subtype, task, maxlen=43) -> Dict:
    """

    :param name: dataset name wd15k/wd15k_33/wd15k_66/wd15k_qonly
    :param subtype: triples/statements
    :param task: so/full predict entities at sub/obj positions (for triples/statements) or all nodes incl quals
    :param maxlen: max statement length
    :return: train/valid/test splits for the wd15k datasets
    """

    assert name in ['wd50k', 'wd50k_100', 'wd50k_33', 'wd50k_66'], "Incorrect dataset"
    assert subtype in ["triples", "statements"], "Incorrect subtype: triples/statements"


    DIRNAME = Path(f'./data/clean/{name}/{subtype}')

    with open(DIRNAME / 'nc_edges.txt', 'r') as f:
        edges = []
        for line in f.readlines():
            edges.append(line.strip("\n").split(","))

    with open(DIRNAME / 'nc_entities.txt', 'r') as f:
        statement_entities = [l.strip("\n") for l in f.readlines()]

    with open(DIRNAME / 'nc_rels.txt', 'r') as f:
        statement_predicates = [l.strip("\n") for l in f.readlines()]

    if subtype == "triples":
        task = "so"

    with open(DIRNAME / f'nc_train_{task}_labels.json', 'r') as f:
        train_labels = json.load(f)

    with open(DIRNAME / f'nc_val_{task}_labels.json', 'r') as f:
        val_labels = json.load(f)

    with open(DIRNAME / f'nc_test_{task}_labels.json', 'r') as f:
        test_labels = json.load(f)

        # load node features with the total index
    entity_index = {line.strip('\n'): i for i, line in
                    enumerate(open(f'./data/clean/{name}/statements/{name}_entity_index.txt').readlines())}
    total_node_features = np.load(f'./data/clean/{name}/statements/{name}_embs.pkl', allow_pickle=True)


    idx = np.array([entity_index[key] for key in statement_entities], dtype='int32')
    node_features = total_node_features[idx]


    st_entities = ['__na__'] + statement_entities
    st_predicates = ['__na__'] + statement_predicates

    entoid = {pred: i for i, pred in enumerate(st_entities)}
    prtoid = {pred: i for i, pred in enumerate(st_predicates)}

    graph, train_mask, val_mask, test_mask = [], [], [], []
    for st in edges:
        id_st = []
        for i, uri in enumerate(st):
            id_st.append(entoid[uri] if i % 2 is 0 else prtoid[uri])
        graph.append(id_st)

    if subtype != "triples":
        graph = _pad_statements_(graph, maxlen)

    # if subtype == "triples":
    #     graph = remove_dups(graph)

    train_mask = [entoid[e] for e in train_labels]
    val_mask = [entoid[e] for e in val_labels]
    test_mask = [entoid[e] for e in test_labels]

    all_labels = sorted(list(set([
        label for v in list(train_labels.values())+list(val_labels.values())+list(test_labels.values()) for label in v])))
    label2id = {l: i for i, l in enumerate(all_labels)}
    id2label = {v: k for k, v in label2id.items()}

    train_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in train_labels.items()}
    val_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in val_labels.items()}
    test_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in test_labels.items()}

    return {"train_mask": train_mask, "valid_mask": val_mask, "test_mask": test_mask,
            "train_y": train_y, "val_y": val_y, "test_y": test_y,
            "all_labels": all_labels, "label2id": label2id, "id2label": id2label,
            "n_entities": len(st_entities), "n_relations": len(st_predicates),
            "e2id": entoid, "r2id": prtoid, "graph": graph, "features": node_features
            }

def to_sparse_graph(edges, subtype, entoid, prtoid, maxlen):

    edge_index, edge_type = np.zeros((2, len(edges)), dtype='int64'), np.zeros((len(edges)), dtype='int64')
    qualifier_rel = []
    qualifier_ent = []
    qualifier_edge = []
    quals = None

    for i, st in enumerate(edges):
        edge_index[:, i] = [entoid[st[0]], entoid[st[2]]]
        edge_type[i] = prtoid[st[1]]

        if subtype == 'statements':
            qual_rel = np.array([prtoid[r] for r in st[3::2]])[
                       :(maxlen - 3) // 2]  # cut to the max allowed qualifiers per statement
            qual_ent = np.array([entoid[e] for e in st[4::2]])[
                       :(maxlen - 3) // 2]  # cut to the max allowed qualifiers per statement
            for j in range(qual_ent.shape[0]):
                qualifier_rel.append(qual_rel[j])
                qualifier_ent.append(qual_ent[j])
                qualifier_edge.append(i)

    if subtype == 'statements':
        quals = np.stack((qualifier_rel, qualifier_ent, qualifier_edge), axis=0)

    return edge_index, edge_type, quals


def load_clean_pyg(name, subtype, task, inductive="transductive", ind_v=None, maxlen=43, permute=False) -> Dict:
    """
    :param name: dataset name wd50k/wd50k_33/wd50k_66/wd50k_100
    :param subtype: triples/statements
    :param task: so/full predict entities at sub/obj positions (for triples/statements) or all nodes incl quals
    :param inductive: whether to load transductive dataset (one graph for train/val/test) or inductive
    :param ind_v: v1 / v2 for the inductive dataset
    :param maxlen: max statement length
    :return: train/valid/test splits for the wd50k datasets suitable for loading into TORCH GEOMETRIC dataset
    no reciprocal edges (as will be added in the gnn layer), create directly the edge index
    """
    assert name in ['wd50k', 'wd50k_100', 'wd50k_33', 'wd50k_66'], "Incorrect dataset"
    assert subtype in ["triples", "statements"], "Incorrect subtype: triples/statements"
    assert inductive in ["transductive", "inductive"], "Incorrect ds type: only transductive and inductive accepted"
    if inductive == "inductive":
        assert ind_v in ["v1", "v2"], "Only v1 and v2 are allowed versions for the inductive task"

    if inductive == "transductive":
        DIRNAME = Path(f'./data/clean/{name}/{subtype}')
        train_edges = [line.strip("\n").split(",") for line in open(DIRNAME / 'nc_edges.txt', 'r').readlines()]
        print(f"Transductive: With quals: {len([t for t in train_edges if len(t)>3])} / {len(train_edges)}, Ratio: {round((len([t for t in train_edges if len(t)>3]) / len(train_edges)),2)}")
    else:
        DIRNAME = Path(f'./data/clean/{name}/inductive/nc/{subtype}/{ind_v}')
        train_edges = [line.strip("\n").split(",") for line in open(DIRNAME / 'nc_train_edges.txt', 'r').readlines()]
        val_edges = [line.strip("\n").split(",") for line in open(DIRNAME / 'nc_val_edges.txt', 'r').readlines()]
        test_edges = [line.strip("\n").split(",") for line in open(DIRNAME / 'nc_test_edges.txt', 'r').readlines()]
        print(
            f"Inductive train: With quals: {len([t for t in train_edges if len(t) > 3])} / {len(train_edges)}, Ratio: {round((len([t for t in train_edges if len(t) > 3]) / len(train_edges)), 2)}")
        print(
            f"Inductive val: With quals: {len([t for t in val_edges if len(t) > 3])} / {len(val_edges)}, Ratio: {round((len([t for t in val_edges if len(t) > 3]) / len(val_edges)), 2)}")
        print(
            f"Inductive test: With quals: {len([t for t in test_edges if len(t) > 3])} / {len(test_edges)}, Ratio: {round((len([t for t in test_edges if len(t) > 3]) / len(test_edges)), 2)}")

    statement_entities = [l.strip("\n") for l in open(DIRNAME / 'nc_entities.txt', 'r').readlines()]
    statement_predicates = [l.strip("\n") for l in open(DIRNAME / 'nc_rels.txt', 'r').readlines()]

    if subtype == "triples":
        task = "so"

    with open(DIRNAME / f'nc_train_{task}_labels.json', 'r') as f:
        train_labels = json.load(f)

    with open(DIRNAME / f'nc_val_{task}_labels.json', 'r') as f:
        val_labels = json.load(f)

    with open(DIRNAME / f'nc_test_{task}_labels.json', 'r') as f:
        test_labels = json.load(f)

    # load node features with the total index
    entity_index = {line.strip('\n'): i for i, line in enumerate(open(f'./data/clean/{name}/statements/{name}_entity_index.txt').readlines())}
    total_node_features = np.load(f'./data/clean/{name}/statements/{name}_embs.pkl', allow_pickle=True)

    idx = np.array([entity_index[key] for key in statement_entities], dtype='int32')
    node_features = total_node_features[idx]

    if permute:
        node_features = np.random.permutation(node_features)

    entoid = {pred: i for i, pred in enumerate(statement_entities)}
    prtoid = {pred: i for i, pred in enumerate(statement_predicates)}

    print(f"Total Entities: {len(entoid)}")
    print(f"Total Rels: {len(prtoid)}")

    train_edge_index, train_edge_type, train_quals = to_sparse_graph(train_edges, subtype, entoid, prtoid, maxlen)
    if inductive == "inductive":
        val_edge_index, val_edge_type, val_quals = to_sparse_graph(val_edges, subtype, entoid, prtoid, maxlen)
        test_edge_index, test_edge_type, test_quals = to_sparse_graph(test_edges, subtype, entoid, prtoid, maxlen)


    train_mask = [entoid[e] for e in train_labels]
    val_mask = [entoid[e] for e in val_labels]
    test_mask = [entoid[e] for e in test_labels]

    if inductive == "inductive":
        print(f"Train Ents: {len(train_labels)}, Val Ents: {len(val_labels)}, Test Ents: {len(test_labels)}")

    all_labels = sorted(list(set([
        label for v in list(train_labels.values()) + list(val_labels.values()) + list(test_labels.values()) for label in
        v])))
    label2id = {l: i for i, l in enumerate(all_labels)}
    id2label = {v: k for k, v in label2id.items()}
    print(f"Total labels: {len(label2id)}")

    train_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in train_labels.items()}
    val_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in val_labels.items()}
    test_y = {entoid[k]: [label2id[vi] for vi in v] for k, v in test_labels.items()}

    train_graph = Data(x=torch.tensor(node_features, dtype=torch.float),
                       edge_index=torch.tensor(train_edge_index, dtype=torch.long),
                       edge_type=torch.tensor(train_edge_type, dtype=torch.long),
                       quals=torch.tensor(train_quals, dtype=torch.long) if train_quals is not None else None, y=train_y)
    val_graph, test_graph = None, None



    if inductive == "inductive":
        val_graph = Data(x=torch.tensor(node_features, dtype=torch.float),
                         edge_index=torch.tensor(val_edge_index, dtype=torch.long),
                         edge_type=torch.tensor(val_edge_type, dtype=torch.long),
                         quals=torch.tensor(val_quals, dtype=torch.long) if val_quals is not None else None, y=val_y)
        test_graph = Data(x=torch.tensor(node_features, dtype=torch.float),
                          edge_index=torch.tensor(test_edge_index, dtype=torch.long),
                          edge_type=torch.tensor(test_edge_type, dtype=torch.long),
                          quals=torch.tensor(test_quals, dtype=torch.long) if test_quals is not None else None, y=test_y)

    return {"train_graph": train_graph, "val_graph": val_graph, "test_graph": test_graph,
            "train_mask": train_mask, "valid_mask": val_mask, "test_mask": test_mask,
            "train_y": train_y, "val_y": val_y, "test_y": test_y,
            "all_labels": all_labels, "label2id": label2id, "id2label": id2label,
            "n_entities": len(statement_entities), "n_relations": len(statement_predicates),
            "e2id": entoid, "r2id": prtoid}



if __name__ == "__main__":
    data = load_clean_pyg('wd50k', 'statements', 'so', inductive="inductive", ind_v='v2', maxlen=15)
    print("nop")
    #count_stats(load_clean_wd50k("wd50k","statements",43))