import torch
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from model.utils import init_base_model
from torch_geometric.loader import NeighborLoader
from train.train import EarlyStopping
import gc

# Batch-based training function using NeighborLoader
def train_epoch_batch(model, optimizer, loader, device):
    model.train()
    total_loss = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Forward pass
        out = model(batch.x, batch.edge_index)
        out = out[:batch.batch_size]
        y = batch.y[:batch.batch_size].squeeze()
        loss = F.cross_entropy(out, y)

        # Backward pass and optimizer step
        loss.backward()
        optimizer.step()

        # Accumulate the loss and the number of training samples
        total_loss += float(loss)

    loss = total_loss / len(loader)
    
    return loss

# Batch-based testing function using OGB Evaluator
def test_batch(model, data, device):
    model.eval()
    with torch.no_grad():
        data_load = data.clone().detach()
        del data_load.x, data_load.y
        subgraph_loader = NeighborLoader(
            data_load,
            #input_nodes=None,
            num_neighbors=[-1],
            batch_size=8092,
            shuffle=False,
        )
        # No need to maintain these features during evaluation:
        
        out = model.inference_0(data, device, subgraph_loader)
        pred = F.log_softmax(out, dim=1).argmax(dim=1).to(device)

        del data_load
        accs = []
        for mask in [data.train_mask, data.val_mask, data.test_mask]:
            accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        gc.collect()  # Force garbage collection
    return accs


# Run training with batch processing and OGB Evaluator
def run_training_batch(model, optimizer, train_loader, data, device, epochs, patience, save_path=None):
    train_losses, val_accuracies, train_accuracies, test_accuracies = [], [], [], []
    early_stopping = EarlyStopping(patience=patience)

    with tqdm(total=epochs, desc="Training Progress") as pbar:
        for epoch in range(epochs):
            loss = train_epoch_batch(model, optimizer, train_loader, device)
            train_acc, val_acc, test_acc = test_batch(model, data, device)

            train_losses.append(loss)
            train_accuracies.append(train_acc)
            val_accuracies.append(val_acc)
            test_accuracies.append(test_acc)

            pbar.set_postfix({
                'Loss': f'{loss:.4f}', 
                'Train Acc': f'{train_acc:.4f}', 
                'Val Acc': f'{val_acc:.4f}', 
                'Test Acc': f'{test_acc:.4f}'
            })
            pbar.update(1)

            early_stopping(val_acc, model)
            if early_stopping.early_stop:
                print(f"No performance improvement for {patience} epochs. Stopping early.")
                break
    
    if save_path is not None:
        torch.save(early_stopping.best_model_state, save_path)

    # Load the best model state
    model.load_state_dict(early_stopping.best_model_state)
    _, val_acc, test_acc = test_batch(model, data, device)
    print(f"Best model validation accuracy: {val_acc:.4f}")
    print(f"Best model test accuracy: {test_acc:.4f}")

    return train_losses, train_accuracies, val_accuracies, test_accuracies

# Implementation of batch train and inference is based on pyg team's implementation
# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py

def train_batch(data, params, device, save_path, args, batch_size=1024, num_neighbors=[15, 10, 5]):
    data = data.to(device)

    # Initialize the model using init_model
    model = init_base_model(data, params, device, args)
    optimizer = optim.Adam(model.parameters(), lr=params['learning_rate'])

    # Create NeighborLoader for train and validation
    train_loader = NeighborLoader(data, input_nodes=data.train_mask, num_neighbors=num_neighbors, batch_size=batch_size, shuffle=True)

    # Run training with batch processing and OGB evaluator
    results = run_training_batch(model, optimizer, train_loader, data, device, params['epochs'], params["patience"], save_path)

    return model, results