import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo

import torchvision

from config import args

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

def weights_init(w):
    """
    Initializes the weights of the layer, w.
    """
    classname = w.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(w.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(w.weight.data, 1.0, 0.02)
        nn.init.constant_(w.bias.data, 0)

class MLP(nn.Module) :
    def __init__(self, activation='relu'):
        super(MLP, self).__init__()
        
        self.linear1 = nn.Linear(2,4) #input dimension:2
        self.linear2 = nn.Linear(4,2)
        self.linear3 = nn.Linear(2,2)
        if activation == 'relu':
            self.active = nn.ReLU() 
        else :
            self.active = nn.ELU()
    
    def forward(self,input):
        x = self.active(self.linear1(input))
        x = self.active(self.linear2(x))
        x = self.linear3(x)
        return x
    
    def init_weights_glorot(self):
        for m in self._modules :
            if type(m) == nn.Linear:
                nn.init.xavier_uniform(m.weight)

class CNN(nn.Module):
    # initializers, d=num_filters
    def __init__(self, dataset, d=32, activation='elu'):
        super(CNN, self).__init__()
        
        if dataset == 'mnist':
            layer1_in_channels = 1
        elif dataset == 'cifar10':
            layer1_in_channels = 3

        self.conv = nn.Sequential(
            # Layer 1
            nn.Conv2d(in_channels=layer1_in_channels, out_channels=d, kernel_size=(8, 8)), #(28-8 )+1 = 21
            nn.BatchNorm2d(d),
            nn.ELU(),
    
            # Layer 2
            nn.Conv2d(in_channels=d, out_channels=2*d, kernel_size=(6, 6)), # (21-6)+1 = 16 
            nn.BatchNorm2d(2*d)  ,          
            nn.ELU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2), # 8 
            
            # Layer 3
            nn.Conv2d(in_channels=2*d, out_channels=4*d, kernel_size=(5, 5)), # (8-5)+1 = 4
            nn.BatchNorm2d(4*d),
            nn.ELU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2), # chanel 128 feature map 2*2
            
        )
        
        # Logistic Regression
        if dataset == 'mnist':
            self.clf = nn.Linear(512, 10)
        elif dataset == 'cifar10':
            self.clf = nn.Linear(1152, 10)

    def init_weights(self, mean, std):
        for m in self._modules:
            if type(m) == nn.Linear:
                nn.init.xavier_uniform(m.weight)
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                m.weight.data.normal_(mean, std)
                m.bias.data.zero_()


    def forward(self, input): 
        
        x = self.conv(input)
        return self.clf(x.view(len(x), -1 ))
    

class Reshape(torch.nn.Module):
    def forward(self, x, dataset):
        if dataset == 'mnist':
            return x.view(-1, 1, 28, 28)
        elif dataset == 'cifar10':
            return x.view(-1, 3, 32, 32)
        else:
            raise NotImplementedType

class LeNet5(nn.Module):
    def __init__(self, dataset, num_classes=10, init_weights=True):
        super(LeNet5, self).__init__()

        self.dataset = dataset
        self.reshape = Reshape()

        if dataset == 'mnist':
            self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        elif dataset == 'cifar10':
            self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()

        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        if dataset == 'mnist':
            self.fc1 = nn.Linear(16*4*4, 120)
        elif dataset == 'cifar10':
            self.fc1 = nn.Linear(16*5*5, 120)
            
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

        # if init_weights:
        #     self._initialize_weights()

    def forward(self, x):
        x = self.reshape(x, self.dataset)
        x = self.conv1(x)  # input(1, 32, 32)  output(6, 28, 28)
        x = self.relu(x)  # activation function
        x = self.maxpool1(x)  # output(6, 14, 14)
        x = self.conv2(x)  # output(16, 10, 10)
        x = self.relu(x)  # activation function
        x = self.maxpool2(x)  # output(16, 5, 5)
        x = torch.flatten(x, start_dim=1)  # output(16*5*5) N代表batch_size
        x = self.fc1(x)  # output(120)
        x = self.relu(x)  # activation function
        x = self.fc2(x)  # output(84)
        x = self.relu(x)  # activation function
        x = self.fc3(x)  # output(num_classes)
        # x = self.relu(x)

        return x
    
    def init_weights(self, mean, std):
        for m in self._modules:
            if type(m) == nn.Linear:
                nn.init.xavier_uniform(m.weight)
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                m.weight.data.normal_(mean, std)
                m.bias.data.zero_()


def get_model(dataset):
    if args.model == 'MLP':
        model = MLP()
    elif args.model == 'CNN':
        model = CNN(dataset)
        model.init_weights(mean=0.0, std=0.02)
    elif args.model == 'LeNet':
        model = LeNet5(dataset)
        model.init_weights(mean=0.0, std=0.02)
    elif args.model == 'ResNet18':
        model = torchvision.models.resnet18(pretrained=False)
        model_path = "./model_path"
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18'], model_dir=model_path))
    else:
        raise NotImplementedType
    
    return model
