# from torchvision import datasets, transforms
# from torch.utils.data import Dataset, DataLoader
# from utils import *
import torch
import os
import argparse
import torch.nn as nn
import torch.nn.init as init

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='resnet', choices=['vgg','resnet'])
args = parser.parse_args()

def weight_init(m):
    '''
    Usage:
        model = Model()
        model.apply(weight_init)
    '''
    if isinstance(m, nn.Conv1d):
        init.normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.Conv2d):
        init.xavier_normal_(m.weight.data)
        if m.bias is not None:
            init.normal_(m.bias.data)
    elif isinstance(m, nn.BatchNorm1d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        init.normal_(m.weight.data, mean=1, std=0.02)
        init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        init.xavier_normal_(m.weight.data)
        init.normal_(m.bias.data)

dir='./CIFAR100/init'
if not os.path.exists(dir):
    os.makedirs(dir)
# model
if args.model == 'resnet':
    from resnet import ResNet18
    model = ResNet18()
    # model.apply(weight_init)
    torch.save(model.state_dict(), dir+'/resnet.pth')
elif args.model == 'vgg':
    from vgg import vgg11
    model = vgg11()
    # model.apply(weight_init)
    torch.save(model.state_dict(), dir+'/vgg.pth')