import torch.nn as nn
from models.utils.activations import *

def get_activation_function(activation_type, tau=0, inplace_flag=True):
    print('#' * 60)
    print('Model Activation Function : ' + activation_type)
    print('#' * 60)

    while True:
        if activation_type == 'relu':
            yield nn.ReLU(inplace=inplace_flag)
        elif activation_type == 'sigmoid':
            yield nn.Sigmoid()
        elif activation_type == 'tanh':
            yield nn.Tanh()
        elif activation_type == 'softsign':
            yield nn.Softsign()
        elif activation_type == 'elliot':
            yield Elliot()
        elif activation_type == 'lecun':
            yield LeCun_tanh()
        elif activation_type == 'shifted_tanh':
            yield ShiftedTanh(tau)
        else:
            print("Invalied activation function type")
            exit()


def make_conv_block(channels_in=None, channels_out=None, activation_generator=None, oper_order='cba',
                    kernel_size=3, stride=1, padding=1, bn_momentum=0.1):
    '''
    In "arch" parameter

    'c' denotes 2d Convolution operation,
    'd' denotes depthwise convolution operation,
    'p' denotes pointwise convolution operation,
    'b' denotes 2d BatchNorm operation,
    'a' denotes activation operation,
    'D' denotes dropout operation,
    '''

    layers = []
    channel = channels_in

    for operation_type in oper_order:
        if 'c' == operation_type:
            l = nn.Conv2d(channel, channels_out, kernel_size=kernel_size, stride=stride,
                                            padding=padding, bias=False)
            channel = channels_out
        elif 'd' == operation_type:
            l = nn.Conv2d(channel, channels_out, kernel_size=kernel_size, stride=stride, groups=channel,
                                            padding=padding, bias=False)
            channel = channels_out
        elif 'p' == operation_type:
            l = nn.Conv2d(channel, channels_out, 1, bias=False)
            channel = channels_out
        elif 'b' == operation_type:
            l = nn.BatchNorm2d(channel, momentum=bn_momentum)
        elif 'D' == operation_type:
            l = nn.Dropout2d(p=0.5)
        elif 'a' == operation_type:
            l = next(activation_generator)
        else:
            print("conv")
            print(operation_type)
            print("Invalied operation type")
            exit()

        layers.append(l)

    return nn.Sequential(*layers)


def make_fc_block(activations_in=None, activations_out=None, 
                activation_generator=None, oper_order='fba', bn_momentum=0.1):
    '''
    'f' denotes fully-connected operation,
    'b' denotes 1d BatchNorm operation,
    'a' denotes activation operation,
    'D' denotes dropout operation,
    '''
    layers = []
    activations = activations_in

    for operation_type in oper_order:
        if 'f' == operation_type or 'c' == operation_type:
            if 'c' == operation_type:
                l = nn.Linear(activations, activations_out, bias=True)
            if 'f' == operation_type:
                l = nn.Linear(activations, activations_out, bias=True)
            activations = activations_out
        elif 'b' == operation_type:
            l = nn.BatchNorm1d(activations, momentum=bn_momentum)
        elif 'D' == operation_type:
            l = nn.Dropout(p=0.5)
        elif 'a' == operation_type:
            l = next(activation_generator)
        else:
            print(operation_type)
            print("Invalied operation type")
            exit()

        layers.append(l)

    return nn.Sequential(*layers)
