import torch as tc

from .grad import grad
from .error import mse_TC, l2_err_TC
from .flatten import unflatten_TC


def infer_gradient(model, x, mode="forward"):
    model = model.eval()
    dedx_pred = grad(model, x, mode=mode)
    return dedx_pred

def eval_dxdt(model, x: tc.Tensor, L: tc.Tensor, dxdt: tc.Tensor, mode="forward",
              as_separate=False, verbose=True):
    """
    Evaluate time derivatives dx/dt
    if as_separate then returns (mse_dqdt, rel2_dqdt, mse_dpdt, rel2_dpdt)
    else then returns (mse_dxdt, rel2_dxdt)
    """
    dedx_pred = infer_gradient(model, x, mode)
    dxdt_pred = tc.mm(dedx_pred, L.T)
    if not as_separate:
        dxdt_mse = mse_TC(dxdt, dxdt_pred, verbose)
        dxdt_rel2 = l2_err_TC(dxdt, dxdt_pred, verbose)
        if not verbose: return dxdt_mse, dxdt_rel2
    else:
        dxdt_pred = unflatten_TC(dxdt_pred, model.n_obj, model.dof)
        dqdt_pred, dpdt_pred = dxdt_pred[..., :model.dof], dxdt_pred[..., model.dof:]

        dxdt_true = unflatten_TC(dxdt, model.n_obj, model.dof)
        dqdt_true, dpdt_true = dxdt_true[..., :model.dof], dxdt_true[..., model.dof:]

        dqdt_mse = mse_TC(dqdt_true.flatten(), dqdt_pred.flatten(), verbose)
        dqdt_rel2 = l2_err_TC(dqdt_true.flatten(), dqdt_pred.flatten(), verbose)

        dpdt_mse = mse_TC(dpdt_true.flatten(), dpdt_pred.flatten(), verbose)
        dpdt_rel2 = l2_err_TC(dpdt_true.flatten(), dpdt_pred.flatten(), verbose)

        if not verbose: return dqdt_mse, dqdt_rel2, dpdt_mse, dpdt_rel2

def eval_dedx(model, x: tc.Tensor, dedx: tc.Tensor, mode="forward", as_separate=False, verbose=True):
    """
    Evaluate Hamiltonian gradients for sanity check
    if as_separate then returns (mse_dedq, rel2_dedq, mse_dedp, rel2_dedp)
    else then returns (mse_dedx, rel2_dedx)
    """
    dedx_pred = infer_gradient(model, x, mode)
    if not as_separate:
        dedx_mse = mse_TC(dedx, dedx_pred, verbose)
        dedx_rel2 = l2_err_TC(dedx, dedx_pred, verbose)
        if not verbose: return dedx_mse, dedx_rel2
    else:
        dedx_pred = unflatten_TC(dedx_pred, model.n_obj, model.dof)
        dedq_pred, dedp_pred = dedx_pred[..., :model.dof], dedx_pred[..., model.dof:]

        dedx_true = unflatten_TC(dedx, model.n_obj, model.dof)
        dedq_true, dedp_true = dedx_true[..., :model.dof], dedx_true[..., model.dof:]

        dedq_mse = mse_TC(dedq_true.flatten(), dedq_pred.flatten(), verbose)
        dedq_rel2 = l2_err_TC(dedq_true.flatten(), dedq_pred.flatten(), verbose)

        dedp_mse = mse_TC(dedp_true.flatten(), dedp_pred.flatten(), verbose)
        dedp_rel2 = l2_err_TC(dedp_true.flatten(), dedp_pred.flatten(), verbose)

        if not verbose: return dedq_mse, dedq_rel2, dedp_mse, dedp_rel2
