from typing import Tuple

import torch
from torch import nn


class SimpleCNN(nn.Module):
    def __init__(self, obs_space: Tuple[int, int, int], num_outputs: int):
        super().__init__()
        chan, width, height = obs_space
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(chan, 32, kernel_size=(5, 5)),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Flatten()
        )

        test_input = torch.zeros((1, chan, width, height))
        feature_size = self.feature_extractor(test_input).numel()
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 64),
            nn.LeakyReLU(),
            nn.Linear(64, num_outputs)
        )

    def forward(self, input):
        features = self.feature_extractor(input)
        return self.classifier(features)