import os
import json

import torch

structure_dict = None
text_dict = None
ind_text_dict = None
node_text_dict = None
ind_node_text_dict = None
relation_text_dict = None
ind_relation_text_dict = None

def _load_json(path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_structure_dict(base_data_path):
    global structure_dict
    if structure_dict is None:
        structure_dict = torch.load(os.path.join(base_data_path, 'structure.pt'))
    return structure_dict

def load_text_dict(base_data_path):
    global text_dict
    if text_dict is None:
        text_dict = _load_json(os.path.join(base_data_path, 'text.json'))
    return text_dict

def load_ind_text_dict(base_data_path):
    global ind_text_dict
    if ind_text_dict is None:
        ind_text_dict = _load_json(os.path.join(base_data_path, 'ind_text.json'))
    return ind_text_dict

def load_node_text_dict(base_data_path):
    global node_text_dict
    if node_text_dict is None:
        node_text_dict = _load_json(os.path.join(base_data_path, 'node_text.json'))
    return node_text_dict

def load_ind_node_text_dict(base_data_path):
    global ind_node_text_dict
    if ind_node_text_dict is None:
        ind_node_text_dict = _load_json(os.path.join(base_data_path, 'ind_node_text.json'))
    return ind_node_text_dict

def load_relation_text_dict(base_data_path):
    global relation_text_dict
    if relation_text_dict is None:
        relation_text_dict = _load_json(os.path.join(base_data_path, 'relation_text.json'))
    return relation_text_dict

def load_ind_relation_text_dict(base_data_path):
    global ind_relation_text_dict
    if ind_relation_text_dict is None:
        ind_relation_text_dict = _load_json(os.path.join(base_data_path, 'ind_relation_text.json'))
    return ind_relation_text_dict



# node_embeddings = None
# whitening_node_embeddings = None
# relation_embeddings = None
# whitening_relation_embeddings = None
# ind_node_embeddings = None
# whitening_ind_node_embeddings = None
# ind_relation_embeddings = None
# whitening_ind_relation_embeddings = None


# def _whitening(emb, whitening_dim):
#     emb = emb.numpy()
        
#     mean = np.mean(emb, axis=0, keepdims=True)
#     cov = np.cov(emb.T)
#     u, s, vh = np.linalg.svd(cov)
#     kernel, bias = np.dot(u, np.diag(1. / np.sqrt(s))), -mean
#     kernel = kernel[:, :whitening_dim]
#     emb_whitening = (emb + bias).dot(kernel)
        
#     emb_whitening = torch.from_numpy(emb_whitening)
#     emb_whitening = F.normalize(emb_whitening, p=2, dim=-1)
#     return emb_whitening


# def get_node_embeddings(ind, is_gnn=False):
#     if ind:
#         if is_gnn:
#             return whitening_ind_node_embeddings
#         return ind_node_embeddings
#     else:
#         if is_gnn:
#             return whitening_node_embeddings
#         return node_embeddings

# def get_relation_embeddings(ind, is_gnn=False):
#     if ind:
#         if is_gnn:
#             return whitening_ind_relation_embeddings
#         return ind_relation_embeddings
#     else:
#         if is_gnn:
#             return whitening_relation_embeddings
#         return relation_embeddings
    
# def set_node_embeddings(embeddings, ind, whitening_dim):
#     embeddings = embeddings.cpu()
    
#     if ind:
#         global ind_node_embeddings
#         global whitening_ind_node_embeddings
#         ind_node_embeddings = embeddings
#         whitening_ind_node_embeddings = _whitening(embeddings, whitening_dim)
#     else:
#         global node_embeddings
#         global whitening_node_embeddings
#         node_embeddings = embeddings
#         whitening_node_embeddings = _whitening(embeddings, whitening_dim)
        
# def set_relation_embeddings(embeddings, ind, whitening_dim):
#     embeddings = embeddings.cpu()
    
#     if ind:
#         global ind_relation_embeddings
#         global whitening_ind_relation_embeddings
#         ind_relation_embeddings = embeddings
#         whitening_ind_relation_embeddings = _whitening(embeddings, whitening_dim)
#     else:
#         global relation_embeddings
#         global whitening_relation_embeddings
#         relation_embeddings = embeddings
#         whitening_relation_embeddings = _whitening(embeddings, whitening_dim)