import numpy as np
from scipy.special import softmax, log_softmax


def one_sample_F_and_grad(W, X, y):
    loss, grad = full_batch_F_and_grad(W, X[np.newaxis], y[np.newaxis])
    return loss, grad

def reg(w):
    return np.linalg.norm(w)**2
    
def full_batch_F_and_grad(W, X, y, lam):
    num_samples = X.shape[0]
    scores = np.dot(X, W)
    probs = softmax(scores, axis=1)
    assert probs.shape == y.shape
    loss =  -np.sum(y * log_softmax(scores, axis=1)) / num_samples  + lam * np.linalg.norm(W, ord='fro')**2 
    dscores = (probs - y) / num_samples
    grad = np.dot(X.T, dscores) + 2 * lam * W
    return loss, grad