"""Evaluation code (metrics, etc.)"""

import numpy as np
from src import data


def get_ols_error(X, Y, Xt, Yt, batch_size):
  error = 0.0
  dim = Xt.shape[1]
  for i in range(batch_size):
    w_hat = np.linalg.lstsq(X[i, :, :], Y[i, :], rcond=None)[0]
    # return MSE / dimension 
    error += (w_hat @ Xt[i, :] - Yt[i]) ** 2 / dim
  return error / batch_size


def get_model_error(trained_model, test_prompt, test_labels):
  estimated_labels = trained_model(test_prompt, training=False)[:, -1, 0]
  dim = test_prompt.shape[2]
  # return MSE / dimension 
  return (np.linalg.norm(estimated_labels - test_labels) ** 2) / (
      dim * len(test_labels)
  )

def save_weights(model, fpath):
  model.save_weights(fpath)

def load_weights(model, fpath):
  model.load_weights(fpath)