import mxnet as mx
import numpy as np
from sklearn import metrics
from sklearn.linear_model import Ridge, RidgeCV


def solve_ridge_regression(
    F,
    X,
    y,
    lmbd,
    mode="inverse",
    precision="double",
    sample_weights=None,
    rescale=True,
    bias=None,
):

    if mode != "inverse" and bias is not None:
        raise NotImplementedError("bias can be used only with inverse mode")

    if not mode == "inverse":
        nonzero = (F.sum(X, axis=-1)).tostype("csr").indices
        X = X[:, nonzero]
        y = y[:, nonzero]
        sample_weights = (
            sample_weights[:, nonzero] if sample_weights is not None else None
        )

    if rescale:
        # y_scale = F.mean(y, axis=1)
        y_scale = F.mean(y.abs(), axis=1)
        y_scale = F.where(
            condition=y_scale > 0, x=y_scale, y=F.ones_like(y_scale)
        )
        y = F.broadcast_div(y, y_scale.expand_dims(-1))
        X = F.broadcast_div(X, y_scale.expand_dims(-1))

    n = len(y[0])

    if precision == "double":
        X = F.cast(X, dtype="float64")
        y = F.cast(y, dtype="float64")

    if mode in ("inverse", "np_solver"):
        if sample_weights is not None:
            sample_weights = F.cast(sample_weights, dtype="float64")
            root_weights = F.sqrt(sample_weights)
            X = F.broadcast_mul(X, root_weights.expand_dims(-1))
            y = F.broadcast_mul(y, root_weights.expand_dims(-1))

        eye = F.eye(X.shape[2])
        if precision == "double":
            eye = F.cast(eye, dtype="float64")
            lmbd = F.cast(lmbd, dtype="float64")

        if len(lmbd) == 1:
            lmbd = F.repeat(lmbd, repeats=y.shape[0], axis=0)

        lmbd_term = lmbd.expand_dims(1).expand_dims(1) * F.repeat(
            eye.expand_dims(0), repeats=y.shape[0], axis=0
        )
        # lmbd_term = lmbd*eye

        XTX = F.linalg.gemm2(X, X, transpose_a=True) + lmbd_term

        b = F.linalg.gemm2(X, y, transpose_a=True)
        if bias is not None:
            b = b + (
                lmbd.expand_dims(1)
                * F.repeat(bias.expand_dims(0), repeats=lmbd.shape[0], axis=0)
            ).expand_dims(-1)

        w = solve(F, XTX, b, mode=mode)

    elif mode == "sklearn":
        w = []
        for i in range(y.shape[0]):
            sw = (
                sample_weights[i].asnumpy()
                if sample_weights is not None
                else None
            )
            clf = Ridge(
                alpha=lmbd,
                fit_intercept=False,
                solver="auto",
                tol=1e-24,
                normalize=True,
            )
            clf.fit(X[i].asnumpy(), y[i].asnumpy(), sample_weight=sw)
            w.append(mx.nd.array(clf.coef_).expand_dims(-1))

        w = F.concat(*w, dim=0)
    elif mode == "sklearn_cv":
        retrain = True
        reg_params = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4, 1e5]
        # reg_params = np.logspace(-7, 7, num=100)
        v_index = 48
        index_iter = iter(
            [(list(range(n - v_index)), list(range(n - v_index, n)))]
        )  # train test validation
        scorer = metrics.make_scorer(score_func=smape, greater_is_better=False)
        # reg_params = [1]
        w = []
        for i in range(y.shape[0]):
            sw = (
                sample_weights[i].asnumpy()
                if sample_weights is not None
                else None
            )
            n_folds = None if len(y[i]) < 10 else 10
            clf = RidgeCV(
                alphas=reg_params,
                fit_intercept=False,
                cv=n_folds,
                scoring=scorer,
            )
            clf.fit(X[i].asnumpy(), y[i].asnumpy(), sample_weight=sw)
            if retrain:
                clf = Ridge(
                    alpha=clf.alpha_,
                    fit_intercept=False,
                    solver="auto",
                    tol=1e-24,
                )
                clf.fit(X[i].asnumpy(), y[i].asnumpy(), sample_weight=sw)
            w.append(mx.nd.array(clf.coef_).expand_dims(-1))

        w = F.concat(*w, dim=0)

    else:
        raise NotImplementedError("mode", mode, "not implemented")

    if precision == "double":
        return F.cast(w, dtype="float32")


def solve(F, A, b, mode="inverse"):
    if mode == "inverse":
        inv = F.linalg.inverse(A)
        x = F.linalg.gemm2(inv, b)
    elif mode == "np_solver":
        x = mx.np.linalg.solve(
            A.as_np_ndarray(), b.as_np_ndarray()
        ).as_nd_ndarray()
    else:
        raise NotImplementedError("mode", mode, "Not Implemented")

    return x


def print_diff_precision_solve(F, A, b):
    x_f = solve(F, A, b, precision="float")
    x_d = solve(F, A, b, precision="double")
    print("||x_float- x_double|| =", F.norm(x_d - x_f).asnumpy()[0])


def smape(y_true, y_pred):
    denominator = np.abs(y_true) + np.abs(y_pred)
    diff = np.abs(y_true - y_pred) / denominator
    diff[denominator == 0] = 0.0
    return 200 * np.mean(diff)
