import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx
import numpy as np
from decoder.decoder import Decoder

from kdtree import kdtree_batch_knn

class MLPDecoder(nn.Module, Decoder):
    def __init__(self, config, embeddings):
        super().__init__()
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.cons_config = config.construction
        emb_size = config.embedding_dim

        if config.decoder.mlp.train.tune_embeddings == True:
            self.embeddings = torch.nn.Parameter(torch.tensor(embeddings, dtype=torch.float32, device=self.device), requires_grad=True)
        else:
            self.embeddings = embeddings

        self.fc1 = torch.nn.Linear(4*emb_size, 128)
        self.fc2 = torch.nn.Linear(128, 1)
        self.dropout = torch.nn.Dropout(p=0.5)
    
    def forward(self, edge_index, embeddings):
        i1, i2 = edge_index
        e1, e2 = embeddings[i1], embeddings[i2]

        edge_features = torch.concatenate([
        e1 * e2,                 # Hadamard
        torch.abs(e1 - e2),         # L1
        (e1 - e2)**2,            # L2
        (e1 + e2) / 2,           # Average
        ], dim=1)

        x = self.fc1(edge_features)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)

        return x.view(-1)
    
    def construct(self, z, topk=None):
        
        method = self.cons_config.method # method used while construction of graph, e.g., none, random drop, kdtree
        num_nodes = z.shape[0]
        z = z.to(self.device).float()

        if method is None:
            edge_index = torch.combinations(torch.arange(num_nodes), r=2).T 
            edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1).to(self.device)

            self.eval()
            with torch.no_grad():
                predictions = self.forward(edge_index, z)

        elif method == "random":
            sample_ratio = self.cons_config.sample_ratio
            total_pairs = num_nodes * num_nodes
            sample_size = int(sample_ratio * total_pairs)

            all_edge_index = torch.combinations(torch.arange(num_nodes), r=2).T 
            all_edge_index = torch.cat([all_edge_index, all_edge_index.flip(0)], dim=1).to(self.device)

            indices = np.random.choice(all_edge_index.shape[1], sample_size, replace=False)
            edge_index = all_edge_index[:, indices]

            self.eval()
            with torch.no_grad():
                predictions = self.forward(edge_index, z)

        elif method == "kdtree":

            topk = int(self.cons_config.avg_degree*1.5) if self.cons_config.topk == "auto" else self.cons_config.topk

            neighbors = kdtree_batch_knn(z, topk)

            edge_list = []

            for i in range(num_nodes):

                close_neighbors = neighbors [i]
                for j in close_neighbors:
                    if i != j:
                        edge_list.append((i,j))
        
            edge_index = torch.tensor(edge_list, dtype=torch.long).T.to(self.device)
            self.eval()

            with torch.no_grad():
                predictions = self.forward(edge_index, z)

        adj_matrix = torch.zeros((num_nodes, num_nodes), device=self.device)
        adj_matrix[edge_index[0], edge_index[1]] = predictions

        return adj_matrix
            
            

