import torch
from regress_utils import *
import argparse

parser = argparse.ArgumentParser(description="train regression neural network")
parser.add_argument("-f", "--filename", type=str, default="d15w15o10", help="Name of the file")
parser.add_argument("-b", "--batch", type=int, default=100, help="Batch size")
parser.add_argument("-e", "--epochs", type=int, default=5, help="Number of epochs")
parser.add_argument("-w", "--width", type=int, default=18, help="Width")
parser.add_argument("-d", "--depth", type=int, default=18, help="Depth")
parser.add_argument("-o", "--dim_out", type=int, default=10, help="Output dimension")
parser.add_argument("-r", "--runs", type=int, default=5, help="Number of runs to average over")

args = parser.parse_args()
filename = args.filename
batch = args.batch
epochs = args.epochs
width = args.width
depth = args.depth
dim_out = args.dim_out
runs = args.runs

train_data, test_data = torch.load(f'teacher_{filename}/data_{filename}')

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=len(test_data), shuffle=False)

dim_in = len(test_data[0][0])


for run in range(runs):

    opt_names = {
        'AGNES, lr=1e-4, eta=1e-3, m=.99':'AGNES(self.net.parameters(), lr=1e-4 , friction=.99, correction=1e-3)',
        'AGNES, lr=1e-4, eta=1e-2, m=.99':'AGNES(self.net.parameters(), lr=1e-4 , friction=.99, correction=1e-2)',
        'NAG, lr=1e-3, m=.99': 'torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99, nesterov=True)',
        'SGD, lr=1e-3, m=.99': 'torch.optim.SGD(self.net.parameters(), lr=1e-3, momentum=0.99)',
        'SGD, lr=1e-2, m=.9': 'torch.optim.SGD(self.net.parameters(), lr=1e-2, momentum=0.9)',
        'ADAM 1e-3': 'torch.optim.Adam(self.net.parameters(), lr=1e-3,)',
        'ADAM 1e-4': 'torch.optim.Adam(self.net.parameters(), lr=1e-4,)',
        'ADAM 1e-2': 'torch.optim.Adam(self.net.parameters(), lr=1e-2,)',
        }

    for key, opt_name in opt_names.items():
        directory = f'teacher_{filename}/learner_d{depth}w{width}b{batch}'
        model = ffn(dim_in, depth, width, dim_out, lrelu_slope =.1)
        #model.fc = nn.Linear(model.fc.in_features, 10, bias=True)
        net = trainer(model = model, opt_name = opt_name, train_loader=train_loader, test_loader=test_loader)
        net.train(save_dir = os.path.join(directory,str(run)+key), num_epochs=epochs, seed=False)

#plot_results(batch,epochs,directory)
