#coding=utf-8
import torch.nn as nn
import torch
import torch.nn.functional as F

var_size = {
    'emg':{
        'in_size': 8,
        'ker_size': 9,
        'fc_size': 32*44
    }
}

class ActNetwork(nn.Module):
    def __init__(self, taskname):
        super(ActNetwork, self).__init__()
        self.taskname=taskname
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=var_size[taskname]['in_size'], out_channels=16, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.in_features = var_size[taskname]['fc_size']

    def forward(self, x):
        # print(x.shape)
        x=self.conv2(self.conv1(x))
        # print(x.shape)
        x=x.view(-1,self.in_features)
        # print(x.shape)
        return x
