# %%
import numpy as np
import utils as u
import matplotlib.pyplot as plt


# %%
def weighted_linear_regression_stacked(X_list, y_list, omega, ridge_lambda=0.0):
    """
    Weighted linear regression with optional Ridge regularization.

    Parameters
    ----------
    X_list : list of np.array
        List of input matrices X_j, each of shape (n_j, d)
    y_list : list of np.array
        List of target vectors y_j, each of shape (n_j,)
    omega : array-like
        Weights for each task
    ridge_lambda : float, optional
        Ridge regularization parameter (default=0, i.e., ordinary weighted least squares)

    Returns
    -------
    theta : np.array
        Estimated coefficients of shape (d,)
    """
    Xw = []
    yw = []

    for Xj, yj, wj in zip(X_list, y_list, omega):
        nj = Xj.shape[0]
        s = np.sqrt(wj / nj)
        Xw.append(s * Xj)
        yw.append(s * yj)

    Xw = np.vstack(Xw)  # stacked weighted design matrix
    yw = np.concatenate(yw)  # stacked weighted targets

    d = Xw.shape[1]

    if ridge_lambda == 0:
        # Ordinary weighted least squares
        theta = np.linalg.lstsq(Xw, yw, rcond=None)[0]
    else:
        # Ridge regression: (X^T X + lambda I)^-1 X^T y
        theta = np.linalg.solve(Xw.T @ Xw + ridge_lambda * np.eye(d), Xw.T @ yw)

    return theta


# %%data generation
def generate_data(
    d=20,
    n=10,
    T=100,
    prop=0.5,
    sig_coef=0.0,
    ntest=1000,
    random_state=None,
    randomcoef=False,
):
    """
    Generation of synthetic data in a list
    [(X_0, y_0), ..., (X_{T-1}, y_{T-1})]

    Returns
    -------
    X_list : list of ndarray, each (n, d)
    y_list : list of ndarray, each (n,)
    classind : ndarray, shape (T,)
    coef : ndarray, shape (2, d)
    noise_coef : ndarray, shape (T, d)
    Xtest : ndarray, shape (ntest, d)
    ytest : ndarray, shape (ntest,)
    """
    if random_state is not None:
        np.random.seed(random_state)

    # Coefficients
    if randomcoef:
        coef = np.zeros((2, d))
        coef[0, :] = np.random.multivariate_normal(np.zeros(d), np.eye(d))
        coef[0, :] = np.sqrt(d) * coef[0, :] / np.linalg.norm(coef[0, :])
        coef[1, :] = -coef[0, :]
    else:
        coef = np.zeros((2, d))
        coef[0, 0] = np.sqrt(d)
        coef[1, 0] = -np.sqrt(d)
    coef = np.sqrt(1 - sig_coef) * coef

    # Bruit de tâche
    noise_coef = np.random.multivariate_normal(np.zeros(d), np.eye(d), size=T)
    noise_coef = (
        np.sqrt(sig_coef)
        * np.sqrt(d)
        * noise_coef
        / np.linalg.norm(noise_coef, axis=1)[:, None]
    )

    X_list, y_list = [], []

    # Group indexes
    classind = np.random.binomial(1, prop, size=T)
    classind[0] = 0

    for j in range(T):
        Xj = np.random.multivariate_normal(np.ones(d), np.eye(d), size=n)
        noisey = np.random.normal(0, 1, size=n)

        yj = Xj @ coef[classind[j]] + Xj @ noise_coef[j] + noisey

        X_list.append(Xj)
        y_list.append(yj)

    Xtest = np.random.multivariate_normal(np.ones(d), np.eye(d), size=ntest)
    ytest = Xtest @ (coef[classind[0]] + noise_coef[0]) + np.random.normal(
        0, 1, size=ntest
    )

    return X_list, y_list, classind, coef, noise_coef, Xtest, ytest


