import torch.nn as nn


def print_pre_train_model_stats(model):
    for name, layer in model.named_modules():
        if isinstance(layer, nn.Linear):  # Check if the layer is Linear
            print(f"Layer {name} has {layer.weight.size(0)} neurons.")

    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total number of parameters in the model: {total_params}")


def train_model(net, train_gen, test_gen, device, num_epochs=30, lr=1e-3):
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(net.parameters(), lr=lr, weight_decay=0.01)
    print_pre_train_model_stats(net)

    # Training loop
    for epoch in range(num_epochs):
        net.train()
        train_correct = 0
        train_total = 0
        for i, (images, labels) in enumerate(train_gen):
            # Reshape and move images and labels to GPU if available
            images = images.view(-1, 28 * 28).to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = net(images)

            loss = loss_function(outputs, labels)

            loss.backward()
            optimizer.step()

            # Training accuracy
            _, train_predicted = torch.max(outputs, 1)
            train_correct += (train_predicted == labels).sum().item()
            train_total += labels.size(0)

        train_accuracy = 100 * train_correct / train_total

        # Validation loop
        net.eval()  # Set the model to evaluation mode
        val_loss = 0
        val_correct = 0
        val_total = 0
        with torch.no_grad():  # Disable gradient computation for validation
            for val_images, val_labels in test_gen:
                val_images = val_images.view(-1, 28 * 28).to(device)
                val_labels = val_labels.to(device)

                val_outputs = net(val_images)
                val_loss += loss_function(val_outputs, val_labels).item()

                # Validation accuracy
                _, val_predicted = torch.max(val_outputs, 1)
                val_correct += (val_predicted == val_labels).sum().item()
                val_total += val_labels.size(0)

        val_loss /= len(test_gen)  # Compute average validation loss
        val_accuracy = 100 * val_correct / val_total

        print(f"Epoch [{epoch + 1}/{num_epochs}], "
              f"Train Loss: {loss.item():.4f}, Train Accuracy: {train_accuracy:.2f}%, "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
