import torch
import numpy as np

# write a function to flatten tensor of shape [n,d] into a flattened np array
def to_flattened_numpy(x):
    """Flatten tensor of shape [n,d] into a flattened np array.
    Args:
        x: tensor of shape [n,d]
    Returns:
        x: flattened np array of shape [n*d]"""
    return x.flatten().detach().cpu().numpy()

# write a function to unflatten flattened np array of shape [n*d] into a tensor of shape [n,d]
def from_flattened_numpy(x, shape):
    """Unflatten flattened np array of shape [n*d] into a tensor of shape [n,d].
    Args:
        x: flattened np array of shape [n*d]
        shape: shape of the unflattened tensor [n,d]
    Returns:
        x: tensor of shape [n,d]"""
    return torch.from_numpy(x.reshape(shape)).float()

def get_score_fn(model, if_training):
    if not if_training:
        model.eval()

    return model