import numpy as np
import torch

def eval_sindy_regressor(regressor, truth, model, threshold=0.05):
    with torch.no_grad():
        coef = regressor.W
        mask = regressor.mask
        coef = coef.cpu().numpy()
        mask = mask.bool().cpu().numpy()
        if model == 'di-sindy':
            coef = np.append(coef, -1.0).reshape(1, -1)
            mask = np.append(mask, 1.0).reshape(1, -1)
        coef = np.where(mask, coef, 0.0)
        truth_mask = truth != 0
        n_eqs, n_terms = coef.shape
        correct_form = np.zeros(n_eqs)
        mse = np.ones(n_eqs) * -1.0
        for i in range(n_eqs):
            correct_form[i] = np.all(mask[i, :] == truth_mask[i, :])
            mse[i] = np.mean((coef[i, truth_mask[i, :]] - truth[i, truth_mask[i, :]]) ** 2)
        correct_form_all = np.all(correct_form)
        mse_all = np.mean(mse)

    return coef, correct_form, mse, correct_form_all, mse_all

di_sindy_truth = {
    'KdV': np.array([
        [0.0, 0.0, -1.0, 0.0, -1.0],
    ]),
    'KS': np.array([
        [0.0, -1.0, 0.0, -1.0, -1.0],
    ]),
    'Burgers': np.array([
        [0.0, 0.01, 0.0, 0.0, -1.0],
    ]),
    'nKdV': np.array([
        [0.0, 0.0, -1.0, 0.0, -1.0],
    ])
}

equivsindy_r_truth = {
    'KdV': np.array([
        [0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    ]),
    'KS': np.array([
        [0.0, 0.0, -1.0, 0.0, -1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    ]),
    'Burgers': np.array([
        [0.0, 0.0, 0.01, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    ]),
    'nKdV': np.array([
        [0.0, 0.0, 0.0, -1.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    ]),
}
