
print("============Running test.py===============")
import os
import sys
#print(os.system('which python'))
import matplotlib
import matplotlib.pyplot as plt
import torch
import hypergrad as hg
import numpy as np
from sklearn.model_selection import train_test_split
import torch.nn.functional as F


if torch.cuda.is_available():
    print("Num GPUs found: ", torch.cuda.device_count())
    print("GPU name", torch.cuda.get_device_name(0))


# Helper functions to deal with cuda and double precision
cuda = True
double_precision = False

default_tensor_str = 'torch.cuda' if cuda else 'torch'
default_tensor_str += '.DoubleTensor' if double_precision else '.FloatTensor'
torch.set_default_tensor_type(default_tensor_str)

def frnp(x):
    t = torch.from_numpy(x).cuda() if cuda else torch.from_numpy(x)
    return t if double_precision else t.float()

def tonp(x, cuda=cuda):
    return x.detach().cpu().numpy() if cuda else x.detach().numpy()

seed = 1
n, d = 1000, 1
val_perc = 0.5
np.random.seed(seed)

w_oracle = np.random.randn(d)
x = np.random.randn(n, d)
y = x  @ w_oracle + 0.1 * np.random.randn(n)
y = (y > 0.).astype(float) # binary classification output

x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=val_perc)
x_train, x_val, y_train, y_val = frnp(x_train), frnp(x_val), frnp(y_train), frnp(y_val)

# problem definition
hparams = [torch.ones(d).requires_grad_(True)]
inner_lr = 0.1

val_losses = []
def val_loss(params, hparams, x=x_val, y=y_val):
    val_loss = F.binary_cross_entropy_with_logits(x @ params[0], y)
    val_losses.append(tonp(val_loss))
    return val_loss

inner_losses = []
def fp_map(params, hparams):
    w = params[0]
    loss = F.binary_cross_entropy_with_logits(x_train @ w, y_train) +\
           0.5 * (w.unsqueeze(0) @ torch.diag(hparams[0]) @ w.unsqueeze(1)).sum()
    inner_losses.append(tonp(loss))
    return [params[0] - inner_lr*torch.autograd.grad(loss, params, create_graph=True)[0]]


#****************************************************************************************

# # inner optimization
# T = 100
#
# inner_losses = []
# params_history = [[torch.zeros(d).requires_grad_(True)]]
# for t in range(T):
#     new_params = fp_map(params_history[-1], hparams)
#     params_history.append(new_params)
#
#     if t % 10 == 0 or t == T-1:
#         print("t={}: inner loss = {}".format(t, inner_losses[-1]))
#
#
# plt.title('inner loss')
# plt.xlabel('t')
# plt.plot(inner_losses)
# plt.savefig('inner_losses.png', bbox_inches='tight')
# plt.close()
# #plt.show()
#
# plt.title('params_history')
# plt.xlabel('t')
# plt.plot([tonp(p[0]) for p in params_history])
# plt.savefig('params_history.png', bbox_inches='tight')
# plt.close()
# #plt.show()
#
# # hypergradients comparison
# K = T
# final_params = params_history[-1]
# for hg_name, hg_alg in (('reverse', hg.reverse),
#                         ('fixed_point', hg.fixed_point),
#                         ('CG', hg.CG)):
#     h_grads = []
#     for k in range(1, K):
#         if hg_name == 'reverse':
#             h_grad = hg.reverse(params_history[-k-1:], hparams, [fp_map]*k, val_loss)
#         else:
#             h_grad = hg_alg(final_params, hparams, k, fp_map, val_loss)
#
#         h_grads.append(tonp(h_grad[0]))
#
#     plt.title(hg_name + ' approx hypergradient')
#     plt.xlabel('k')
#     plt.plot(h_grads)
#     plt.savefig(hg_name + '_hypergrads.png', bbox_inches='tight')
#     plt.close()
#     #plt.legend()
#     #plt.show()
#
# # check that reverse K=T and reverse_unroll have same output
# h_grad_reverse_unroll = hg.reverse_unroll(params_history[-1], hparams, val_loss)
# h_grad_reverse = hg.reverse(params_history[-T-1:], hparams, [fp_map]*T, val_loss)
# h_grad_diff = torch.sum(h_grad_reverse_unroll[0] - h_grad_reverse[0]).item()
# print('reverse_unroll - reverse (should be 0):', h_grad_diff)


