import jax
import jax.numpy as jnp
from functools import partial

@jax.jit
def projection(x, y_target):
    return jnp.dot(x, y_target) / jnp.dot(y_target, y_target) # * y_target

@jax.jit
def momentum_projection(x, momentum):
    return jnp.dot(x, momentum) / jnp.dot(momentum, momentum) * momentum

@jax.jit
def closed_form_linear_regression(x, y_target):
    # Add a column of ones to the input to account for the bias
    X_b = jnp.hstack([x.T, jnp.ones((x.shape[1], 1))])

    # Compute the closed-form solution using the normal equation
    # theta_best = (X_b.T @ X_b)^(-1) @ X_b.T @ y_target
    theta_best = jnp.linalg.inv(X_b.T @ X_b) @ X_b.T @ y_target

    # The first element of theta_best is the bias, and the rest are the weights
    # b = theta_best[:]
    # w = theta_best[1:]

    return theta_best

@jax.jit
def linear_regression_loss(w, b, x, y_target):
    y_pred = jnp.dot(w, x) + b  # Element-wise multiplication and sum across features
    return jnp.mean((y_pred - y_target)**2)



@jax.jit
def gradient_descent_step(w, b, x, y_target, lr):
    loss, gradients = jax.value_and_grad(linear_regression_loss, argnums=(0, 1))(w, b, x, y_target)
    w -= lr * gradients[0]
    b -= lr * gradients[1]
    return w, b, loss
@partial(jax.jit, static_argnums=(2,3,4))
def gradient_descent(x, y_target, lr, comp_iter, pop_size):
    _w = jnp.zeros(pop_size)
    _b = jnp.zeros(1)
    loss = []
    for _ in range(comp_iter):
        _w, _b, _loss = gradient_descent_step(_w, _b, x, y_target, lr)
        loss.append(_loss)
    return _w, _b, jnp.array(loss)