from tqdm import tqdm
import torch

from utils_proof.graph import get_weight_matrix_2, get_ppr_graph, get_adjacency_matrix, max_singular_vectors

def encode_bayes_ppr2(self, train_graphs, test_graphs, arg):
    
    graphs = train_graphs + test_graphs
    n_train = len(train_graphs)
    h, pooling, alpha = arg.split('-')
    h = int(h)
    alpha = float(alpha.replace('p', '.'))
    h = int(h)
    n_feature = 768
    
    def encode_g(g):
        
        node_labels = list(g.nodes)
        node_embs = self.encoder.encode(node_labels)

        if g.number_of_nodes() == 0:
            node_embs = torch.ones((1, n_feature), device=self.device)
            g.add_node(0)
        X = node_embs
        
        W = get_weight_matrix_2(g, device=self.device, self_loop=True)
        ppr_g = get_ppr_graph(g)
        P = get_adjacency_matrix(ppr_g, device=self.device)
        
        if 'mean' in pooling:
            for k in range(h + 1):
                node_embs = (1 - alpha) * W.T @ node_embs + alpha * P.T @ node_embs
            Z = node_embs
            
            alpha2, beta2 = max_singular_vectors(X.T @ Z)
            thought_dict = {'alpha2': alpha2, 'beta2': beta2, 'X': X, 'Z': Z}
            return thought_dict
    
    output = []
    for i, g in tqdm(enumerate(graphs), total=len(graphs)):
        graph_emb = encode_g(g)
        output.append(graph_emb)
    
    alpha2s = torch.stack([x['alpha2'] for x in output], dim=0)
    beta2s = torch.stack([x['beta2'] for x in output], dim=0)

    trains = {'alpha2': alpha2s[:n_train], 'beta2': beta2s[:n_train]}
    tests = [{'X': x['X'], 'Z': x['Z']} for x in output[n_train:]]
    return trains, tests