#*******************************************************************************

# #HPO
# import time
# outer_steps = 1000
# eval_interval = 10
# T, K = 100, 10
#
# hparams = [torch.ones(d).requires_grad_(True)]
# outer_opt = torch.optim.SGD(lr=1., momentum=.9, params=hparams)
# total_time, val_losses = 0,  []
# hparams_history = [tonp(hparams[0].clone())]
# for o_step in range(outer_steps):
#
#     step_start_time = time.time()
#
#     inner_losses = []
#     params_history = [[torch.zeros(d).requires_grad_(True)]]
#     for t in range(T):
#         new_params = fp_map(params_history[-1], hparams)
#         params_history.append(new_params)
#
#     outer_opt.zero_grad()
#     #hg.reverse_unroll(params_history[-1], hparams, val_loss, set_grad=True)
#     #hg.reverse(params_history[-K-1:], hparams, [fp_map]*(K), val_loss, set_grad=True)
#     hg.CG(params_history[-1], hparams, K, fp_map, val_loss, set_grad=True)
#     #hg.fixed_point(params_history[-1], hparams, K, fp_map, val_loss, set_grad=True)
#     outer_opt.step()
#     hparams[0].data.clamp_(min=1e-8)
#
#     step_time = time.time()-step_start_time
#     total_time +=step_time
#     hparams_history.append(tonp(hparams[0].clone()))
#
#     if o_step % eval_interval == 0 or o_step == outer_steps:
#         print('o_step={}({:.2e}s) val loss={} '.format(o_step, step_time, val_losses[-1]))
#
# print('total time = {}'.format(total_time))
#
# plt.title('validation loss')
# plt.xlabel('outer steps')
# plt.plot(val_losses)
# plt.savefig('plots/val_loss.png', bbox_inches='tight')
# plt.close()
# #plt.show()
#
# plt.title('reg hparams')
# plt.xlabel('outer steps')
# plt.plot(hparams_history)
# plt.savefig('plots/hyperparams_history.png', bbox_inches='tight')
# plt.close()
# #plt.show()


#**************************************************************

def inner_solver(hparams, steps=100):

   params = [torch.zeros(d).requires_grad_(True)]

   for t in range(steps):

       params = fp_map(params, hparams)

   return params


#HPO
import time
outer_steps = 400
eval_interval = 10
T, K = 100, 10
mu = 0.05

hparams = [torch.ones(d)]
hparams = [torch.Tensor.uniform_(hparams[0], 0, 2).requires_grad_(True)]
outer_opt = torch.optim.SGD(lr=1., momentum=.9, params=hparams)
total_time, val_losses = 0,  []
hparams_history = [tonp(hparams[0].clone())]
for o_step in range(outer_steps):

    step_start_time = time.time()

    inner_losses = []
    #params_history = [[torch.zeros(d).requires_grad_(True)]]

    params = inner_solver(hparams, steps=T)

    outer_opt.zero_grad()

    hg.ZeroOrderHypergradient(params, hparams, val_loss, inner_solver, mu=mu, T=T, set_grad=True)

    outer_opt.step()
    hparams[0].data.clamp_(min=1e-8)

    step_time = time.time()-step_start_time
    total_time +=step_time
    hparams_history.append(tonp(hparams[0].clone()))

    if o_step % eval_interval == 0 or o_step == outer_steps:
        print('o_step={}({:.2e}s) val loss={} '.format(o_step, step_time, val_losses[-1]))

print('total time = {}'.format(total_time))

plt.title('validation loss')
plt.xlabel('outer steps')
plt.plot(val_losses)
plt.savefig('plots/val_loss2.png', bbox_inches='tight')
plt.close()
#plt.show()

plt.title('reg hparams')
plt.xlabel('outer steps')
plt.plot(hparams_history)
plt.savefig('plots/hyperparams_history2.png', bbox_inches='tight')
plt.close()
