import torch

import MNIST_dataset
from learned_optimizer import LearnedLBFGS_pi

import psutil
import os


# Move tensor to GPU when possible
def w(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x

restart = 0
epoch = 50
n_train = 10
horizon = 80
trunc = 50
STOP_CRIT = 1e-10
m = 5
test_name = 'mnist'


torch.set_default_tensor_type(torch.DoubleTensor)
path_to_autosave_partial = lambda x: f'models/autosave_lbfgs_pi_{test_name}_epoch_{str(x)}.pt'

path_to_autosave = f'models/autosave_lbfgs_pi_{test_name}.pt'
process = psutil.Process(os.getpid())


train_data = MNIST_dataset.load_training(problem_size=1000, n_data=60000)


n_train = len(train_data)

my_optim = w(LearnedLBFGS_pi(n_in=2, m_in=m, ulim=1.0))

try:
    my_optim.load_state_dict(torch.load(path_to_autosave))
except FileNotFoundError:
    print('load not possible!')

optimizer = (torch.optim.Adadelta(my_optim.parameters(), lr=1.0))

torch.autograd.set_detect_anomaly(True)

for idx in range(restart, epoch * n_train):
    t = train_data[idx % n_train]
    my_optim.reset_dim(t.n)
    loss = 0
    for k in range(horizon):
        aux = t.f()
        #
        loss = loss + aux

        g = t.df().detach()
        if g.norm() < STOP_CRIT or loss > 1e20:
            break
        if k == 0:
            print(f'idx={idx}/{epoch* n_train}\nIni: {loss:10.4}')
        else:
            print(f'{k:3}: F:{aux:10.4}\t g:{g.norm():10.4}')
        d = my_optim.forward(g)

        t.setPars(t.getPars() + d.view(-1))
        if (k+1) % trunc == 0:
            optimizer.zero_grad()
            loss.backward()

            optimizer.step()
            t.detachState()
            my_optim.detach()
            loss=0
    print(f'Fin: {t.f()}')
    if (k+1) % trunc != 0:
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    t.reset()
    my_optim.flush()
    if (idx+1) % n_train == 0:
        torch.save(my_optim.state_dict(), path_to_autosave_partial((idx+1)//n_train))


    print(process.memory_info().rss)
    torch.save(my_optim.state_dict(), path_to_autosave)
