from gymnasium import spaces
import torch as th
from torch import nn
import numpy as np

from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

class ConvNet(nn.Module):
        def __init__(self):
            super(ConvNet, self).__init__()				
            self.original_length=128
            self.num_blocks=5
            self.kernel_size=3
            self.padding=1  
            self.original_dim=1

            self.layers = []

            dims = [self.original_dim]
            dims += list(2 ** np.arange(6, 6 + self.num_blocks))
            print(dims)
            dims = [x if x <= 256 else 256 for x in dims]

            for i in range(self.num_blocks):
                self.layers.extend([
                    nn.Conv1d(dims[i], dims[i+1], kernel_size=self.kernel_size, padding=self.padding),
                    nn.BatchNorm1d(dims[i+1]),
                    nn.ReLU(),
                ])
            self.layers.extend([
                nn.Conv1d(dims[-1], dims[-1], kernel_size=self.kernel_size, padding=self.padding),
                nn.ReLU(),
            ])
            self.layers = nn.Sequential(*self.layers)
                    
            self.GAP = nn.AvgPool1d(self.original_length)
            
            self.fc1 = nn.Sequential(
                nn.Linear(dims[-1], 12)
            )

            
        def forward(self, x):
            """
            Arg :
                - x : tensor of shape (batch_size, original_dim, original_length) 
            Output :
                - out : tensor of shape (batch_size, out_features)
            
            Note : original_dim is the number of channels in the input data, >1 if multimodal
            """
            out = self.layers(x)
            out = self.GAP(out)
            out = out.reshape(out.size(0), -1)
            out = self.fc1(out)
            
            return out

