import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.utils import dropout_adj
from models import NormLayer



class LogReg(nn.Module):
    def __init__(self, hid_dim, n_classes):
        super(LogReg, self).__init__()
        self.fc = nn.Linear(hid_dim, n_classes)

    def forward(self, x):
        ret = self.fc(x)
        return ret


def drop_features(x, drop_prob):
    drop_mask = torch.empty(
        (x.size(1), ),
        dtype=torch.float32,
        device=x.device).uniform_(0, 1) < drop_prob
    x = x.clone()
    x[:, drop_mask] = 0
    return x


def index_to_mask(index, size=None):
    index = index.view(-1)
    size = int(index.max()) + 1 if size is None else size
    mask = index.new_zeros(size, dtype=torch.bool)
    mask[index] = True
    return mask


class GRACE(torch.nn.Module):
    def __init__(self, encoder, input_dim, num_hidden, num_proj_hidden, tau, drop_rate, args):
        super(GRACE, self).__init__()
        self.encoder = encoder
        
        self.tau = tau
        self.drop_rate = drop_rate
        self.num_proj_hidden = num_proj_hidden
        self.fc1 = torch.nn.Linear(input_dim, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)

        self.reg_coef = args.reg_coef
        self.l1_penalty = torch.nn.L1Loss(size_average=False)
        self.norm_layer = NormLayer(args, num_proj_hidden)

    def augmentation(self, edge_index, x):
        edge_index_1 = dropout_adj(edge_index, p=self.drop_rate[0])[0]
        edge_index_2 = dropout_adj(edge_index, p=self.drop_rate[1])[0]
        x1 = drop_features(x, self.drop_rate[2])
        x2 = drop_features(x, self.drop_rate[3])
        return edge_index_1, edge_index_2, x1, x2

    def forward(self, edge_index, x):
        edge_index_1, edge_index_2, x1, x2 = self.augmentation(edge_index, x)
        z1 = self.encoder(x1, edge_index_1)
        z2 = self.encoder(x2, edge_index_2)
        return z1, z2

    def get_embedding(self, edge_index, x, random=False, constant=False):
        z = self.encoder.get_embeddings(x, edge_index, random=random, constant=constant)
        return z.detach()

    def projection_mlp(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        z = self.norm_layer(z)
        return self.fc2(z)
    
    def projection_direct(self, z: torch.Tensor) -> torch.Tensor:
        return z[: , :self.num_proj_hidden]

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())
        
    def infonce(self, z1, z2):
        f = lambda x: torch.exp(x / self.tau)
        between_sim = f(self.sim(z1, z2))
        alignment_loss = -torch.log(between_sim.diag())
        refl_sim = f(self.sim(z1, z1))
        uniformity_loss = torch.log(refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())
        loss = alignment_loss + uniformity_loss
        return loss
    
    def infonce_loss(self, z1, z2, l1_norm=False):
        z1 = self.projection_mlp(z1)
        z2 = self.projection_mlp(z2)
        if l1_norm:
            reg_loss = 0
            param = self.encoder.lins[0].weight
            target = torch.zeros_like(param)
            reg_loss += self.l1_penalty(param, target)
        l1 = self.infonce(z1, z2)
        l2 = self.infonce(z2, z1)
        ret = (l1 + l2) * 0.5
        ret = ret.mean()
        if l1_norm:
            # ret += self.reg_coef * reg_loss
            ret = ret - self.reg_coef * self.encoder.lins[0].weight.std()
        return ret


class Discriminator(nn.Module):
    def __init__(self, dim):
        super(Discriminator, self).__init__()
        self.fn = nn.Bilinear(dim, dim, 1)

    def forward(self, h, h_neg, g):
        g_x = g.expand_as(h).contiguous()
        sc_1 = self.fn(h, g_x).squeeze(1)
        sc_2 = self.fn(h_neg, g_x).squeeze(1)
        logits = torch.cat((sc_1, sc_2))
        return logits
    

class DGI(nn.Module):
    def __init__(self, encoder, out_dim):
        super(DGI, self).__init__()
        self.encoder = encoder
        self.disc = Discriminator(out_dim)

        self.act_fn = nn.ReLU()
        self.loss_fn = nn.BCEWithLogitsLoss()

    def get_embedding(self, edge_index, feat):
        h = self.encoder(x=feat, edge_index=edge_index)
        return h.detach()

    def forward(self, edge_index, feat, shuf_feat):
        h = self.encoder(x=feat, edge_index=edge_index)
        h_neg = self.encoder(x=shuf_feat, edge_index=edge_index)
        g = self.act_fn(torch.mean(h, dim=0))
        logits = self.disc(h, h_neg, g)
        return logits