# Initialize parameters (weights and biases)
def init_params(key, inputsz, outputsz):
    w_key, b_key = random.split(key)
    W = random.normal(w_key, (inputsz, outputsz)) * 0.01 
    b = random.normal(b_key, (outputsz,))
    return W, b
#########################################################################################################
# Define the model: simple linear function followed by softmax
def predict(params, x):
    W, b = params
    logits = jnp.dot(x, W) + b
    return jax.nn.softmax(logits)

# Cross-entropy loss function with regularization
def cross_entropy_loss(params, x, y, mu, sigma, t, p, reg,  lambda_reg):
    W, b = params
    preds = predict(params, x)
    ce_loss = -jnp.sum(y * jnp.log(preds), axis=1) #size of (batchsize, )
    ce_loss_batch = jnp.mean(ce_loss) #scalar

    '''
    #Easier view of predicate below
    if reg == 0:
      reg_loss = 0
    elif reg == 1:
      # Lp regularization term
      reg_loss = reg_loss_lp(p, W, lambda_reg)
    elif reg == 2:
      # Our Regularisation
      reg_loss = our_reg_loss(p, W, b, x, y, mu, sigma, t, lambda_reg=0.01)
    '''

    reg_loss = jnp.where(
        jnp.equal(reg, 2),
        our_reg_loss(p, W, b, x, y, mu, sigma, t, lambda_reg),
        jnp.where(
            jnp.equal(reg, 1),
            reg_loss_lp(p, W, lambda_reg),
            jnp.where(
                jnp.equal(reg, 0),
                0,
                0  # Default to 0 if reg is not 0, 1, or 2
            )
        )
    )
    #jax.debug.print("reg_loss: {}", reg_loss)
    return ce_loss_batch + reg_loss

def accuracy(params, x, y):
    predictions = jnp.argmax(predict(params, x), axis=1)
    targets = jnp.argmax(y, axis=1)
    return jnp.mean(predictions == targets)

@jit
def train_step(params, x, y, mu, sigma, t, p, reg, learning_rate, lambda_reg):
    grads = grad(cross_entropy_loss)(params, x, y, mu, sigma, t, p, reg, lambda_reg)

    W, b = params
    dW, db = grads
    # Grad dsc
    new_W = W - learning_rate * dW
    new_b = b - learning_rate * db

    return (new_W, new_b)

# Training loop
def train(key , params, x_train, y_train, x_test, y_test, mu, sigma, t, p, reg, epochs=100, batch_size=32, learning_rate=0.1, lambda_reg=0.01,runs=3):
    num_train = x_train.shape[0]
    #Listss to store accuracies
    train_accuracies = jnp.empty((runs,epochs))
    test_accuracies = jnp.empty((runs,epochs))

    for run in range(runs):
      train_acc_run = jnp.empty(epochs)
      test_acc_run = jnp.empty(epochs)
      print("Run ",run + 1)
      for epoch in range(epochs):
          #Shuffle data
          perm = random.permutation(key, num_train)
          x_train = x_train[perm]
          y_train = y_train[perm]

          #Mini-batch grad dsc
          for i in range(0, num_train, batch_size):
              x_batch = x_train[i:i+batch_size]
              y_batch = y_train[i:i+batch_size]
              params = train_step(params, x_batch, y_batch, mu, sigma, t, p, reg, learning_rate, lambda_reg)

          train_acc = accuracy(params, x_train, y_train)
          test_acc = accuracy(params, x_test, y_test)

          train_acc_run = train_acc_run.at[epoch].set(train_acc)
          test_acc_run = test_acc_run.at[epoch].set(test_acc)
          print(f'Epoch {epoch+1}, Train accuracy: {train_acc:.4f}, Test accuracy: {test_acc:.4f}')

      train_accuracies = train_accuracies.at[run].set(train_acc_run)
      test_accuracies = test_accuracies.at[run].set(test_acc_run)
    return params , train_accuracies , test_accuracies
