
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

# Define Dataset class
class GrayscaleDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.image_files = [f for f in os.listdir(data_path) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.data_path, img_name)
        image = Image.open(img_path).convert("RGB")  # Convert to grayscale
        label = int(img_name.split('_label_')[1].split('.png')[0])
        
        if self.transform:
            image = self.transform(image)

        return image, label

"""
# Define the CNN Model
class Grayscale_SimpleCNN(nn.Module):
    def __init__(self):
        super(Grayscale_SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 64, 128)  # Adjust dimensions
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Dynamically compute the flattened size
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
"""

# Define the CNN Model for RGB Input
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # Change input channels to 3
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 64, 128)  # Ensure this matches the flattened size of the output
        self.fc2 = nn.Linear(128, 2)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Dynamically compute the flattened size
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    




    
