import torch as tc
import numpy as np

import torch.nn as nn


def load_linear_params_from_numpy(linear_layer: nn.Linear, weight: np.ndarray, bias: np.ndarray):
    """
    Loads given weights and biases of type numpy.ndarray into torch Linear module.

    Args:
        linear_layer            pointer to torch.nn.Linear
        weight                  weight parameters of shape (n_in, n_out)
        bias                    bias parameters of shape (1, n_out)
    """
    assert weight.shape == (linear_layer.in_features, linear_layer.out_features), \
            f"Expected weight shape {(linear_layer.in_features, linear_layer.out_features)} got {weight.shape}"
    assert bias.shape == (1, linear_layer.out_features), \
            f"Expected bias shape {(1, linear_layer.out_features,)} got {bias.shape}"
    with tc.no_grad():
        linear_layer.weight.copy_(tc.from_numpy(weight.T))
        linear_layer.bias.copy_(tc.from_numpy(bias.squeeze(0)))
