# File: predict.py
# Description: Generate predictions by training a binary classifier (based on an MLP model).


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
import time
import psutil


def main():
    # training_file = "../data/sbm/nodes_2400/streaming-SBM-2400.txt"
    # output_file_binary = "../data/sbm/nodes_2400/SBM_2400_prediction_binary.txt"
    
    training_file = "../data/dblp/streaming-dblp10000.txt"  # training file
    output_file_binary = "../data/dblp/dblp10000_prediction_binary.txt"

    data = []
    with open(training_file, "r") as f:
        f.readline()
        for line in f:
            v1, v2, label = line.strip().split()
            v1, v2 = int(v1), int(v2)
            label = 0 if label == "+" else 1
            data.append((v1, v2, label))

    print("Data loading complete.\n")

    start_memory = psutil.Process().memory_info().rss
    start_time = time.time()
    nodes = set(v for edge in data for v in edge[:2])
    num_nodes = len(nodes)
    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
    for v1, v2, label in data:
        if label == 0:
            adj_matrix[v1, v2] = 1
            adj_matrix[v2, v1] = 1

    degrees = adj_matrix.sum(axis=1)
    features = []
    labels = []
    for v1, v2, label in data:
        common_neighbors = np.sum(adj_matrix[v1] * adj_matrix[v2])
        is_connected = adj_matrix[v1, v2]
        features.append([degrees[v1], degrees[v2], common_neighbors, is_connected])
        labels.append(label)

    features_tensor = torch.tensor(features, dtype=torch.float32)
    labels_tensor = torch.tensor(labels, dtype=torch.float32).view(-1, 1)

    print("Step 1 complete.\n")

    X_train, X_test, y_train, y_test = train_test_split(features_tensor, labels_tensor, test_size=0.3, random_state=42)

    print("Step 2 complete.\n")

    class MLP(nn.Module):
        def __init__(self, input_dim, hidden_dim):
            super(MLP, self).__init__()
            self.fc1 = nn.Linear(input_dim, hidden_dim)
            self.fc2 = nn.Linear(hidden_dim, hidden_dim)
            self.fc3 = nn.Linear(hidden_dim, 1)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            return self.fc3(x) 

    model = MLP(input_dim=4, hidden_dim=128)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([2.0], dtype=torch.float32))

    print("Step 3 complete.\n")

    for epoch in range(10):
        model.train()
        optimizer.zero_grad()
        logits = model(X_train)
        loss = criterion(logits, y_train)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            model.eval()
            test_logits = model(X_test)
            test_probs = torch.sigmoid(test_logits)
            test_preds = (test_probs > 0.5).float()
            accuracy = (test_preds == y_test).float().mean().item()
        print(f"Epoch {epoch + 1}, Loss: {loss.item():.4f}, Test Accuracy: {accuracy:.4f}")

    end_time = time.time()
    print(f"Training time: {end_time - start_time:.2f}s")
    end_memory = psutil.Process().memory_info().rss
    print(f"Total memory used: {(end_memory - start_memory) / (1024 * 1024):.2f} MB") 

    print("Step 4 complete.\n")

    with open(output_file_binary, "w") as f_binary:
        model.eval()
        with torch.no_grad():
            test_probs = torch.sigmoid(model(features_tensor))
            predictions = (test_probs > 0.5).int()
            for (v1, v2, _), pred, prob in zip(data, predictions.view(-1).tolist(), test_probs.view(-1).tolist()):
                f_binary.write(f"{v1} {v2} {pred}\n")

    print("Step 5 complete.\n")

if __name__ == "__main__":
    main()