# -*- coding: utf-8 -*

import torch.nn as nn
import math

"""
Alexnet
"""

class AlexNet(nn.Module):

    def __init__(self, args, init_scale=1.0):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(args.input_dim, 192, kernel_size=5, stride=1),
            nn.MaxPool2d(kernel_size=3),
            nn.LocalResponseNorm(1),
            
            nn.Conv2d(192, 256, kernel_size=5, stride=1),
            nn.MaxPool2d(kernel_size=3),  
            nn.LocalResponseNorm(1),   
        )

        if args.activation == 'tanh':
            self.classifier = nn.Sequential(
                nn.Linear(256, 384),
                nn.Tanh(),
                nn.Linear(384, 192),
                nn.Tanh(),
                nn.Linear(192, args.num_classes),
            )
        elif args.activation == 'sigmoid':
            self.classifier = nn.Sequential(
                nn.Linear(256, 384),
                nn.Sigmoid(),
                nn.Linear(384, 192),
                nn.Sigmoid(),
                nn.Linear(192, args.num_classes),
            )
        elif args.activation == 'leaky_relu':
            self.classifier = nn.Sequential(
                nn.Linear(256, 384),
                nn.LeakyReLU(inplace=True),
                nn.Linear(384, 192),
                nn.LeakyReLU(inplace=True),
                nn.Linear(192, args.num_classes),
            )
        else:
            if args.dropout == 'dropout':
                self.classifier = nn.Sequential(
                    nn.Linear(256, 384),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(384, 192),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(192, args.num_classes),
                )
            elif args.dropout == 'dropout_1':
                self.classifier = nn.Sequential(
                    nn.Linear(256, 384),
                    nn.ReLU(inplace=True),
                    nn.Linear(384, 192),
                    nn.ReLU(inplace=True),
                    nn.Dropout(),
                    nn.Linear(192, args.num_classes),
                )
            else:
                self.classifier = nn.Sequential(
                    nn.Linear(256, 384),
                    nn.ReLU(inplace=True),
                    nn.Linear(384, 192),
                    nn.ReLU(inplace=True),
                    nn.Linear(192, args.num_classes),
                )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, init_scale * math.sqrt(2. /n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

                size = m.weight.size()
                fan_out = size[0]
                fan_in = size[1]
                variance = math.sqrt(2.0/(fan_in + fan_out))
                m.weight.data.normal_(0.0, init_scale * variance)
                
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 256)
        x = self.classifier(x)
        return x
