# arguments:  nh, units, N_data, test_num, outputfile

import torch
import learned_optimizer
import timeit

import os
import psutil
import sys

import MNIST_dataset

hl = 1  # hidden layers
units = 20  # units in hidden layers
N_data = 1000  # number of tasks
test_num = 0  # which test

horizon = 800  # max number of iterations
psize = 1000  # number of images per task

filename = 'results_test_'

if len(sys.argv) > 1:
    hl = int(sys.argv[1])

if len(sys.argv) > 2:
    units = int(sys.argv[2])

if len(sys.argv) > 3:
    N_data = int(sys.argv[3])

if len(sys.argv) > 4:
    test_num = int(sys.argv[4])

if len(sys.argv) > 5:
    filename = sys.argv[5]

print(f' hl={hl}\n units={units}\n N_data={N_data}\n test_num={test_num}\n filename={filename}')

torch.set_default_tensor_type(torch.DoubleTensor)

process = psutil.Process(os.getpid())

filename_save = f'results/{filename}{test_num}.pt'


#  GPU sync function
def cuda_sync():
    if torch.cuda.is_available():
        torch.cuda.synchronize()


# Move tensor to GPU when possible
def w(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x


GRAD_LIM = 1e-8
ulim = 1.0

if test_num == 0:
    optim_files = {  # 'L-BFGS Mod': ('autosave_lbfgs_v1.pt','lbfgs_v1'),
        #  'L-BFGS Base': ('', 'lbfgs_v1'),
        #   'L-BFGS LS++': ({'c': 0.75, 'tau': 0.75}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS00': ({'c': 0.50, 'tau': 0.50}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS--': ({'c': 0.25, 'tau': 0.25}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS+0': ({'c': 0.75, 'tau': 0.50}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS0+': ({'c': 0.50, 'tau': 0.75}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS+-': ({'c': 0.75, 'tau': 0.25}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS-+': ({'c': 0.25, 'tau': 0.75}, 'lbfgs_v1_line_search'),
        'L-BFGS LS-0': ({'c': 0.25, 'tau': 0.50}, 'lbfgs_v1_line_search'),
        #   'L-BFGS LS0-': ({'c': 0.50, 'tau': 0.25}, 'lbfgs_v1_line_search'),
        'L-BFGS cvx': ('models/autosave_lbfgs_pi_mnist_epoch_1.pt', 'lbfgs_pi')
    }

    optimizers = [('adam', {'lr': 0.03}, 'ADAM lr=0.03'),
                  ('rmsprop', {'lr': 0.01}, 'RMSprop lr=0.01')
                  ]
else:
    print('unknown test')
    exit(0)

    #################################################################3
for type, file in optim_files.items():
    if file[1] == 'lbfgs_v1':
        new_optim = learned_optimizer.LearnedLBFGS_v1()
        try:
            new_optim.load_state_dict(torch.load(file[0]))
        except FileNotFoundError:
            print(f'Loading {type} failed. File {file[0]} not found')
        optimizers += [(file[1], new_optim, type)]

    elif file[1] == 'lbfgs_v1_line_search':
        new_optim = learned_optimizer.LearnedLBFGS_v1(line_search=True, **file[0], ulim=ulim)

        optimizers += [(file[1], new_optim, type)]

    elif file[1] == 'lbfgs_pi':
        new_optim = learned_optimizer.LearnedLBFGS_pi(test=True, ulim=ulim)
        try:
            new_optim.load_state_dict(torch.load(file[0]))
        except FileNotFoundError:
            print(f'Loading {type} failed. File {file[0]} not found')
            # raise
        optimizers += [(file[1], new_optim, type)]

# loading tasks
train_data = MNIST_dataset.taskLoader(problem_size=psize, hl=hl, units=units, test=True,
                                      n_data=psize * N_data)  # 10000000)

t00 = timeit.default_timer()
results = []
with torch.no_grad():
    for n, t in enumerate(train_data):
        print(f'Test data: {n}')
        results_inner = []
        for type, par, label in optimizers:
            print(f'opt: {label}')
            print(process.memory_info().rss)
            curr_results = w(torch.zeros(horizon, 4, 1))
            if type == 'adam':
                opt = torch.optim.Adam(t.soft_parameters(), **par)
                cuda_sync()
                t0 = timeit.default_timer()
                for k in range(horizon):
                    torch.set_grad_enabled(True)
                    curr_results[k, 0] = t.f().detach()
                    g = w(t.df())
                    torch.set_grad_enabled(False)
                    opt.step()
                    cuda_sync()
                    curr_results[k, 1] = timeit.default_timer() - t0
                    curr_results[k, 2] = g.norm()
            elif type == 'rmsprop':
                opt = torch.optim.RMSprop(t.soft_parameters(), **par)
                cuda_sync()
                t0 = timeit.default_timer()
                for k in range(horizon):
                    torch.set_grad_enabled(True)
                    curr_results[k, 0] = t.f().detach()
                    g = w(t.df())
                    torch.set_grad_enabled(False)
                    opt.step()
                    cuda_sync()
                    curr_results[k, 1] = timeit.default_timer() - t0
                    curr_results[k, 2] = g.norm()
            elif type == 'lbfgs_v1_line_search':
                par = w(par)
                par.reset_dim(t.n)
                cuda_sync()
                t0 = timeit.default_timer()
                func_x = lambda x: t.fx(x) if x is None else t.fx(w(t.getPars()) + w(x))
                for k in range(horizon):
                    torch.set_grad_enabled(True)
                    curr_results[k, 0] = t.f().detach()
                    g = w(t.df())
                    if g.norm() < GRAD_LIM:
                        curr_results[k:, 0] = curr_results[k, 0]
                        break
                    torch.set_grad_enabled(False)

                    t.setPars(t.getPars() + par(g, None, func_x).view(-1))
                    cuda_sync()
                    curr_results[k, 1] = timeit.default_timer() - t0
                    curr_results[k, 2] = g.norm()
                    curr_results[k, 3] = par.last_step_size
                par.flush()
            elif type == 'lbfgs_pi':
                par = w(par)
                par.reset_dim(t.n)
                cuda_sync()
                t0 = timeit.default_timer()
                for k in range(horizon):
                    torch.set_grad_enabled(True)
                    curr_results[k, 0] = t.f().detach()
                    g = w(t.df())
                    if g.norm() < GRAD_LIM:
                        curr_results[k:, 0] = curr_results[k, 0]
                        curr_results[k:, 1] = curr_results[k, 1]
                        break
                    torch.set_grad_enabled(False)
                    t.setPars(t.getPars() + par(g).view(-1))
                    cuda_sync()
                    curr_results[k, 1] = timeit.default_timer() - t0
                    curr_results[k, 2] = g.norm()
                    curr_results[k, 3] = par.last_step_size
                par.flush()
            else:
                par = w(par)
                par.reset_dim(t.n)
                cuda_sync()
                t0 = timeit.default_timer()
                for k in range(horizon):
                    torch.set_grad_enabled(True)
                    curr_results[k, 0] = t.f().detach()
                    g = w(t.df())
                    if g.norm() < GRAD_LIM:
                        curr_results[k:, 0] = curr_results[k, 0]
                        curr_results[k:, 1] = curr_results[k, 1]
                        break
                    torch.set_grad_enabled(False)
                    t.setPars(t.getPars() + par(g).view(-1))
                    cuda_sync()
                    curr_results[k, 1] = timeit.default_timer() - t0
                    curr_results[k, 2] = g.norm()
                par.flush()
            t.reset()

            results_inner += [curr_results.view(1, horizon, 4, -1)]
        results += [torch.cat(results_inner, 0)]

# save only optimizer's state dictionaries
for i, o in enumerate(optimizers):
    try:
        optimizers[i] = (o[0], o[1].state_dict(), o[2])
    except:
        pass

# save results
torch.save({'res': torch.cat(results, 3).cpu(), 'optimizers': optimizers}, filename_save)

print(f'Took {timeit.default_timer() - t00} s to run in {curr_results.device}')
