import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(128 * 8 * 8, 256)  # Adjust input size according to the output shape after convolutions
        self.fc2 = nn.Linear(256, num_classes)

    

   #def forward(self, x):
        #x = x.to(next(self.parameters()).device)  # Ensure input is on the same device as the model
       # x = F.relu(self.conv1(x))
        #x = self.pool(F.relu(self.conv2(x)))
        #x = self.pool(F.relu(self.conv3(x)))
        #x = x.view(x.size(0), -1)
        #x = F.relu(self.fc1(x))
        #logits = self.fc2(x)  # Output logits
       # return logits, None
    def forward(self, x):
        #x = x.to(next(self.parameters()).device)  # Ensure input is on the same device as the model
        x = F.relu(self.conv1(x))
        #self.dropout = nn.Dropout(p=0.5)  # Add this in __init__
        

        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x