from decoder.decoder import Decoder
import torch
import torch.nn as nn

from kdtree import kdtree_batch_knn

class DotProductDecoder(nn.Module, Decoder):
    def __init__(self, config):
        super().__init__()
        self.sigmoid = config.construction.sigmoid
        self.cons_config = config.construction

    def construct(self, z, topk=None):

        method = self.cons_config.method # method used while construction of graph, e.g., none, random drop, kdtree
        num_nodes = z.shape[0]

        if method is None:
            adj_constructed = torch.matmul(z, z.t()) 

        elif method == "random":
            sample_ratio = self.cons_config.sample_ratio
            total_pairs = num_nodes * num_nodes
            sample_size = int(sample_ratio * total_pairs)

            all_indices = torch.cartesian_prod(torch.arange(num_nodes), torch.arange(num_nodes))

            non_diag_mask = all_indices[:, 0] != all_indices[:, 1]
            all_indices = all_indices[non_diag_mask]

            sampled_idx = torch.randperm(all_indices.size(0))[:sample_size]
            sampled_pairs = all_indices[sampled_idx]
            row_indices = sampled_pairs[:, 0]
            col_indices = sampled_pairs[:, 1]

            dot_products = (z[row_indices] * z[col_indices]).sum(dim=1)

            adj_constructed = torch.zeros((num_nodes, num_nodes), device=z.device, dtype=torch.double)
            adj_constructed[row_indices, col_indices] = dot_products
            adj_constructed[col_indices, row_indices] = dot_products

        elif method == "kdtree":

            # if topk is 'auto', assign it as 2 times average degree
            topk = int(self.cons_config.avg_degree)*2 if self.cons_config.topk == "auto" else self.cons_config.topk

            neighbors = kdtree_batch_knn(z, topk)  # Efficient batch k-NN

            adj_constructed = torch.zeros((num_nodes, num_nodes), device=z.device, dtype=torch.double)
            for i in range(num_nodes):

                close_neighbors = neighbors[i]
                for j in close_neighbors:
                    if i != j:
                        adj_constructed[i,j] = torch.dot(z[i], z[j])
                        adj_constructed[j,i] = adj_constructed[i,j]
        
        if self.sigmoid:
            adj_constructed = torch.sigmoid(adj_constructed)
            
        return adj_constructed