import pickle
import torch
import numpy as np


# def get_sorted_s_r_embeddings(s, r, ent_embs, hist):
#     hist_len = torch.LongTen(list(map(hist, len)))
#     lens, idx = hist_len.sort(0, descending=True)
#     num_nonzero = len(torch.nonzero(lens))
#     nonzero_hist_lens, nonzero_idx = lens[:num_nonzero], idx[:num_nonzero]
#
#     s_sorted = s[nonzero_idx]
#     r_sorted = r[nonzero_idx]
#
#     all_s = []
#     all_lens = []
#     for idx in nonzero_idx:
#         hist = hist[idx.item()]
#         for neighbors in hist:
#             all_lens.append(neighbors)
#             for n in neighbors:
#                 all_s.append(n)
#
#     embeds = ent_embs[torch.LongTensor(all_s).cuda()]
#     embeds_split = torch.split(embeds, all_lens)
#
#     return embeds, embeds_split, s_sorted, r_sorted, nonzero_hist_lens, all_lens
#
#
from code_utils.config import PATHS


def get_mask(batch, unknow_ent_id):
    np.mean(np.asarray(batch), axis=2)
    import IPython;
    IPython.embed()


def get_sorted_s_r_embeddings(s, r, ent_embs, rel_embs, hist, device, unkonwn_rel):
    s_hist = hist[0]
    r_hist = hist[1]

    hist_len = torch.LongTensor(list(map(len, s_hist)))
    lens, idx = hist_len.sort(0, descending=True)
    num_nonzero = len(torch.nonzero(lens))
    nonzero_hist_lens, nonzero_idx = lens[:num_nonzero], idx[:num_nonzero]

    s_sorted = s[nonzero_idx].to(device)
    r_sorted = r[nonzero_idx].to(device)

    all_s = []
    all_r = []
    all_lens = []
    batch_mask = []
    for idx in nonzero_idx:
        node_hist = s_hist[idx.item()]
        for neighbors in node_hist:
            all_lens.append(len(neighbors))
            for n in neighbors:
                all_s.append(n)
        rel_hist = r_hist[idx.item()]
        mask = np.ones(len(node_hist))
        for i, rels in enumerate(rel_hist):
            if len(set(rels)) == 1 and rels[0] == unkonwn_rel:
                mask[i] = 0
            for r in rels:
                all_r.append(r)
        # print(mask)
        batch_mask.append(mask)

    node_embeds = ent_embs(torch.LongTensor(all_s).to(device)).to(device)
    rel_embeds = rel_embs(torch.LongTensor(all_r).to(device)).to(device)
    embeds = torch.cat([node_embeds, rel_embeds], dim=1).to(device)
    embeds_split = torch.split(embeds, all_lens, dim=0)

    # We can get relation ids here

    return embeds, embeds_split, \
           s_sorted, r_sorted, nonzero_hist_lens, all_lens, all_r, batch_mask


def euclidean_dist(x, y):
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)



