# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn as nn
from .frequency_filter import FrequencyFilter

var_size = {
    'EMG': {
        'in_size': 8,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'DSADS':{
        'in_size': 45,
        'ker_size': 9,
        'fc_size': 32*25
    },
    'PAMAP':{
        'in_size': 27,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'USCHAD':{
        'in_size': 6,
        'ker_size': 6,
        'fc_size': 32*46
    },
    'UCIHAR':{
        'in_size': 9,
        'ker_size': 6,
        'fc_size': 32*28
    },
    'SHAR': {
        'in_size': 3,
        'ker_size': 6,
        'fc_size': 32*34
    },
    'OPP': {
        'in_size': 77,
        'ker_size': 6,
        'fc_size': 32*3
    },
    'PCL': {
        'in_size': 48,
        'ker_size': 6,
        'fc_size': 32*183
    },
    'HHAR': {
        'in_size': 6,
        'ker_size': 6,
        'fc_size': 32*121
    },
    'Spurious_Fourier': {
        'in_size': 1,
        'ker_size': 6,
        'fc_size': 32*8
    },
    'WESAD': {
        'in_size': 8,
        'ker_size': 9,
        'fc_size': 32*44
    },
    'EEG': {    
        'in_size': 1,
        'ker_size': 9,
        'fc_size': 32*744
    }
}



class ActNetwork(nn.Module):
    def __init__(self, taskname):
        super(ActNetwork, self).__init__()
        self.taskname = taskname
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=var_size[taskname]['in_size'], out_channels=16, kernel_size=(
                1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(
                1, var_size[taskname]['ker_size'])),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.in_features = var_size[taskname]['fc_size']

    def forward(self, x):
        x = self.conv2(self.conv1(x))
        # print(x.shape)
        x = x.view(-1, self.in_features)
        return x


class ActFreqNetwork(ActNetwork):
    def __init__(self, taskname, mask_spectrum, freq_type):
        super().__init__(taskname)
        self.mask_spectrum = mask_spectrum
        self.disentanglement = FrequencyFilter(mask_spectrum = mask_spectrum, freq_type = freq_type)
    
    def forward(self, x):
        x = x.squeeze(dim=2).permute(0, 2, 1)
        _, x_det = self.disentanglement(x)
        x = x_det.permute(0, 2, 1).unsqueeze(dim=2)
        x = self.conv2(self.conv1(x))
        x = x.view(-1, self.in_features)
        return x