import torch
import torch.nn as nn
import numpy as np


class ConvNet(nn.Module):
	def __init__(
		self,
		original_length=256,
		num_blocks=5,
		kernel_size=3,
		padding=1,
		original_dim=1,
		out_features=12
        	):
		super(ConvNet, self).__init__()				
		
		self.out_features = out_features
		self.kernel_size = kernel_size
		self.padding = padding
		self.layers = []

		dims = [original_dim]
		dims += list(2 ** np.arange(6, 6 + num_blocks))
		dims = [x if x <= 256 else 256 for x in dims]

		for i in range(num_blocks):
			self.layers.extend([
				nn.Conv1d(dims[i], dims[i+1], kernel_size=self.kernel_size, padding=self.padding),
				nn.BatchNorm1d(dims[i+1]),
				nn.ReLU(),
			])
		self.layers.extend([
			nn.Conv1d(dims[-1], dims[-1], kernel_size=self.kernel_size, padding=self.padding),
			nn.ReLU(),
		])
		self.layers = nn.Sequential(*self.layers)
				
		self.GAP = nn.AvgPool1d(original_length)
		
		self.fc1 = nn.Sequential(
			nn.Linear(dims[-1], out_features)
		)

		
	def forward(self, x):
		"""
		Arg :
			- x : tensor of shape (batch_size, original_dim, original_length) 
		Output :
			- out : tensor of shape (batch_size, out_features)
		
		 Note : original_dim is the number of channels in the input data, >1 if multimodal
		"""
		out = self.layers(x)
		
		out = self.GAP(out)
		out = out.reshape(out.size(0), -1)
		out = self.fc1(out)
		
		return out