# import geoopt
# import torch
# import math
# from geoopt import SymmetricPositiveDefinite, Sphere, Euclidean
# from utils import autograd, dot, compute_hypergrad
# from manifolds import EuclideanMod, SphereMod, LorentzMod

# def compute_jvp(loss, hparams, params, tangents):
#     """
#     Compute the cross derivative of loss(hparams, params), i.e., G_xy [tangents] where x is hparams, y is params
#     :param loss:
#     :param inputs: List[Tensors] of size hparams
#     :param tangents: List[Tensors] of size params
#     :return:
#     """
#     assert len(params) == len(tangents)
#
#     def function(params):
#         grad = autograd(loss(hparams, [params]), hparams, create_graph=True)  # list of size hparams
#         return tuple([hparam.manifold.egrad2rgrad(hparam, gg) for hparam, gg in zip(hparams, grad)])
#
#     gradA, gradxy = torch.autograd.functional.jvp(function, params, tangents)
#
#     return gradA, gradxy


from scipy.io import loadmat



if __name__ == '__main__':


    # sphere = SphereMod()
    # spd = SymmetricPositiveDefinite()
    # euclidean = EuclideanMod(ndim=1)  # note the dimension here for Euclidean space
    #
    # A = geoopt.ManifoldParameter(spd.random(5, 5), manifold=spd)
    # x = geoopt.ManifoldParameter(sphere.random(5), manifold=sphere)
    #
    # Ci = spd.random(6, 5, 5)
    # Ctr = spd.random(5, 5)
    # Cval = spd.random(5, 5)
    #
    #
    # def loss_lower(hparams, params):
    #     A = hparams[0]
    #     x = params[0]
    #     return 0.5 * x @ (A + Ctr) @ x
    #
    #
    # def loss_upper(hparams, params):
    #     A = hparams[0]
    #     x = params[0]
    #     return 0.5 * x @ (A + Cval) @ x
    #
    # def true_hess_prod(u):
    #     A = hparams[0]
    #     x = params[0]
    #     return (torch.eye(5) - x.unsqueeze(1) @ x.unsqueeze(0)) @ (A + Ctr) @ u - (x @ (A + Ctr) @ x) * u
    #
    #
    # def rhess_prod(u):
    #     egrad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #     ehess = autograd(dot(egrad, u), params)
    #     out = []
    #     with torch.no_grad():
    #         for idx, param in enumerate(params):
    #             out.append(param.manifold.ehess2rhess(param, egrad[idx], ehess[idx], u[idx]))
    #     return out
    #
    # ns_gamma = 0.01
    #
    # if ns_gamma > 0:
    #     Hinv_gy = [ns_gamma * hg for hg in Hinv_gy]
    #
    # hparams = (A,)
    # params = (x,)
    #
    # u = (torch.randn(5),)
    #
    # print(rhess_prod(u))
    # print(true_hess_prod(u[0]))





    # 1. < test over SPD qudratic function >
    # mfd = SymmetricPositiveDefiniteMod()
    # A = mfd.random(3,3)
    # X = geoopt.ManifoldParameter(mfd.random(3,3), manifold=mfd)
    # params = [X]
    #
    # def loss():
    #     return 0.5*torch.trace(X@A@X)
    #
    # # return a list of tensors
    # egrad = autograd(loss(), params)
    # # ehess_fn = lambda u: autograd(dot(egrad, u), params)
    # def rhess_fn(loss, params, u):
    #     # input is list of tensors
    #     egrad = autograd(loss(), params, create_graph=True)
    #     ehess = autograd(dot(egrad, u), params)
    #     # ehess = ehess_fn(u)
    #     out = []
    #     with torch.no_grad():
    #         for idx, param in enumerate(params):
    #             out.append(param.manifold.ehess2rhess(param, egrad[idx], ehess[idx], u[idx]))
    #     return out
    #
    # p = ts_conjugate_gradient(lambda u: rhess_fn(loss, params, u),
    #                           [mfd.egrad2rgrad(param.data, egrad[idx]) for idx,param in enumerate(params)],
    #                           params,
    #                           lam=0.01)
    #
    # print(rhess_fn(loss, params, p))
    # print([mfd.egrad2rgrad(param.data, egrad[idx]) for idx,param in enumerate(params)])


    # 2. < test on dot and autograd >
    # # need to input list of tensors
    # grad = autograd(loss, params, create_graph=True) # u need to be on tangent space (symmetric)
    # ehess_fn = lambda u: autograd(dot(grad, u), params)
    # u = [linalg.sym(torch.randn(3, 3))]
    # print(grad)
    # print(linalg.sym(A@X))
    # print(ehess_fn(u))
    # print(linalg.sym(A @ u[0]))


    # 3. < test over compute_hypergrad function >
    # sphere = SphereMod()
    # spd = SymmetricPositiveDefinite()
    # euclidean = EuclideanMod(ndim=1)  # note the dimension here for Euclidean space
    #
    # A = geoopt.ManifoldParameter(spd.random(5, 5), manifold=spd)
    # x = geoopt.ManifoldParameter(sphere.random(5), manifold=sphere)
    #
    # Ci = spd.random(6, 5, 5)
    # Ctr = spd.random(5, 5)
    # Cval = spd.random(5, 5)
    #
    #
    # def loss_lower(hparams, params):
    #     A = hparams[0]
    #     x = params[0]
    #     return 0.5 * x @ (A + Ctr) @ x
    #
    #
    # def loss_upper(hparams, params):
    #     A = hparams[0]
    #     x = params[0]
    #     return 0.5 * x @ (A + Cval) @ x
    #
    #
    # def true_hypergrad(hparams, params):
    #     A = hparams[0]
    #     x = params[0]
    #     u = torch.linalg.inv(A + Ctr) @ (A + Cval) @ x
    #     lhs = (torch.eye(5) - x.unsqueeze(1) @ x.unsqueeze(0)) @ (A + Ctr) - (x @ (A + Ctr) @ x) * torch.eye(5)
    #     rhs = (torch.eye(5) - x.unsqueeze(1) @ x.unsqueeze(0)) @ (A + Cval) @ x
    #     u = torch.linalg.solve(lhs, rhs)
    #     return 0.5 * A @ (
    #                 (x.unsqueeze(1) @ x.unsqueeze(0)) - u.unsqueeze(1) @ x.unsqueeze(0) - x.unsqueeze(1) @ u.unsqueeze(
    #             0)) @ A
    #
    #
    #
    # ns_gamma = 0.01
    # def reg_rhess_prod(u):
    #     egrad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #     ehess = autograd(dot(egrad, u), params)
    #     out = []
    #     with torch.no_grad():
    #         for idx, param in enumerate(params):
    #             out.append(u[idx] - ns_gamma * param.manifold.ehess2rhess(param, egrad[idx], ehess[idx], u[idx]))
    #     return out
    #
    # def true_reg_rhess_prod(u):
    #     # u - ns_gamma * (A + Ctr) @ u
    #     egrad = (A + Ctr) @ x
    #     ehess = (A + Ctr) @ u
    #     rhess = ehess - x.unsqueeze(1) @ x.unsqueeze(0) @ ehess - x.manifold.inner(x, x, egrad) * u
    #     return u - ns_gamma*rhess
    #
    #
    # hparams = (A,)
    # params = (x,)
    # u = (torch.randn(5),)
    # hygrad1 = true_hypergrad(hparams, params)
    # hygrad2 = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='cg', ns_gamma=0.001)
    #
    # print(hygrad1)
    # print(hygrad2)
    #
    # print(true_reg_rhess_prod(u[0]))
    # print(reg_rhess_prod(u))
    #
    # egrad = autograd(loss_lower(hparams, params), params, create_graph=True)
    # ehess = autograd(dot(egrad, u), params)
    #
    # print(ehess)
    # print((A + Ctr) @ u[0])


    # <test automatic differentiation euclidean space for hyperrep example>
    # double-check the implementation in hypergrad in Kaiyi ji paper.
    # sphere = SphereMod()
    # spd = SymmetricPositiveDefinite()
    # vector = EuclideanMod(ndim=1)  # note the dimension here for Euclidean space
    # matrix = EuclideanMod(ndim=2)
    #
    # n1 = 50
    # n2 = 30
    # d = 10
    # r = 5
    #
    # W = geoopt.ManifoldParameter(matrix.random(d, r), manifold=matrix) # haprams
    # beta = geoopt.ManifoldParameter(vector.random(r), manifold=vector) #params
    #
    # X1 = torch.randn(n1, d)
    # X2 = torch.randn(n2, d)
    # Wstar = torch.randn(d,r)
    # betastar = torch.randn(r)
    #
    # y1 = X1 @ Wstar @ betastar + torch.randn(n1)
    # y2 = X2 @ Wstar @ betastar + torch.randn(n2)
    #
    # gamma = 0.1
    # def loss_lower(hparams, params):
    #     ww = hparams[0]
    #     bb = params[0]
    #     return 0.5 * torch.norm(X2 @ ww @ bb - y2)**2/n2 + 0.5 * gamma * torch.norm(bb)**2
    #
    #
    # def loss_upper(hparams, params):
    #     ww = hparams[0]
    #     bb = params[0]
    #     return 0.5 * torch.norm(X1 @ ww @ bb - y1)**2/n1
    #
    #
    # def true_hypergrad(hparams, params):
    #     ww = hparams[0]
    #     bb = params[0]
    #     Gxf = X1.T @ (X1 @ ww @ bb - y1).unsqueeze(1) @ bb.unsqueeze(0)/n1
    #     Gyf = ww.T @ X1.T @ (X1 @ ww @ bb - y1)/n1
    #     HinvGyf = torch.linalg.solve(ww.T @ X2.T @ X2 @ ww/n2 + gamma * torch.eye(r), Gyf)
    #     GxygHinvGyf = X2.T @ X2 @ ww @ HinvGyf.unsqueeze(1) @ bb.unsqueeze(0)/n2 + X2.T @ X2 @ ww @ bb.unsqueeze(1) @ HinvGyf.unsqueeze(0)/n2 - X2.T @ y2.unsqueeze(1) @ HinvGyf.unsqueeze(0)/n2
    #
    #     # print(Gxf)
    #     # print(Gyf)
    #     # print(HinvGyf)
    #     # print(GxygHinvGyf)
    #
    #     return Gxf - GxygHinvGyf
    #
    #
    # hparams = [W]
    # params = [beta]
    #
    #
    #
    # # def lower_update(hparams, params):
    # #     S = 50
    # #     eta_y = 0.01
    # #     for ii in range(S):
    # #         grad = autograd(loss_lower(hparams, params), params, create_graph=True)
    # #         # for param, egrad in zip(params, grad):
    # #             # rgrad = param.manifold.egrad2rgrad(param, egrad)
    # #             # param = param.manifold.retr(param, -eta_y * rgrad)
    # #             # param =
    # #         params = [param - eta_y * egrad for param, egrad in zip(params, grad)]
    # #     return params
    #
    # S = 500
    # eta_y = 0.01
    # mfd_params = [beta.manifold]
    # eta_x = 0.01
    # epoch = 10
    #
    # for ep in range(epoch):
    #     for ii in range(S):
    #         # grad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #         # params = [param - eta_y * egrad for param, egrad in zip(params, grad)]
    #         grad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #         rgrad = [mfd.egrad2rgrad(param, egrad) for mfd, egrad, param in zip(mfd_params, grad, params)]
    #         params = [mfd.retr(param, - eta_y * rg) for mfd, param, rg in zip(mfd_params, params, rgrad)]
    #         with torch.no_grad():
    #             print(f"Loss {loss_lower(hparams, params):.4f}")
    #
    # # egrad = autograd(loss_upper(hparams, params), hparams)
    #
    #     hypergrad = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='ad')
    #     print(hypergrad[0] - true_hypergrad(hparams, params))
    #     print()
    #
    #     params = [param.detach().clone().requires_grad_(True) for param in params]
    #
    #     with torch.no_grad():
    #         for hparam, hg in zip(hparams, hypergrad):
    #             new_hparam = hparam.manifold.retr(hparam, - eta_x * hg)
    #             hparam.copy_(new_hparam)
    #
    #         print(f"Epoch {ep}: "
    #               f"loss upper: {loss_upper(hparams, params).item():.4f}, "
    #               f"hypergrad norm: {hparams[0].manifold.inner(hparams[0], hypergrad[0]).item():.2f}")





    # print(compute_hypergrad(loss_lower, loss_upper, hparams, params)[0])







    # --- un used -----
    # 4. < test over hyperbolic strongly convex inner problems >
    #


    # 5. < test over hyperbolic distance function and ns approximation >
    # lorentz = LorentzMod()
    #
    # d = 5
    # atr = lorentz.random( d)
    #
    # x = geoopt.ManifoldParameter(lorentz.random(d), manifold=lorentz)
    #
    # hparams = []
    # params = [x]
    #
    # ns_gamma = 0.01
    # ns_iter = 30
    #
    # def loss_lower(hparams, params):
    #     x = params[0]
    #     return lorentz.dist(x, atr) **2
    #
    # def true_hess(hparams, params):
    #     x = params[0]
    #
    #
    # def reg_rhess_prod(u):
    #     egrad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #     ehess = autograd(dot(egrad, u), params)
    #     out = []
    #     with torch.no_grad():
    #         for idx, param in enumerate(params):
    #             out.append(u[idx] - ns_gamma * param.manifold.ehess2rhess(param, egrad[idx], ehess[idx], u[idx]))
    #     return out
    #
    #
    # rgradfy = lorentz.proju(params[0], torch.randn(d))
    #
    # with torch.no_grad():
    #     Hinv_gy_prev = [g.clone().detach() for g in rgradfy]
    #     Hinv_gy = [g.clone().detach() for g in rgradfy]
    #     for ins in range(ns_iter):
    #         with torch.enable_grad():
    #             Hinv_gy_new = reg_rhess_prod(Hinv_gy_prev)
    #         Hinv_gy = [hg + hg_new for hg, hg_new in zip(Hinv_gy, Hinv_gy_new)]
    #         Hinv_gy_prev = Hinv_gy_new
    #
    # if ns_gamma > 0:
    #     Hinv_gy = [ns_gamma * hg for hg in Hinv_gy]




    ####### some random tests #########
    # import torch
    # from torch import nn
    # from collections import OrderedDict
    # torch.manual_seed(42)
    #
    # def MiniimageNetFeats(hidden_size):
    #     def conv_layer(ic, oc):
    #         # return nn.Sequential(OrderedDict([
    #         #     ("conv", nn.Conv2d(ic, oc, 3, padding=1)),
    #         #     ("relu", nn.ReLU(inplace=True)),
    #         #     ("maxpool", nn.MaxPool2d(2)),
    #         #     ("bn", nn.BatchNorm2d(oc, momentum=1., affine=True, track_running_stats=False))
    #         # ]))
    #         return nn.Sequential(
    #             nn.Conv2d(ic, oc, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
    #             nn.BatchNorm2d(oc, momentum=1., affine=True,
    #                            track_running_stats=False
    #                            )
    #         )
    #
    #     net = nn.Sequential(
    #         conv_layer(3, hidden_size),
    #         conv_layer(hidden_size, hidden_size),
    #         conv_layer(hidden_size, hidden_size),
    #         conv_layer(hidden_size, hidden_size),
    #         nn.Flatten())
    #
    #     # initialize(net)
    #     return net
    #
    # model = MiniimageNetFeats(6)

    # for name, p in model.named_parameters():
    #     print(name, (p.size()))

    # params = [torch.randn(4,3,requires_grad=True)]
    # new_params = [params[0].view(2,2,3)]
    #
    # loss = torch.norm(new_params[0])**2
    #
    # g = torch.autograd.grad(loss, params, create_graph=True)
    # gg = torch.autograd.grad(loss, new_params)
    # print(g[0])
    # print(gg)

    # params = [torch.randn(4, requires_grad=True)]
    # hparams = [torch.randn(4, requires_grad=True)]
    #
    # # new_params = params
    # # new_params = [p**2+hparams[0]*p for p in new_params]
    # # new_params = [p*2/hparams[0] for p in new_params]
    # new_params = params
    # new_params = [p**2+hparams[0]*p for p in new_params]
    # new_params = [p*2/hparams[0] for p in new_params]
    #
    #
    # # print(new_params[0] * 4 * params[0])
    #
    # loss = torch.norm(new_params[0])**2/2
    # print(torch.autograd.grad(loss, hparams))
    #
    # p = params[0]
    # h = hparams[0]
    # print(2 * (p ** 2 + h * p) / h * (- 2*p**2/h**2))



    #### Spd dataset
    data = loadmat('data/spddb_afew_train_spd400_int_histeq.mat')

    print(data)








