import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import grad
from tensorboardX import SummaryWriter

from collections import OrderedDict

#from cifar10 import CIFAR10
from vgg import vgg16
from utils import *

parser = argparse.ArgumentParser()
parser.add_argument('--logdir', type=str, default='logs/VGG16_initialization')
parser.add_argument('--total_reps', type=int, default=int(5))
parser.add_argument('--seed', type=int, default=int(1))
parser.add_argument('--num_classes', type=int, default=int(100))
args = parser.parse_args()
logger = LogSaver(args.logdir)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

for idx_rep in range(args.total_reps):
    model = vgg16(args.num_classes)
    torch.save(model.state_dict(), args.logdir+'/initialized_weight_'+str(idx_rep)+'.pth.tar')
