from utils import *
from models.gnn_models import *
from models.egnn_models import *
from sklearn.metrics import roc_auc_score
import argparse
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Set seed at module level
set_random_seed(0)


# Early stopping parameters
patience = 100
max_epochs = 500
    
num_runs = 3
test_aucs = []

parser = argparse.ArgumentParser(description="Link Prediction with Various GNN Models")
parser.add_argument('--model', type=str, default='gcn',
                    choices=['gcn', 'sage', 'gat', 'egnn', 'equiformer'])
parser.add_argument('--hid_dim', type=int, default=32)
parser.add_argument('--r', type=float, default=0.1, help="Ratio of edge removal for link prediction")
parser.add_argument('--dataset', type=str, default='nanofibres', choices=['nanofibres', 'nanowire', 'polymer'], help="Dataset to use")
args = parser.parse_args()




def train(train_data):
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x.to(device), train_data.edge_index.to(device))
    out = model.decode(z, train_data.edge_label_index.to(device))
    label = train_data.edge_label.to(device)
    loss = F.binary_cross_entropy_with_logits(out, label)
    loss.backward()
    optimizer.step()
    return loss.item()



@torch.no_grad()
def test(data, neg_per_pos=50, k=5, test=False, g=None):
    model.eval()
    z = model.encode(data.x.to(device), data.edge_index.to(device))
    out = model.decode_prob(z, data.edge_label_index.to(device))
    pred = (out > 0.5).float()
    acc = (pred == data.edge_label.to(device)).sum().item() / pred.size(0)
    
    # Binary AUC
    y_true = data.edge_label.to(device).cpu().numpy()
    y_score = out.cpu().numpy()

   
    auc = roc_auc_score(y_true, y_score)

    # if not test:
    return acc, auc, 0  # no MRR/Hit@K for validation

    

def get_model(args, in_dim=32):
    if args.model == 'gcn':
        model = GCN(in_channels=data.num_node_features, hidden_channels=args.hid_dim).to(device)    
        lr = 0.001
        patience = 100

    elif args.model == 'sage':
        model = GraphSAGE(in_channels=data.num_node_features, hidden_channels=args.hid_dim).to(device)
        lr = 0.001
        patience = 100

    elif args.model == 'gat':
        model = GAT(in_channels=data.num_node_features, hidden_channels=args.hid_dim, heads=2).to(device)
        lr = 0.001
        patience = 100

    elif args.model == 'egnn':
        model = BatchedEGNNLinkPredictor(
            in_channels=data.num_node_features, 
            hidden_channels=args.hid_dim,
            batch_size=64,  # Adjust this based on your GPU memory
            dropout=0.2
        ).to(device)
        lr = 0.005
        patience = 20

    elif args.model == 'equiformer':
        model = BatchedEquiformerV2LinkPredictor(
            in_channels=data.num_node_features,
            hidden_channels=args.hid_dim,
            num_layers=1,
            max_ell=2,
            dropout=0.1,
            batch_size=2048  # Adjust based on your GPU memory
        ).to(device)
        lr = 0.01
        patience = 50

    else:
        raise ValueError(f"Unknown model type: {args.model}")

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    return model, optimizer, patience



# LP on single graph
if args.dataset == 'nanofibres':
    data, train_data, val_data, test_data, dgl_g = get_binarized_network_v1(ratio=args.r)

   
elif args.dataset == 'nanowire':
    
    data_list = get_nanowire_network()
    
    # Use the largest graph for link prediction
    data = max(data_list, key=lambda data: data.num_nodes)
    
    transform = RandomLinkSplit(
        num_val=args.r, num_test=args.r,
        is_undirected=True,
        add_negative_train_samples=True,
        neg_sampling_ratio=1.0
    )
    train_data, val_data, test_data = transform(data)
    dgl_g = to_dgl(data)

    train_data.edge_label_index = sampling_neg_edges(train_data.edge_label_index, dgl_g, method='hops')
    val_data.edge_label_index = sampling_neg_edges(val_data.edge_label_index, dgl_g, method='hops')
    test_data.edge_label_index = sampling_neg_edges(test_data.edge_label_index, dgl_g, method='hops')



elif args.dataset == 'polymer':
    data_list = get_polymer_network()
    
    # Use the largest graph for link prediction
    data = max(data_list, key=lambda data: data.num_nodes)
    # print(f"Largest graph has {data.num_nodes} nodes and {data.edge_index.shape[1]} edges.")
    
    transform = RandomLinkSplit(
        num_val=args.r, num_test=args.r,
        is_undirected=True,
        add_negative_train_samples=True,
        neg_sampling_ratio=1.0
    )
    train_data, val_data, test_data = transform(data)
    dgl_g = to_dgl(data)

    train_data.edge_label_index = sampling_neg_edges(train_data.edge_label_index, dgl_g, method='hops')
    val_data.edge_label_index = sampling_neg_edges(val_data.edge_label_index, dgl_g, method='hops')
    test_data.edge_label_index = sampling_neg_edges(test_data.edge_label_index, dgl_g, method='hops')

else:
    raise ValueError(f"Unknown dataset: {args.dataset}")

for run in range(num_runs):
    model, optimizer, patience = get_model(args, in_dim=data.num_node_features)

    best_val_auc = 0.0
    epochs_no_improve = 0
    best_model_state = None
    
    for epoch in range(max_epochs):
        loss = train(train_data)
        val_acc, val_auc, _ = test(val_data)
        print(f'Run {run+1} | Epoch {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}')
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_model_state = model.state_dict()
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch:03d}. Best Val AUC: {best_val_auc:.4f}")
            break

    # Load best model and test
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    test_acc, test_auc, test_hit10 = test(test_data, test=True, g=dgl_g)
    test_aucs.append(test_auc)
    print(f'Run {run+1}: Test Accuracy = {test_acc:.4f}, Test AUC = {test_auc:.4f}')

print(
f'Avg Test AUC: {np.mean(test_aucs)*100:.2f} ± {np.std(test_aucs)*100:.2f} ')