# %%
def compute_rff_from_Xy(X, y, D, SigmaRFF=None):
    """
    Compute Random Fourier Features from concatenated (X, y).

    Parameters
    ----------
    X : ndarray, shape (T, n, d)
    y : ndarray, shape (T, n, 1) or (T, n)
    D : int
        Number of random features
    SigmaRFF : ndarray, shape (d+1, d+1), optional
        Covariance matrix for RFF sampling

    Returns
    -------
    RFFs : ndarray
        Random Fourier Features
    """
    T = X.shape[0]
    d = X.shape[-1]

    if y.ndim == 3:
        y = y[..., 0]

    # Concatenate X and y for each task
    XY = [np.append(X[i], y[i][:, None], axis=1) for i in range(T)]

    # Default covariance
    if SigmaRFF is None:
        SigmaRFF = np.eye(d + 1) / (d + 1)
        SigmaRFF[d, d] = 1

    RFFs, _, _ = u.RFF(XY, D, variance=SigmaRFF)
    return RFFs


# %%
def compute_qaggregation_weights(RFFs, T, M=1, u0=None):
    """
    Compute Q-aggregation weights from RFFs.

    Parameters
    ----------
    RFFs : ndarray or list
        Random Fourier Features (output of u.RFF)
    T : int
        Number of tasks
    M : int, optional
        Aggregation parameter (default: 1)

    Returns
    -------
    weights : ndarray
        Q-aggregation weights
    """
    if u0 is None:
        u0 = np.log(T)
    c0 = np.sqrt(u0)
    cbs = u0

    weights = u.Qaggregation(RFFs, M, c0=c0, cbs=cbs)
    return weights


# %%
d = 20
n = 10
T = 100
D = 500
prop = 0.5
sigmas = np.logspace(-2, 0, 21)  # np.linspace(0, 1, 11)
n_repeats = 50
ridge = 0
ntest = 1000

test_errors_fed = []
test_errors_naive = []
test_errors_gm = []
test_errors_or = []
weights_fed_list = []

test_std_fed = []
test_std_naive = []
test_std_gm = []
test_std_or = []

for sig_coef in sigmas:
    # Initialize running averages
    fed_mse_sum = 0.0
    naive_mse_sum = 0.0
    gm_mse_sum = 0.0
    or_mse_sum = 0.0

    fed_std_sum = 0.0
    naive_std_sum = 0.0
    gm_std_sum = 0.0
    or_std_sum = 0.0
    weight_aux = np.zeros(3)

    for i in range(n_repeats):
        # =========================
        # 1. Generating datat
        # =========================
        X_list, y_list, classind, coef, noise_coef, Xtest, ytest = generate_data(
            d=d, n=n, T=T, prop=prop, sig_coef=sig_coef, ntest=ntest, randomcoef=True
        )

        # =========================
        # 2. RFF
        # =========================
        XY = [
            np.concatenate([Xj, yj[:, None]], axis=1) for Xj, yj in zip(X_list, y_list)
        ]

        SigmaRFF = np.eye(d + 1) / (d + 1)
        SigmaRFF[d, d] = 1

        RFFs, _, _ = u.RFF(XY, D, variance=SigmaRFF)

        # =========================
        # 3. Q-aggregation
        # =========================
        weights_fed = compute_qaggregation_weights(RFFs, T)
        weight_aux[0] += weights_fed[0]
        weight_aux[1] += np.mean(weights_fed[1:][(1 - classind)[1:].astype(bool)])
        weight_aux[2] += np.mean(weights_fed[1:][classind[1:].astype(bool)])

        # =========================
        # 4. Weighted linear regression
        # =========================
        theta_hat_fed = weighted_linear_regression_stacked(
            X_list, y_list, weights_fed, ridge_lambda=ridge
        )
        ypred_fed = Xtest @ theta_hat_fed
        fed_mse_sum += np.mean((ypred_fed - ytest) ** 2)
        fed_std_sum += np.std((ypred_fed - ytest) ** 2)

        # =========================
        # Naive approach
        # =========================
        weights_naive = np.zeros(T)
        weights_naive[0] = 1
        theta_hat_naive = weighted_linear_regression_stacked(
            X_list, y_list, weights_naive, ridge_lambda=ridge
        )
        ypred_naive = Xtest @ theta_hat_naive
        naive_mse_sum += np.mean((ypred_naive - ytest) ** 2)
        naive_std_sum += np.std((ypred_naive - ytest) ** 2)

        # ========================
        # Grand mean
        # =========================
        weights_gm = np.ones(T) / T
        theta_hat_gm = weighted_linear_regression_stacked(
            X_list, y_list, weights_gm, ridge_lambda=ridge
        )
        ypred_gm = Xtest @ theta_hat_gm
        gm_mse_sum += np.mean((ypred_gm - ytest) ** 2)
        gm_std_sum += np.std((ypred_gm - ytest) ** 2)

        # ========================
        # Oracle
        # ========================
        weights_or = (1 - classind) / np.sum(1 - classind)
        theta_hat_or = weighted_linear_regression_stacked(
            X_list, y_list, weights_or, ridge_lambda=ridge
        )
        ypred_or = Xtest @ theta_hat_or
        or_mse_sum += np.mean((ypred_or - ytest) ** 2)
        or_std_sum += np.std((ypred_or - ytest) ** 2)

    # Compute mean over repetitions
    test_errors_fed.append(fed_mse_sum / n_repeats)
    test_errors_naive.append(naive_mse_sum / n_repeats)
    test_errors_gm.append(gm_mse_sum / n_repeats)
    test_errors_or.append(or_mse_sum / n_repeats)
    weights_fed_list += [weight_aux / n_repeats]

    test_std_fed.append(fed_std_sum / n_repeats)
    test_std_naive.append(naive_std_sum / n_repeats)
    test_std_gm.append(gm_std_sum / n_repeats)
    test_std_or.append(or_std_sum / n_repeats)

    print(
        f"sigma_coef={sig_coef:.2f} | Fed MSE={test_errors_fed[-1]:.4f} | Naive MSE={test_errors_naive[-1]:.4f}"
    )


