import torch
import torch.nn as nn
import torch.nn.functional as F


class Graph_Remover(nn.Module):
    def __init__(self, K, n, device):
        super(Graph_Remover, self).__init__()
        self.B = nn.Parameter(torch.FloatTensor(K, n, n).to(device))
        self.device = device
        self.reset_parameters()

    def reset_parameters(self):
        # nn.init.uniform_(self.B)
        nn.init.normal_(self.B, mean=0.0, std=0.1)

    def forward(self, adj, p_remove, k):
        Bk = self.B[k]  # [n, n]
        Bk = Bk.view(-1)
        n = adj.size(0)
        device = self.device

        adj = adj.to(torch.float32)

        P = torch.softmax(Bk, dim=0)  # [n*n]
        num_edges_to_remove = int(p_remove * n * n)
        sampled_idx = torch.multinomial(P, num_samples=num_edges_to_remove, replacement=False)  # [num_edges_to_remove]

        # Create deletion matrix
        M = torch.zeros(n*n, device=device)
        M[sampled_idx] = 1.0  # Mark edges to be deleted
        M = M.view(n, n)

        # Update adjacency matrix
        adj = adj * (1 - M)

        # Calculate log_p
        log_p = torch.logsumexp(Bk[sampled_idx], dim=0) - torch.logsumexp(Bk, dim=0)

        return adj, log_p


# class Graph_Remover(nn.Module):
#     def __init__(self, K, n, device):
#         super(Graph_Remover, self).__init__()
#         self.B = nn.Parameter(torch.FloatTensor(K, n, n).to(device))  # Learnable weights
#         self.device = device
#         self.reset_parameters()
#
#     def reset_parameters(self):
#         nn.init.normal_(self.B, mean=0.0, std=0.1)
#
#     def forward(self, adj, p_remove, k):
#         Bk = self.B[k]  # [n, n]
#         n = adj.size(0)
#         device = self.device
#         adj = adj.to(torch.float32)
#
#         # Find real edges in the graph
#         edge_indices = (adj > 0).nonzero(as_tuple=False)  # [E, 2]
#         if edge_indices.size(0) == 0:
#             return adj, torch.tensor(0.0, device=device)
#
#         # Get Bk values for corresponding edges
#         flat_indices = edge_indices[:, 0] * n + edge_indices[:, 1]
#         Bk_flat = Bk.view(-1)
#         Bk_edges = Bk_flat[flat_indices]  # [E]
#
#         # Each edge independently gets a deletion probability
#         P = torch.sigmoid(Bk_edges)
#
#         # Sample delete_mask
#         delete_mask = torch.bernoulli(P).to(device)
#
#         # Build deletion mask matrix
#         M = torch.zeros(n, n, device=device)
#         M[edge_indices[:, 0], edge_indices[:, 1]] = delete_mask
#
#         # Execute deletion
#         adj = adj * (1 - M)
#
#         # Complete log_prob, can be used for REINFORCE or other policy gradient methods
#         log_p = torch.sum(
#             delete_mask * torch.log(P + 1e-8) +
#             (1 - delete_mask) * torch.log(1 - P + 1e-8)
#         )
#
#         return adj, log_p


class Graph_Adder(nn.Module):
    def __init__(self, n, device):
        super(Graph_Adder, self).__init__()
        self.B = nn.Parameter(torch.FloatTensor(n, n).to(device))  # Probability
        self.A = nn.Parameter(torch.FloatTensor(n, n).to(device))  # Weight
        self.device = device
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.B)
        nn.init.uniform_(self.A)

    def forward(self, adj, p_add):
        B = self.B.view(-1)
        A = self.A.view(-1)
        n = adj.size(0)
        device = self.device
        adj = adj.to(torch.float32)

        P = torch.softmax(B, dim=0)
        V = torch.softmax(A, dim=0)
        num_edges_to_add = int(p_add * n * n)
        sampled_idx = torch.multinomial(P, num_samples=num_edges_to_add, replacement=False)  # [num_edges_to_remove]

        # Create edge addition matrix
        M = torch.zeros(n * n, device=device)
        M[sampled_idx] = 1.0  # Mark edges to be added
        M = M.view(n, n)
        V = V.view(n, n)

        adj = adj + M * (1 - adj) * V

        adj.fill_diagonal_(0)

        # Calculate log_p
        log_p = torch.logsumexp(B[sampled_idx], dim=0) - torch.logsumexp(B, dim=0)

        return adj, log_p

    def deterministic_forward(self, adj, p_add):
        B = self.B.view(-1)
        A = self.A.view(-1)
        n = adj.size(0)
        device = self.device
        adj = adj.to(torch.float32)

        P = torch.softmax(B, dim=0)
        V = torch.softmax(A, dim=0)

        num_edges_to_add = int(p_add * n * n)

        # Select edges with highest probabilities
        _, topk_idx = torch.topk(P, num_edges_to_add)

        M = torch.zeros(n * n, device=device)
        M[topk_idx] = 1.0
        M = M.view(n, n)
        V = V.view(n, n)

        adj = adj + M * (1 - adj) * V
        adj.fill_diagonal_(0)

        return adj