# test on maximally strech frechet mean of two sets with dimensionality reduction
# here hparams is W (stiefel) and params is X (SPD)

import torch
from torch import nn
import geoopt
import numpy as np

from geoopt import linalg
from geoopt import SymmetricPositiveDefinite
from geoopt import Stiefel, Euclidean
import time

from utils import autograd, compute_hypergrad, compute_jvp, batch_egrad2rgrad
from utils import ts_conjugate_gradient
from manifolds import EuclideanMod, SymmetricPositiveDefiniteMod
import argparse
from optimizer import RHGD



def loss_lower(hparams, params):
    W = hparams[0]
    X = params[0]
    return torch.mean(spd.dist(X, W.T @ Atr @ W)**2)/2 - lam * torch.logdet(X)

def loss_upper(hparams, params):
    W = hparams[0]
    X = params[0]
    return torch.mean(spd.dist(X, W.T @ Aval @ W)**2)/2




if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--eta_x', type=float, default=0.01)
    parser.add_argument('--eta_y', type=float, default=0.01)
    parser.add_argument('--lower_iter', type=int, default=50)
    parser.add_argument('--epoch', type=int, default=200)
    parser.add_argument('--hygrad_opt', type=str, default='cg', choices=['hinv', 'cg', 'ns', 'ad'])
    parser.add_argument('--ns_gamma', type=float, default=0.01)
    parser.add_argument('--ns_iter', type=int, default=5)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(args.seed)
    print(device)
    torch.random.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    mfd = SymmetricPositiveDefinite()
    N = 100
    d = 50
    r = 10
    lam = 0.1
    epoch = 100
    S = 10  # number of inner iterations
    eta_y = 0.1
    eta_x = 0.1
    hg_opt = 'ns'

    A = mfd.random(N, d, d, device=device)  # N dxd SPD matrices

    ntr = int(N / 2)
    Atr = A[:ntr]
    Aval = A[ntr:]

    spd = SymmetricPositiveDefiniteMod()
    stiefel = Stiefel(canonical=False)

    params = [geoopt.ManifoldParameter(spd.random(r,r, device=device), manifold=spd)]
    hparams = [geoopt.ManifoldParameter(torch.eye(d, r, device=device), manifold=stiefel)]
    mfd_params = [param.manifold for param in params]


    # for ep in range(epoch):
    #     step_start_time = time.time()
    #
    #     # lower level update (depending on whether we use ad)
    #     for ii in range(S):
    #         if hg_opt == 'ad':
    #             grad = autograd(loss_lower(hparams, params), params, create_graph=True)
    #             rgrad = [mfd.egrad2rgrad(param, egrad) for mfd, egrad, param in zip(mfd_params, grad, params)]
    #             params = [mfd.retr(param, - eta_y * rg) for mfd, param, rg in zip(mfd_params, params, rgrad)]
    #             #
    #             # egrad = autograd(loss_upper(hparams, params), hparams)
    #         else:
    #             grad = autograd(loss_lower(hparams, params), params)
    #             with torch.no_grad():
    #                 for param, egrad in zip(params, grad):
    #                     rgrad = param.manifold.egrad2rgrad(param, egrad)
    #                     new_param = param.manifold.retr(param, -eta_y * rgrad)
    #                     param.copy_(new_param)
    #
    #         print(f"Loss {loss_lower(hparams, params):.4f}")
    #
    #     # compute hypergrad estimate
    #     # hypergrad = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='cg')
    #     #
    #     # true_hg = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='hinv', true_hessinv=true_hessinv)
    #     # # assertion = [(hp.manifold._check_vector_on_tangent(hp,hg)) for hp, hg in zip(hparams, hypergrad)]
    #     #
    #     # ns_hg = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='ns',
    #     #                           ns_gamma=0.01, ns_iter=1000)
    #     # true_hg = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='hinv', true_hessinv=true_hessinv)
    #     hypergrad = compute_hypergrad(loss_lower, loss_upper, hparams, params, option=hg_opt)
    #
    #     # deactivate the computational path
    #     if hg_opt == 'ad':
    #         # params = [param.detach().clone().requires_grad_(True) for param in params]
    #         params = [geoopt.ManifoldParameter(param.detach().clone(), manifold=mfd) for mfd, param in zip(mfd_params, params)]
    #
    #     # true_hg = compute_hypergrad(loss_lower, loss_upper, hparams, params, option='hinv', true_hessinv=true_hessinv)
    #     # print(hypergrad[0] - true_hg[0])
    #
    #     with torch.no_grad():
    #         for hparam, hg in zip(hparams, hypergrad):
    #             new_hparam = hparam.manifold.retr(hparam, - eta_x * hg)
    #             hparam.copy_(new_hparam)
    #
    #         print(f"Epoch {ep}: "
    #               f"loss upper: {loss_upper(hparams, params).item():.4f}, "
    #               f"hypergrad norm: {hparams[0].manifold.inner(hparams[0], hypergrad[0]).item():.2f}")
    #
    #     step_time = time.time() - step_start_time

    RHGD(loss_lower, loss_upper, hparams, params, args)