# %%
test_errors_fed = np.array(test_errors_fed)
test_errors_naive = np.array(test_errors_naive)
test_errors_gm = np.array(test_errors_gm)
test_errors_or = np.array(test_errors_or)

test_std_fed = np.array(test_std_fed)
test_std_naive = np.array(test_std_naive)
test_std_gm = np.array(test_std_gm)
test_std_or = np.array(test_std_or)


# %%
test_std_fed = np.std(test_errors_fed)
test_std_naive = np.std(test_errors_naive)
test_std_gm = np.std(test_errors_gm)
test_std_or = np.std(test_errors_or)
# %%
plt.figure(figsize=(7, 4))


def plot_with_std(x, y, ystd, fmt, label, color):
    plt.plot(x, y, fmt, label=label, color=color)
    plt.fill_between(x, np.maximum(y - ystd, 0), y + ystd, color=color, alpha=0.2)
    plt.plot(x, y + ystd, color=color, linewidth=1.5, alpha=0.2)  # bord supérieur
    plt.plot(
        x, np.maximum(y - ystd, 0), color=color, linewidth=1.5, alpha=0.2
    )  # bord inférieur


colors = ["C0", "C1", "C2", "C3"]

plot_with_std(sigmas, test_errors_naive, test_std_naive, "o-", "Local", colors[0])
plot_with_std(
    sigmas, test_errors_fed, test_std_fed, "s-", "Fed / Q-aggregation", colors[1]
)
plot_with_std(sigmas, test_errors_gm, test_std_gm, "^-", "Grandmean", colors[2])
plot_with_std(
    sigmas, test_errors_or, test_std_or, "d-", "Neighbours (oracle)", colors[3]
)

plt.xscale("log")
# plt.yscale("log")
plt.xlabel(r"$\sigma^2_c$", fontsize=16)
plt.ylabel("Test Mean Square Error", fontsize=14)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig("Syn_concept_shift_mse_with_std_v18.pdf")
plt.show()
