import torch



def torchOLS(X_train, Y_train, X_test, Y_test, W, lambda_reg=0):
    # b: num_test; n: num_train; p: feature dim

    # Compute the mean of Y_train
    Y_train_mean = Y_train.mean()

    # Center Y_train and Y_test
    Y_train = Y_train - Y_train_mean
    # Y_test = Y_test - Y_train_mean

    XTW_batch = torch.einsum('np,bn->bpn', X_train, W)  # Shape: (b, p, n)

    XTWX_batch = torch.matmul(XTW_batch, X_train)  # Shape: (b, p, p)

    XTWy_batch = torch.matmul(XTW_batch, Y_train)  # Shape: (b, p)

    # Better conditioned -> stable inversion. Reduces the conditioning number as lambda increases.
    XTWX_batch += lambda_reg * torch.eye(XTWX_batch.size(-1), device=XTWX_batch.device).unsqueeze(0)

    beta_batch = torch.linalg.solve(XTWX_batch, XTWy_batch)  # Shape: (b, p).

    Y_hat_test = (X_test * beta_batch).sum(-1)

    # Add back the mean of Y_train to predictions
    Y_hat_test = Y_hat_test + Y_train_mean
    return Y_hat_test


class WeightedLinearRegression:
    def __init__(self, fit_intercept: bool = True):
        """
        Weighted linear regression with optional intercept.

        Args:
            fit_intercept: If True, augment features with a column of 1s to learn an intercept.
        """
        self.fit_intercept = fit_intercept
        self.coef_ = None

    def fit(self, X_train, Y_train, weights=None):
        """
        Fit the weighted linear regression model.

        Args:
            X_train (torch.Tensor): Training features of shape (n_samples, n_features).
            Y_train (torch.Tensor): Training targets of shape (n_samples,).
            weights (torch.Tensor): Sample weights of shape (n_samples,).
        """
        n_samples, n_features = X_train.shape

        if weights is None:
            weights = torch.ones(n_samples, device=X_train.device)

        # If fitting intercept, augment X_train with a column of ones.
        if self.fit_intercept:
            ones = torch.ones(n_samples, 1, device=X_train.device, dtype=X_train.dtype)
            X_train = torch.cat([X_train, ones], dim=1)
            # now X_train has shape (n_samples, n_features+1)

        # Compute weighted X and Y
        sqrt_weights = torch.sqrt(weights).unsqueeze(1)  # (n_samples, 1)
        X_weighted = X_train * sqrt_weights  # scale each feature row by sqrt(weights)
        Y_weighted = Y_train * sqrt_weights.squeeze(1)  # (n_samples,)

        # Use lstsq to compute coefficients
        # self.coef_ will have shape (n_features,) or (n_features+1,) if intercept is included
        self.coef_ = torch.linalg.lstsq(X_weighted, Y_weighted).solution

    def predict(self, X_test):
        """
        Predict targets for the test set.

        Args:
            X_test (torch.Tensor): Test features of shape (n_samples, n_features).

        Returns:
            torch.Tensor: Predicted targets of shape (n_samples,).
        """
        # If we fitted an intercept, augment X_test with a column of ones, same as in fit
        if self.fit_intercept:
            n_samples_test = X_test.shape[0]
            ones_test = torch.ones(n_samples_test, 1, device=X_test.device, dtype=X_test.dtype)
            X_test = torch.cat([X_test, ones_test], dim=1)

        # Compute predictions: shape (n_samples,)
        return X_test @ self.coef_
