import numpy as np

from models.LinearModel import get_error

from sklearn.linear_model import Ridge
from sklearn.linear_model import LinearRegression


def linear_regression(dataset):
    X = dataset.data.cpu().detach().numpy()
    Y = dataset.labels.cpu().detach().numpy()

    model = LinearRegression(fit_intercept=False)
    model.fit(X, Y)

    params = model.coef_
    y_pred = model.predict(X)
    resid = Y - y_pred
    err = get_error(y_pred, Y)

    print('Resid: ', resid.max(), resid.min())
    print('Err: ', err)

    return params, resid, err


def ridge_regression(dataset):
    X = dataset.data.cpu().detach().numpy()
    Y = dataset.labels.cpu().detach().numpy()
    ridge = Ridge(alpha=1e-9, tol=1e-9, fit_intercept=False)
    ridge.fit(X, Y)
    params = ridge.coef_
    y_pred = ridge.predict(X)
    resid = Y - y_pred
    err = get_error(Y, y_pred)

    print('resid: ', resid.max(), resid.min())
    print('err: ', err)

    return params, resid, err

def lin_alg(dataset, solver, tol=1e-1):
    x = dataset.data.cpu().detach().numpy()
    y = dataset.labels.cpu().detach().numpy()

    x = x + np.eye(x.shape[0]) * tol  # Tikhonov regularisation a.k.a. ridge regression

    print("Cond:", np.log10(np.linalg.cond(x)))

    params = np.linalg.inv(x) @ y
    resid = x @ params
    err = get_error(resid + y, y)
    return params, resid, err
