import torch
import torch.nn as nn
import torch.nn.functional as F

class RotationPredictorCNN(nn.Module):
    def __init__(self):
        super(RotationPredictorCNN, self).__init__()
        # Input is 1 channel (grayscale), output is the rotation angle
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(in_features=64 * 8 * 8, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=1)  # Output: predicted rotation angle

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # [batch_size, 16, 32, 32]
        x = self.pool(F.relu(self.conv2(x)))  # [batch_size, 32, 16, 16]
        x = self.pool(F.relu(self.conv3(x)))  # [batch_size, 64, 8, 8]
        x = x.view(-1, 64 * 8 * 8)           # Flatten for fully connected layer
        x = F.relu(self.fc1(x))              # 64*8*8 -> 128
        x = F.relu(self.fc2(x))              # 128 -> 32
        angle = self.fc3(x)                  # Output predicted angle
        return angle
