import torch


def create_loss(A,b):
    """
    Returns the natural loss function associated with a linear system:
    x -> |Ax - b|**2
    """

    def loss(x):
        aux = torch.matmul(A, x) - b
        return torch.norm(aux)**2

    return loss


def create_loss_2(A,b):
    """
    Returns the following quadratic loss function:
    x -> <x, Mx> - 2<b,x>.
    """

    def loss(x):
        c = torch.matmul(A, x) - 2*b
        #ret = torch.norm(torch.matmul(x, V), p=2)**2 - 2 * torch.dot(b,x)
        ret = torch.dot(c, x)
        return ret

    return loss


def quad_solver(A, b, n_epochs=1000, init=None, lr=0.1, return_history=False):
    """
    Minimizes |Ax|^2 - 2<b,x> using the Adam optimiser.

    constraint : matrix multiplication A*init or A*b must be defined.
    """

    if init is None:
        x = torch.randn_like(b, requires_grad=True)
    else:
        x = torch.clone(init)
        x.requires_grad = True

    opt = torch.optim.Adam([x], lr=lr)
    loss = create_loss_2(A, b)
    loss_history = []

    for epoch in range(n_epochs):
        opt.zero_grad()
        l = loss(x)
        l.backward()
        opt.step()
        loss_history.append([l])

    x.requires_grad = False
    if return_history:
        return x, loss_history
    else:
        return x

