import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from model.utils import init_base_model
from torch_geometric.loader import NeighborLoader
from ogb.nodeproppred import Evaluator
from train.train import EarlyStopping
from train.train_batch import train_epoch_batch
import gc

# Batch-based testing function using OGB Evaluator
def test_batch_ogbn(model, data, evaluator, device):
    model.eval()
    with torch.no_grad():
        data_load = data.clone().detach()
        del data_load.x, data_load.y
        subgraph_loader = NeighborLoader(
            data_load,
            #input_nodes=None,
            num_neighbors=[-1],
            batch_size=8092,
            shuffle=False)
        # No need to maintain these features during evaluation:
        
        out = model.inference_0(data, device, subgraph_loader)
        y_pred = data.y.cpu()
        y_true = out.argmax(dim=-1, keepdim=True).cpu()

        del data_load
        train_acc = evaluator.eval({
            'y_true': y_true[data.train_mask.cpu()],
            'y_pred': y_pred[data.train_mask.cpu()]
        })['acc']
        val_acc = evaluator.eval({
            'y_true': y_true[data.val_mask.cpu()],
            'y_pred': y_pred[data.val_mask.cpu()]
        })['acc']
        test_acc = evaluator.eval({
            'y_true': y_true[data.test_mask.cpu()],
            'y_pred': y_pred[data.test_mask.cpu()]
        })['acc']
        gc.collect()  # Force garbage collection

    return train_acc, val_acc, test_acc


# Run training with batch processing and OGB Evaluator
def run_training_ogbn(model, optimizer, train_loader, data, evaluator, device, epochs, patience, save_path=None):
    train_losses, val_accuracies, train_accuracies, test_accuracies = [], [], [], []
    early_stopping = EarlyStopping(patience=patience)

    with tqdm(total=epochs, desc="Training Progress") as pbar:
        for epoch in range(epochs):
            loss = train_epoch_batch(model, optimizer, train_loader, device)
            train_acc, val_acc, test_acc = test_batch_ogbn(model, data, evaluator, device)

            train_losses.append(loss)
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)

            pbar.set_postfix({
                'Loss': f'{loss:.4f}', 
                'Train Acc': f'{train_acc:.4f}', 
                'Val Acc': f'{val_acc:.4f}', 
                'Test Acc': f'{test_acc:.4f}'
            })
            pbar.update(1)

            early_stopping(val_acc, model)
            if early_stopping.early_stop:
                print(f"No performance improvement for {patience} epochs. Stopping early.")
                break
    
    if save_path is not None:
        torch.save(early_stopping.best_model_state, save_path)

    # Load the best model state
    model.load_state_dict(early_stopping.best_model_state)
    _, val_acc, test_acc = test_batch_ogbn(model, data, evaluator, device)
    print(f"Best model validation accuracy: {val_acc:.4f}")
    print(f"Best model test accuracy: {test_acc:.4f}")

    return train_losses, train_accuracies, val_accuracies, test_accuracies

def train_ogbn(data, params, device, save_path, args, batch_size=1024, num_neighbors=[15, 10, 5]):
    data = data.to(device)

    # Initialize the model using init_model
    model = init_base_model(data, params, device, args)
    optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])

    # Create NeighborLoader for train and validation
    train_loader = NeighborLoader(data, input_nodes=data.train_mask, num_neighbors=num_neighbors, batch_size=batch_size, shuffle=True)
    # Initialize OGB evaluator
    evaluator = Evaluator(name=args.dataset)

    # Run training with batch processing and OGB evaluator
    results = run_training_ogbn(model, optimizer, train_loader, data, evaluator, device, params['epochs'], params["patience"], save_path)

    return model, results