import torch
from tqdm import tqdm

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import DeepGraphInfomax, SAGEConv


class Encoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.convs = torch.nn.ModuleList([
            SAGEConv(in_channels, hidden_channels),
            SAGEConv(hidden_channels, hidden_channels),
            SAGEConv(hidden_channels, hidden_channels)
        ])

        self.activations = torch.nn.ModuleList()
        self.activations.extend([
            torch.nn.PReLU(hidden_channels),
            torch.nn.PReLU(hidden_channels),
            torch.nn.PReLU(hidden_channels)
        ])

    def forward(self, x, edge_index, batch_size):
        for conv, act in zip(self.convs, self.activations):
            x = conv(x, edge_index)
            x = act(x)
        return x[:batch_size]


def corruption(x, edge_index, batch_size):
    return x[torch.randperm(x.size(0))], edge_index, batch_size


def train(epoch, train_loader, model, optimizer):
    model.train()

    total_loss = total_examples = 0
    for _, batch in enumerate(train_loader):
        optimizer.zero_grad()
        pos_z, neg_z, summary = model(batch.x, batch.edge_index,
                                      batch.batch_size)
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pos_z.size(0)
        total_examples += pos_z.size(0)

    return total_loss / total_examples


@torch.no_grad()
def test(model, test_loader):
    model.eval()

    zs = []
    for batch in tqdm(test_loader, desc='Evaluating'):
        pos_z, _, _ = model(batch.x, batch.edge_index, batch.batch_size)
        zs.append(pos_z.cpu())
    z = torch.cat(zs, dim=0)

    return z


def get_dgi_emb(data, device, batch_size=1024):
    if data.x is None:
        from torch_geometric.utils import degree
        data.x = degree(
            data.edge_index[0], num_nodes=data.num_nodes
        ).unsqueeze(-1).to(device)

    train_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], 
                                  batch_size=batch_size, shuffle=True)
    test_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], 
                                 batch_size=batch_size)

    model = DeepGraphInfomax(
        hidden_channels=128, encoder=Encoder(data.x.shape[1], 128),
        summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),
        corruption=corruption).to(device)

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

    for epoch in tqdm(range(1, 31), desc='Training'):
        loss = train(epoch, train_loader, model, optimizer)

    emb = test(model, test_loader)
    return emb