import torch.nn as nn
#from torchvision import models
from trainer import trainer
from resnet_kuangliu import *
import argparse

parser = argparse.ArgumentParser(description="train regression neural network")
parser.add_argument("-b", "--batch", type=int, default=100, help="Batch size")
parser.add_argument("-e", "--epochs", type=int, default=5, help="Number of epochs")
parser.add_argument("-r", "--runs", type=int, default=5, help="Number of runs to average over")
#parser.add_argument("-s", "--schedule", type=int, default=5, help="Number of epochs after which the lr should be dropped")

args = parser.parse_args()
batch = args.batch
epochs = args.epochs
schedule = epochs//2
runs = args.runs

for run in range(runs):

    model = ResNet50()
    #model.fc = nn.Linear(model.fc.in_features, 10, bias=True)
    initial_state_dict = model.state_dict()


    opt_names = {
        'AGNES, eta=.01':'AGNES(self.net.parameters(), lr={} , friction={} , correction={}, weight_decay={})'.format(1e-3, 0.99, 0.01, 1e-5),
        'AGNES, eta=.01 fixed':'AGNES(self.net.parameters(), lr={} , friction={} , correction={}, weight_decay={})'.format(1e-3, 0.99, 0.01, 1e-5),
        'ADAM': 'torch.optim.Adam(self.net.parameters(), lr=1e-3, weight_decay=1e-5)',
        'SGD, m=.99': 'torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99, weight_decay=1e-5)',
        'NAG, m=.99': 'torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99, weight_decay=1e-5, nesterov=True)',
        }

    for key, opt_name in opt_names.items():

        model.load_state_dict(initial_state_dict)
        net = trainer(model = model, opt_name = opt_name)
        #net.load_parameters(f'exp_results/rn18_lr1e-3/{key}_100.pth')
        net.train(save_dir = f'exp_results/rn50_batch{batch}/{key}_r{run}', batch_size=batch, num_epochs=epochs, seed=run, schedule_lr_epochs=schedule, lr_factor=0.1)
