import torch
from htssr.primitives import special_parameter_id
from htssr.utils import tc_fast_eval_expr


def mse(pred, y):
    return ((pred - y)**2).mean()

def full_hessian(loss, w):
    g  = torch.autograd.grad(loss, w, create_graph=True)[0]
    H  = torch.zeros(w.numel(), w.numel())
    for i in range(w.numel()):
        H[i] = torch.autograd.grad(g[i], w, retain_graph=True)[0]
    return g, H

def fit_constants(
    ids,
    variables,
    target,
    max_iter=5,
    tol_grad=1e-8,
    damping=1e-3,
    d_mult=10.0,
    n_inits=1,
):
    nparams = sum([_id == special_parameter_id for _id in ids])
    if nparams == 0:
        return None, None, None, None
    min_loss = float("inf")
    best_params = None
    best_it = None
    for _ in range(n_inits):
        params = torch.rand(nparams) # , requires_grad=True
        params = 10.0 * params - 5.0
        # params = params.to(target.device)
        params.requires_grad_(True)
        for it in range(max_iter):
            pred  = tc_fast_eval_expr(ids, variables, params)
            loss  = mse(pred, target)
            grad, H = full_hessian(loss, params)
            if grad.abs().max() < tol_grad:
                break
            # Passo LM
            H_lm  = H + damping * torch.eye(nparams)
            try:
                delta = torch.linalg.solve(H_lm, -grad)
            except:
                break
            # Testar o passo
            loss_old = loss.item()
            with torch.no_grad():
                new_params = params + delta
                new_loss   = mse(
                    tc_fast_eval_expr(ids, variables, new_params),
                    target,
                ).item()
            if new_loss < loss_old:          # aceitou
                with torch.no_grad():
                    params += delta
                damping /= d_mult
            else:                            # rejeitou
                damping *= d_mult
        trial_loss = loss.item()
        if trial_loss < min_loss:
            min_loss = trial_loss
            best_params = params
            best_it = it
    min_rel_loss = min_loss / (target**2).mean().item()
    # min_rel_loss = min_rel_loss ** 0.5
    return best_params, best_it, min_loss, min_rel_loss

def __fit_constants(
    model,
    params,
    variables,
    target,
    max_iter=5,
    tol_grad=1e-8,
    damping=1e-3,
    d_mult=10.0,
):
    nsamples = len(variables["x"])
    nparams = len(params)
    for it in range(max_iter):
        pred  = model(params, variables)
        loss  = mse(pred, target)
        grad, H = full_hessian(loss, params)
        if grad.abs().max() < tol_grad:
            break
        # Passo LM
        H_lm  = H + damping * torch.eye(nparams)
        delta = torch.linalg.solve(H_lm, -grad)
        # Testar o passo
        loss_old = loss.item()
        with torch.no_grad():
            new_params = params + delta
            new_loss   = mse(model(new_params, variables), target).item()
        if new_loss < loss_old:          # aceitou
            with torch.no_grad():
                params += delta
            damping /= d_mult
        else:                            # rejeitou
            damping *= d_mult
    return params, (it + 1), loss.item()
