import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn.models import InnerProductDecoder, VGAE
from torch_geometric.nn.conv import GCNConv, SAGEConv
from torch_geometric.utils import negative_sampling, remove_self_loops, add_self_loops

from random import random 

from utils import device

class GCNEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.gcn_shared = GCNConv(in_channels, hidden_channels)
        self.gcn_mu = GCNConv(hidden_channels, out_channels)
        self.gcn_logvar = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.gcn_shared(x, edge_index))
        mu = self.gcn_mu(x, edge_index)
        logvar = self.gcn_logvar(x, edge_index)
        return mu, logvar


class GraphSAGEEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGEEncoder, self).__init__()
        self.sage_shared = SAGEConv(in_channels, hidden_channels)
        self.sage_mu = SAGEConv(hidden_channels, out_channels)
        self.sage_logvar = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.sage_shared(x, edge_index))
        mu = self.sage_mu(x, edge_index)
        logvar = self.sage_logvar(x, edge_index)
        return mu, logvar


class LinearEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(LinearEncoder, self).__init__()
        self.lin_shared = nn.Linear(in_channels, hidden_channels)
        self.lin_mu = nn.Linear(hidden_channels, out_channels)
        self.lin_logvar = nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin_shared(x))
        mu = self.lin_mu(x)
        logvar = self.lin_logvar(x)
        return mu, logvar


class GCNDecoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, hidden_channels_2, out_channels):
        super(GCNDecoder, self).__init__()
        self.gcn_z = GCNConv(in_channels, hidden_channels)
        self.gcn_middle = GCNConv(hidden_channels, hidden_channels_2)
        self.gcn_org_size = GCNConv(hidden_channels_2, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.gcn_z(x, edge_index))
        x = F.relu(self.gcn_middle(x, edge_index))
        x = self.gcn_org_size(x, edge_index)
        return x

class GraphSageDecoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, hidden_channels_2, out_channels):
        super(GraphSageDecoder, self).__init__()
        self.gcn_z = SAGEConv(in_channels, hidden_channels)
        self.gcn_middle = SAGEConv(hidden_channels, hidden_channels_2)
        self.gcn_org_size = SAGEConv(hidden_channels_2, out_channels)


    def forward(self, x, edge_index):
        x = F.relu(self.gcn_z(x, edge_index))
        x = F.relu(self.gcn_middle(x, edge_index))
        x = self.gcn_org_size(x, edge_index)
        return x

class LinearDecoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, hidden_channels_2, out_channels):
        super(LinearDecoder, self).__init__()
        self.linear_1 = nn.Linear(in_channels, hidden_channels)
        self.linear_2 = nn.Linear(hidden_channels, hidden_channels_2)
        self.linear_3 = nn.Linear(hidden_channels_2, out_channels)

    def forward(self, x):
        x = F.relu(self.linear_1(x))
        x = F.relu(self.linear_2(x))
        x = self.linear_3(x)

        return x



class VAE_FeatureGen_GCN(VGAE):
    def __init__(self, enc_in_channels, enc_hidden_channels, enc_out_channels, weight):

        super(VAE_FeatureGen_GCN, self).__init__(encoder=GCNEncoder(enc_in_channels,
                                                          enc_hidden_channels,
                                                          enc_out_channels),
                                            decoder=GCNDecoder(enc_out_channels,
                                                          enc_hidden_channels,
                                                          enc_hidden_channels,
                                                          enc_in_channels))
        self.weight = weight

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        x_pred = self.decoder(z, edge_index)
        return z, x_pred

    def decode_latent(self, x, edge_index):
        x_pred = self.decoder(x, edge_index)
        return x_pred

    def sparse_loss(self, activations, sparsity_level=0.05, beta=3):
        rho_hat = torch.mean(activations, dim=0)
        sparsity_penalty = sparsity_level * torch.log(sparsity_level / rho_hat) + \
                        (1 - sparsity_level) * torch.log((1 - sparsity_level) / (1 - rho_hat))
        return beta * torch.sum(sparsity_penalty)

    def loss(self, x, edge_index, epoch, total_epochs, sparsity_level=0.05, beta=3):

        z = self.encode(x, edge_index)
        x_pred = self.decoder(z, edge_index)

        feat_loss = F.mse_loss(x_pred, x)

        kl_loss = 1 / x.size(0) * self.kl_loss()
        sp_loss = self.sparse_loss(z, sparsity_level, beta)

        kl_weight = min(0.1, epoch / total_epochs)

        return self.weight * feat_loss + kl_weight * kl_loss + sp_loss


