#coding=utf-8
import torch.nn as nn
import torch
import torch.nn.functional as F

var_size = {
    'dsads':{
        'in_size': 45,
        'ker_size': 9,
        'fc_size': 32*25,
    },
    'usc':{
        'in_size': 6,
        'ker_size': 6,
        'fc_size': 32*46
    },
    'pamap':{
        'in_size': 27,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'emg':{
        'in_size': 8,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'wesad':{
        'in_size': 8,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'spcmd':{
        'in_size': 20,
        'ker_size': 9,
        'fc_size': 32*14
    },
    'har':{
        'in_size': 6,
        'ker_size': 9,
        'fc_size': 32*26
    },
    'shar':{
        'in_size': 3,
        'ker_size': 9,
        'fc_size': 32*31
    },
    'pshar':{
        'in_size': 3,
        'ker_size': 9,
        'fc_size': 32*31
    },
    'pdsads':{
        'in_size': 9,
        'ker_size': 9,
        'fc_size': 32*25,
    },
    'pusc':{
        'in_size': 6,
        'ker_size': 6,
        'fc_size': 32*46
    },
    'ppamap':{
        'in_size': 9,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'phar':{
        'in_size': 6,
        'ker_size': 9,
        'fc_size': 32*26
    },
    'cross_dataset':{
        'in_size': 6,
        'ker_size': 6,
        'fc_size': 32*8
    },
}

class SActNetwork(nn.Module):
    def __init__(self, taskname):
        super(SActNetwork, self).__init__()
        self.taskname=taskname
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=var_size[taskname]['in_size'], out_channels=16, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.in_features = 16*96

    def forward(self, x):
        # print(x.shape)
        x=self.conv1(x)
        # print(x.shape)
        x=x.view(-1,self.in_features)
        # print(x.shape)
        return x

class LActNetwork(nn.Module):
    def __init__(self, taskname):
        super(LActNetwork, self).__init__()
        self.taskname=taskname
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=var_size[taskname]['in_size'], out_channels=16, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.in_features = 32*5

    def forward(self, x):
        # print(x.shape)
        x=self.conv4(self.conv3(self.conv2(self.conv1(x))))
        # print(x.shape)
        x=x.view(-1,self.in_features)
        # print(x.shape)
        return x


class ActNetwork(nn.Module):
    def __init__(self, taskname):
        super(ActNetwork, self).__init__()
        self.taskname=taskname
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=var_size[taskname]['in_size'], out_channels=16, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.in_features = var_size[taskname]['fc_size']

    def forward(self, x):
        # print(x.shape)
        x=self.conv2(self.conv1(x))
        # print(x.shape)
        x=x.view(-1,self.in_features)
        # print(x.shape)
        return x
