import matplotlib
import matplotlib.pyplot as plt
import torch
import hypergrad as hg
import numpy as np
import torch.nn.functional as F
from torchvision import datasets
#from sklearn.datasets import make_spd_matrix as spd
import time
from utils.options import args_parser
from utils.dataset_normal import load_data
# from models.ModelBuilder import build_model
from utils.my_logging import Logger

torch.random.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


p, q, d, n = 784, 200, 10, 50000 # CG was not working with n=10000
m = 0.1

Xg = torch.randn((n, p)) # data （5000，256）
H1_true = torch.randn((p, q)) # layer 1 （256，512）
H2_true = torch.randn((q, d)) # layer 2 （512, 256）
w_true = torch.randn(d) # final regression layer

yg = F.sigmoid(F.sigmoid(Xg @ H1_true) @ H2_true) @ w_true + m * torch.randn(n)
Xf = torch.randn((n, p))
yf = F.sigmoid(F.sigmoid(Xf @ H1_true) @ H2_true) @ w_true + m * torch.randn(n)

b = 0.001

# T(.)
def hypernet(hparams, X):
    H1 = hparams[0]
    H2 = hparams[1]
    out = F.sigmoid(F.sigmoid(X @ H1) @ H2)

    return out


# cal loss
def regressor(params, Z, y):
    w = params[0]
    loss = (torch.norm(Z @ w - y)) ** 2

    return loss


def inner_func(params, hparams):
    w = params[0]
    # sigmoid
    Zg = hypernet(hparams, Xg)
    lg = regressor(params, Zg, yg)
    g = 0.5 * lg / n + 0.5 * b * (torch.norm(w)) ** 2

    return g  # .squeeze()


def outer_func(params, hparams):
    w = params[0]
    # sigmoid
    Zf = hypernet(hparams, Xf)
    lf = regressor(params, Zf, yf)
    f = 0.5 * lf / n

    return f  # .squeeze()


###########################################################
alpha = .001
p0 = [torch.randn(d)]
hp0 = [torch.randn((p, q)), torch.randn((q, d))]


def map_func(params, hparams):
    g = inner_func(params, hparams)
    inner_losses.append(g.item())
    # print(torch.norm(torch.autograd.grad(g, params, create_graph=True)[0]))

    return [params[0] - alpha * torch.autograd.grad(g, params, create_graph=True)[0]]


def inner_solver(hparams, steps=20, params0=None, optim=None):
    # params = [torch.randn(d).requires_grad_(True)]
    # w
    params = [p.requires_grad_(True) for p in p0]

    Zg = hypernet(hparams, Xg)
    # N = 20
    for _ in range(steps):
        loss = 0.5 * regressor(params, Zg, yg) / n + 0.5 * b * (torch.norm(params[0])) ** 2
        # a = torch.autograd.grad(loss, params, create_graph=True)[0]
        params = [params[0] - alpha * torch.autograd.grad(loss, params, create_graph=True)[0]]

    return params


# ESJ method
K = 100
eval_interval = 10
T = 10
mu = .1
beta = .01

args = args_parser()
dataset_train, dataset_test, dict_users, args.img_size, dataset_train_real = load_data(args)

# hparams = [torch.randn((p, d)).requires_grad_(True)]
hparams = [hp.clone() for hp in hp0]
hparams = [hp.requires_grad_(True) for hp in hparams]

outer_opt = torch.optim.Adam(lr=beta, params=hparams)

total_time, val_losses, running_time, hg_norms = 0, [], [], []

# K=100
for k in range(K):

    step_start_time = time.time()
    inner_losses = []
    params = inner_solver(hparams, steps=T)
    # params y_k^N
    t1 = time.time() - step_start_time  # inner loop time

    outer_opt.zero_grad()
    _, cost = hg.hgvesj(params, hparams, outer_func, inner_solver, mu=mu, T=T, p=1, set_grad=True)
    t2 = time.time() - step_start_time - t1  # hypergrad estimation time
    val_losses.append(cost.item())
    outer_opt.step()

    step_time = time.time() - step_start_time
    total_time += step_time
    running_time.append(total_time)
    hg_norms.append(torch.norm(hparams[0].grad))

    if k % eval_interval == 0 or k == K - 1:
        print('outer step={} ({:.2e}s)({:.2e}, {:.2e}) | val loss={} | hypergrad norm = {:.3e}'.format(k, step_time, t1,
                                                                                                       t2,
                                                                                                       val_losses[-1],
                                                                                                       torch.norm(
                                                                                                           hparams[
                                                                                                               0].grad)))

print('total time = {}'.format(total_time))

plt.title('validation loss')
plt.xlabel('outer steps')
plt.plot(running_time, val_losses)
plt.show()

norm_zoj = hg_norms
val_zoj = val_losses
run_zoj = running_time