class AE_FeatureGen_Linear(VGAE):
    def __init__(self, enc_in_channels, enc_hidden_channels, enc_out_channels, weight):

        super(AE_FeatureGen_Linear, self).__init__(encoder=LinearEncoder(enc_in_channels,
                                                          enc_hidden_channels,
                                                          enc_out_channels),
                                            decoder=LinearDecoder(enc_out_channels,
                                                          enc_hidden_channels,
                                                          enc_hidden_channels,
                                                          enc_in_channels))
        self.weight = weight

    def forward(self, x, edge_index):
        z = self.encode(x)
        x_pred = self.decoder(z)
        return z, x_pred

    def decode_latent(self, x, edge_index):
        x_pred = self.decoder(x)
        return x_pred

    def sparse_loss(self, activations, sparsity_level=0.05, beta=3):
        rho_hat = torch.mean(activations, dim=0)
        sparsity_penalty = sparsity_level * torch.log(sparsity_level / rho_hat) + \
                        (1 - sparsity_level) * torch.log((1 - sparsity_level) / (1 - rho_hat))
        return beta * torch.sum(sparsity_penalty)

    def loss(self, x, edge_index, epoch, total_epochs):
        z = self.encode(x)
        x_pred = self.decoder(z)

        feat_loss = F.mse_loss(x_pred, x)

        sp_loss = self.sparse_loss(z)

        sp_weight = min(0.1, epoch / total_epochs)

        return feat_loss + sp_weight * sp_loss



class VAE_FeatureGen_Linear(VGAE):
    def __init__(self, enc_in_channels, enc_hidden_channels, enc_out_channels, weight):

        super(VAE_FeatureGen_Linear, self).__init__(encoder=LinearEncoder(enc_in_channels,
                                                          enc_hidden_channels,
                                                          enc_out_channels),
                                            decoder=LinearDecoder(enc_out_channels,
                                                          enc_hidden_channels,
                                                          enc_hidden_channels,
                                                          enc_in_channels))
        self.weight = weight

    def forward(self, x, edge_index):
        z = self.encode(x)
        x_pred = self.decoder(z)
        return z, x_pred

    def decode_latent(self, x, edge_index):
        x_pred = self.decoder(x)
        return x_pred

    def sparse_loss(self, activations, sparsity_level=0.05, beta=3):
        rho_hat = torch.mean(activations, dim=0)
        sparsity_penalty = sparsity_level * torch.log(sparsity_level / rho_hat) + \
                        (1 - sparsity_level) * torch.log((1 - sparsity_level) / (1 - rho_hat))
        return beta * torch.sum(sparsity_penalty)

    def loss(self, x, edge_index, epoch, total_epochs, sparsity_level=0.05, beta=3):
        z = self.encode(x)
        x_pred = self.decoder(z)

        feat_loss = F.mse_loss(x_pred, x)

        kl_loss = 1 / x.size(0) * self.kl_loss()
        sp_loss = self.sparse_loss(z, sparsity_level, beta)
        kl_weight = min(0.1, epoch / total_epochs)

        return self.weight * feat_loss + kl_weight * kl_loss + sp_loss



class VAE_FeatureGen_Sage(VGAE):
    def __init__(self, enc_in_channels, enc_hidden_channels, enc_out_channels, weight):

        super(VAE_FeatureGen_Sage, self).__init__(encoder=GraphSAGEEncoder(enc_in_channels,
                                                          enc_hidden_channels,
                                                          enc_out_channels),
                                       decoder=GraphSageDecoder(enc_out_channels,
                                                          enc_hidden_channels,
                                                          enc_hidden_channels,
                                                          enc_in_channels))
        self.weight = weight

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        # x_pred = self.decoder(z)
        x_pred = self.decoder(z, edge_index)
        return z, x_pred

    def decode_latent(self, x, edge_index):
        # x_pred = self.decoder(x)
        x_pred = self.decoder(x, edge_index)
        return x_pred

    def sparse_loss(self, activations, sparsity_level=0.05, beta=3):
        rho_hat = torch.mean(activations, dim=0)
        sparsity_penalty = sparsity_level * torch.log(sparsity_level / rho_hat) + \
                        (1 - sparsity_level) * torch.log((1 - sparsity_level) / (1 - rho_hat))
        return beta * torch.sum(sparsity_penalty)

    def loss(self, x, edge_index, epoch, total_epochs, sparsity_level=0.05, beta=3):
        z = self.encode(x, edge_index)
        # x_pred = self.decoder(z)
        x_pred = self.decoder(z, edge_index)

        feat_loss = F.mse_loss(x_pred, x)

        kl_loss = 1 / x.size(0) * self.kl_loss()
        sp_loss = self.sparse_loss(z, sparsity_level, beta)

        kl_weight = min(0.1, epoch / total_epochs)

        return self.weight * feat_loss + kl_weight * kl_loss + sp_loss

    

class DeepVGAE(VGAE):
    def __init__(self, enc_in_channels, enc_hidden_channels, enc_out_channels):
        super(DeepVGAE, self).__init__(encoder=GraphSAGEEncoder(enc_in_channels,
                                                          enc_hidden_channels,
                                                          enc_out_channels),
                                       decoder=InnerProductDecoder())

    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        adj_pred = self.decoder.forward_all(z)
        return adj_pred

    def predict(self, x, edge_index, x_gen):
        z = self.encode(x, edge_index)

        x_gen_edge_index = torch.arange(len(x_gen)).repeat(2,1).to(device)
        z_gen = self.encode(x_gen, x_gen_edge_index)

        pred = self.decoder.forward_new(z, z_gen)
        return pred

    def loss(self, x, pos_edge_index, all_edge_index):
        z = self.encode(x, pos_edge_index)
        
        pos_loss = -torch.log(
            self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()

        all_edge_index_tmp, _ = remove_self_loops(all_edge_index)
        all_edge_index_tmp, _ = add_self_loops(all_edge_index_tmp)

        neg_edge_index = negative_sampling(all_edge_index_tmp, z.size(0), pos_edge_index.size(1))
        neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()

        kl_loss = 1 / x.size(0) * self.kl_loss()

        return pos_loss + neg_loss + kl_loss

    def single_test(self, x, train_pos_edge_index, test_pos_edge_index, test_neg_edge_index):
        with torch.no_grad():
            z = self.encode(x, train_pos_edge_index)
        roc_auc_score, average_precision_score = self.test(z, test_pos_edge_index, test_neg_edge_index)
        return roc_auc_score, average_precision_score


class GraphSAGEEdgePredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGEEdgePredictor, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)
        self.decoder = InnerProductDecoder()

    def forward(self, x, edge_index):
        # Encoder
        x = F.relu(self.conv1(x, edge_index))
        z = self.conv2(x, edge_index)
        return z

    def predict(self, x, edge_index, x_gen):
        z = self.forward(x, edge_index)

        x_gen_edge_index = torch.arange(len(x_gen)).repeat(2,1).to(device)
        # x_gen_edge_index = torch.arange(len(x_gen)).repeat(2,1).to(device)
        z_gen = self.forward(x_gen, x_gen_edge_index)

        pred = self.decoder.forward_new(z, z_gen)
        return pred

    def loss(self, z, pos_edge_index, neg_edge_index):
        # Compute the positive and negative loss for edge prediction
        pos_loss = -torch.log(self.decoder(z, pos_edge_index, sigmoid=True) + 1e-15).mean()
        neg_loss = -torch.log(1 - self.decoder(z, neg_edge_index, sigmoid=True) + 1e-15).mean()
        return pos_loss + neg_loss
