import numpy as np
import jax
import jax.numpy as jnp
from jax import jit
from jax.scipy.optimize import minimize as jax_minimize
from jaxopt import OSQP
from sklearn.base import BaseEstimator, RegressorMixin
from tqdm import tqdm

import sys
sys.path.append("..")
from utils.kernel_utils import Kernel, RBF
from utils.linalg_utils import make_psd, cartesian_product, remove_diagonal_elements, columns_mean_excluding_self

from typing import Callable, Tuple, List, Optional, Union, Dict, Any

from jax import config
config.update("jax_enable_x64", True)


class KernelAlternativeProxyATE_Nystorm(BaseEstimator, RegressorMixin):

    def __init__(self,
                 kernel_A: Kernel,
                 kernel_W: Kernel,
                 kernel_Z: Kernel,
                 kernel_X: Kernel = RBF(),
                 lambda_: float = 0.1,
                 eta: float = 0.1, 
                 lambda2_: float = 0.1,
                 nystrom_first_stage_m: int = 500,
                 nystrom_third_stage_m: int = 500,
                 optimize_eta_parameter: bool = True,
                 eta_optimization_range: Tuple[float, float] = (1e-7, 1.0),
                 **kwargs) -> None:
        """
        Initialize the KernelAlternativeProxyATE estimator.

        Parameters:
        - kernel_A (Kernel): Kernel function for variable A.
        - kernel_W (Kernel): Kernel function for variable W.
        - kernel_Z (Kernel): Kernel function for variable Z.
        - kernel_X (Kernel, optional): Kernel function for variable X. Defaults to RBF().
        - lambda_ (float, optional): Regularization parameter. Defaults to 0.1.
        - eta (float, optional): Regularization parameter for structural function prediction. Defaults to 0.1.
        - lambda2_ (float, optional): Second stage regularization parameter. Defaults to 0.1.
        - optimize_eta_parameters (bool, optional): Flag to optimize eta regularization parameter. Defaults to True.
        - eta_optimization_range (Tuple[float, float], optional): Range for eta optimization. Defaults to (1e-7, 1.0).
        - **kwargs: Additional parameters.
        """
        stage1_perc = kwargs.pop('stage1_perc', 0.5)
        model_seed = kwargs.pop('model_seed', 0)
        label_variance_in_eta_opt = kwargs.pop('label_variance_in_eta_opt', 0.0)
        large_eta_option = kwargs.pop('large_eta_option', False)
        selecting_biggest_eta_tol = kwargs.pop('selecting_biggest_eta_tol', 1e-9)
        regularization_grid_points = kwargs.pop('regularization_grid_points', 150)
        make_psd_eps = kwargs.pop('make_psd_eps', 1e-9)
        kernel_A_params = kwargs.pop('kernel_A_params', None)
        kernel_W_params = kwargs.pop('kernel_W_params', None)
        kernel_Z_params = kwargs.pop('kernel_Y_params', None)
        kernel_X_params = kwargs.pop('kernel_X_params', None)

        if (not isinstance(kernel_A, Kernel)):
            raise Exception("Kernel for A must be callable Kernel class")
        if (not isinstance(kernel_W, Kernel)):
            raise Exception("Kernel for W must be callable Kernel class")
        if (not isinstance(kernel_Z, Kernel)):
            raise Exception("Kernel for Z must be callable Kernel class")
        if (not isinstance(kernel_X, Kernel)):
            raise Exception("Kernel for X must be callable Kernel class")
        
        self.kernel_A = kernel_A
        self.kernel_W = kernel_W
        self.kernel_Z = kernel_Z
        self.kernel_X = kernel_X

        if kernel_A_params is not None:
            self.kernel_A.set_params(**kernel_A_params)
        if kernel_W_params is not None:
            self.kernel_W.set_params(**kernel_W_params)
        if kernel_Z_params is not None:
            self.kernel_Z.set_params(**kernel_Z_params)
        if kernel_X_params is not None:
            self.kernel_X.set_params(**kernel_X_params)

        self.lambda_, self.eta, self.lambda2_ = lambda_, eta, lambda2_
        self.nystrom_first_stage_m = nystrom_first_stage_m
        self.nystrom_third_stage_m = nystrom_third_stage_m
        self.model_seed = model_seed
        self.optimize_eta_parameter = optimize_eta_parameter
        self.eta_optimization_range = eta_optimization_range
        self.large_eta_option = large_eta_option
        self.selecting_biggest_eta_tol = selecting_biggest_eta_tol
        self.label_variance_in_eta_opt = label_variance_in_eta_opt
        self.stage1_perc = stage1_perc
        self.regularization_grid_points = regularization_grid_points
        self.make_psd_eps = make_psd_eps        

    def sample_landmarks(self, original_size, m, seed = 0):
        if m > original_size:
            m = original_size
        np.random.seed(seed)
        indices = np.random.choice(original_size, m, replace=False)
        return indices
    
    ########################################################################
    ###################### STATIC JIT FUNCTIONS ############################
    ########################################################################        

    @staticmethod
    @jit 
    def _eta_objective(eta, L, L_sub, M, N, L2, M2, stage1_data_size, label_variance_in_eta_opt, make_psd_eps = 1e-9):
        stage2_data_size = L.shape[0] - 1
        alpha = jnp.linalg.solve(make_psd(L / stage2_data_size + eta * N, make_psd_eps), M)
        cost = ((1 / stage1_data_size) * (alpha.T @ make_psd(L2, make_psd_eps) @ alpha) - 2 * (alpha.T @ M2)) 
        cost += label_variance_in_eta_opt * (2 / stage2_data_size) * jnp.trace(jnp.linalg.solve(make_psd(L + stage2_data_size * eta * N, make_psd_eps), L))
        return cost.reshape(())
    
    @staticmethod
    @jit 
    def _predict_structural_function(alpha: jnp.ndarray,
                                     B: jnp.ndarray,
                                     B_bar: jnp.ndarray,
                                     third_stage_KRR_weights: jnp.ndarray,
                                     K_ATraina: jnp.ndarray,
                                     K_ATildea: jnp.ndarray,
                                     ones_divided_by_m: jnp.ndarray) -> jnp.ndarray:
        """
        Predict the structural function.

        Parameters:
        - alpha (jnp.ndarray): Coefficient array.
        - B (jnp.ndarray): Matrix B from second stage.
        - B_bar (jnp.ndarray): Matrix B_bar from second stage.
        - third_stage_KRR_weights (jnp.ndarray): Weights from third stage kernel ridge regression.
        - K_ATraina (jnp.ndarray): Kernel matrix between training set A and a test point.
        - K_ATildea (jnp.ndarray): Kernel matrix between stage 2 set A and a test point.
        - ones_divided_by_m (jnp.ndarray): Array of ones divided by stage 2 data size.

        Returns:
        - jnp.ndarray: Predicted values.
        """
        pred = (alpha[:-1].T @ ((B.T @ (third_stage_KRR_weights @ K_ATraina)) * K_ATildea))
        pred += (alpha[-1] * ((B_bar.T @ (third_stage_KRR_weights @ K_ATraina)) * K_ATildea) @ ones_divided_by_m)
        return pred

    ########################################################################
    ###################### FIT AND PREDICT FUNCTIONS #######################
    ########################################################################
    def fit(self, 
            AWZX: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]], 
            Y: jnp.ndarray,) -> None:
        """
        Fit the KernelAlternativeProxyATE model.

        Parameters:
        - AWZX (Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]): Tuple of data arrays (A, W, Z, X).
        - Y (np.ndarray): Target values.
        """
        kernel_A, kernel_W, kernel_Z, kernel_X = self.kernel_A, self.kernel_W, self.kernel_Z, self.kernel_X
        lambda_, eta, lambda2_ = self.lambda_, self.eta, self.lambda2_
        stage1_perc = self.stage1_perc
        make_psd_eps = self.make_psd_eps
        
        if len(AWZX) == 4:
            A, W, Z, X = AWZX
        elif len(AWZX) == 3:
            A, W, Z = AWZX
            X = None
        
        K_ATrainATrain = kernel_A(A, A)
        K_WTrainWTrain = kernel_W(W, W)
        K_ZTrainZTrain = kernel_Z(Z, Z)
        if X is None:
            K_XTrainXTrain = jnp.ones((W.shape[0], W.shape[0]))
        else:
            K_XTrainXTrain = make_psd(kernel_X(X, X), make_psd_eps)

        self.kernel_A = kernel_A
        self.kernel_W = kernel_W
        self.kernel_Z = kernel_Z
        self.kernel_X = kernel_X

        ############################# SPLIT DATA IN STAGE 1 AND STAGE 2 #####################################
        train_data_size = A.shape[0]
        train_indices = np.random.permutation(train_data_size)

        if (stage1_perc > 0.) & (stage1_perc < 1.):
            stage1_data_size = int(train_data_size * stage1_perc)
            stage2_data_size = train_data_size - stage1_data_size
            stage1_idx, stage2_idx = train_indices[:stage1_data_size], train_indices[stage1_data_size:]
        else:
            stage1_data_size, stage2_data_size = train_data_size, train_data_size
            stage1_idx, stage2_idx = train_indices, train_indices

        nystrom_stage1_landmarks = self.sample_landmarks(stage1_data_size, self.nystrom_first_stage_m, self.model_seed)
        nystrom_stage3_landmarks = self.sample_landmarks(train_data_size, self.nystrom_third_stage_m, self.model_seed + 1)
        ################################ KERNEL MATRICES ####################################################
        K_AATilde = K_ATrainATrain[np.ix_(stage1_idx, stage2_idx)]
        K_ATildeATilde = K_ATrainATrain[np.ix_(stage2_idx, stage2_idx)]

        K_WWTilde = K_WTrainWTrain[np.ix_(stage1_idx, stage2_idx)]

        K_ZZ = K_ZTrainZTrain[np.ix_(stage1_idx, stage1_idx)]

        K_XXTilde = K_XTrainXTrain[np.ix_(stage1_idx, stage2_idx)]

        for kernel_ in [self.kernel_A, self.kernel_W, self.kernel_Z, self.kernel_X]:
            if hasattr(kernel_, 'use_length_scale_heuristic'):
                kernel_.use_length_scale_heuristic = False

        ########## OPTIMIZE THE LAMBDA REGULARIZATION PARAMETER IF IT IS SPECIFIED ###########################
        K_AWX = (K_ATrainATrain * K_WTrainWTrain * K_XTrainXTrain)[np.ix_(stage1_idx, stage1_idx)]
        
        K_AWX_nm = K_AWX[np.ix_(np.arange(stage1_data_size), nystrom_stage1_landmarks)]
        K_AWX_mm = K_AWX[np.ix_(nystrom_stage1_landmarks, nystrom_stage1_landmarks)]
        ########### FIRST AND SECOND STAGE REGRESSION ########################################
        stage1_ridge_weights = (K_AWX_nm.T @ K_AWX_nm + stage1_data_size * lambda_ * K_AWX_mm)
        self.stage1_ridge_weights = stage1_ridge_weights
        K_A_m_ATilde = K_AATilde[np.ix_(nystrom_stage1_landmarks, np.arange(stage2_data_size))]
        K_WX_m_WXTilde = (K_WWTilde * K_XXTilde)[np.ix_(nystrom_stage1_landmarks, np.arange(stage2_data_size))]
        B = K_AWX_nm @ jnp.linalg.solve(make_psd(stage1_ridge_weights, make_psd_eps), K_WX_m_WXTilde * K_A_m_ATilde)
        B_bar = K_AWX_nm @ jnp.linalg.solve(make_psd(stage1_ridge_weights, make_psd_eps),  (columns_mean_excluding_self(K_WX_m_WXTilde) * K_A_m_ATilde))
        
        block_component1 = (B.T @ K_ZZ @ B) * K_ATildeATilde
        block_component2 = (B.T @ K_ZZ @ B_bar) * K_ATildeATilde
        block_component4 = (B_bar.T @ K_ZZ @ B_bar) * K_ATildeATilde
        ones_divided_by_m = jnp.ones((stage2_data_size)) / stage2_data_size

        L_sub = jnp.vstack((block_component1, ones_divided_by_m.T @ block_component2.T))
        L = L_sub @ L_sub.T
        self.L = L_sub.T
        M = jnp.vstack(((block_component2 @ ones_divided_by_m).reshape(-1, 1), (ones_divided_by_m.T @ block_component4 @ ones_divided_by_m).reshape(-1, 1)))
        
        P = jnp.hstack((block_component1, (block_component2 @ ones_divided_by_m).reshape(-1, 1)))
        R = jnp.hstack(((ones_divided_by_m.T @ block_component2.T).reshape(1, -1), (ones_divided_by_m.T @ block_component4 @ ones_divided_by_m).reshape(-1, 1)))
        N = jnp.vstack((P, R))

        alpha = jnp.linalg.solve(make_psd(L / stage2_data_size + eta * N, make_psd_eps), M)
        ########### THIRD STAGE ########################################
        K_ZZTrain = K_ZTrainZTrain[np.ix_(stage1_idx, train_indices)]
        K_ATrainATrain_ = K_ATrainATrain[np.ix_(train_indices, train_indices)]
        K_ATrainATrain_nm = K_ATrainATrain_[np.ix_(np.arange(train_data_size), nystrom_stage3_landmarks)]
        K_ATrainATrain_mm = K_ATrainATrain_[np.ix_(nystrom_stage3_landmarks, nystrom_stage3_landmarks)]
        # third_stage_KRR_weights = jnp.linalg.solve(make_psd(K_ATrainATrain_ + train_data_size * lambda2_ * jnp.eye(train_data_size), make_psd_eps), (K_ZZTrain.T * Y[train_indices])).T 
        third_stage_KRR_weights = jnp.linalg.solve(make_psd(K_ATrainATrain_nm.T @ K_ATrainATrain_nm + train_data_size * lambda2_ * K_ATrainATrain_mm, make_psd_eps), K_ATrainATrain_nm.T @ (K_ZZTrain.T * Y[train_indices])).T 

        self.alpha = alpha
        self.B, self.B_bar = B, B_bar
        self.third_stage_KRR_weights = third_stage_KRR_weights
        self.ones_divided_by_m = ones_divided_by_m
        self.ATrain, self.WTrain, self.XTrain, self.ZTrain = A, W, X, Z
        self.K_ZZ = K_ZZ
        self.train_indices = train_indices[nystrom_stage3_landmarks]
        self.stage1_idx, self.stage2_idx = stage1_idx, stage2_idx

        ##### For debugging purpose, I might want to check the regularization values after optimization #######
        self.lambda_ = lambda_
        self.lambda2_ = lambda2_
        self.eta = eta

    def predict(self, A: jnp.ndarray) -> jnp.ndarray:
        """
        Predict outcomes for new data points.

        Parameters:
        - A (jnp.ndarray): New data points for variable A.

        Returns:
        - jnp.ndarray: Predicted values.
        """
        if A.ndim != 2:
            A_test = A.reshape(-1, 1)
        else:
            A_test = A
        K_ATrainATest = self.kernel_A(self.ATrain, A_test)

        test_indices = jnp.arange(A_test.shape[0])
        test_shape = test_indices.shape[0]

        K_ATrainATest_ = K_ATrainATest[tuple(cartesian_product(self.train_indices, test_indices).T)].reshape(self.train_indices.shape[0], test_shape)
        K_ATildeATest = K_ATrainATest[tuple(cartesian_product(self.stage2_idx, test_indices).T)].reshape(self.stage2_idx.shape[0], test_shape)

        ones_divided_by_m = self.ones_divided_by_m
        alpha = self.alpha
        B, B_bar = self.B, self.B_bar
        third_stage_KRR_weights = self.third_stage_KRR_weights


        f_struct_pred = jnp.array([self._predict_structural_function(alpha, B, B_bar, third_stage_KRR_weights, 
                                                                    K_ATrainATest_[:, jj], K_ATildeATest[:, jj], 
                                                                    ones_divided_by_m).item() for jj in range(K_ATildeATest.shape[1])])
        return f_struct_pred

    def _predict_bridge_func(self, A_test : jnp.ndarray, Z_test : jnp.ndarray, X_test = None):
        if A_test.ndim != 2:
            A_test = A_test.reshape(-1, 1)
        K_ZZTest = self.kernel_Z(self.ZTrain[self.stage1_idx, :], Z_test)
        K_ATildeATest = self.kernel_A(self.ATrain[self.stage2_idx, :], A_test)
        alpha, B, B_bar = self.alpha, self.B, self.B_bar
        ones_divided_by_m = self.ones_divided_by_m
        bridge_function = jnp.array([alpha[:-1].T @ ((B.T @ K_ZZTest) * K_ATildeATest[:, jj].reshape(-1, 1)) + alpha[-1] * ones_divided_by_m.T @ ((B_bar.T @ K_ZZTest) * K_ATildeATest[:, jj].reshape(-1,1)) for jj in range(K_ATildeATest.shape[1])])
        return bridge_function[:, 0, :]


class KernelProxyVariableATE_Nystorm(BaseEstimator, RegressorMixin):

    def __init__(self, 
                 kernel_A : Kernel,
                 kernel_W : Kernel,
                 kernel_Z : Kernel,
                 kernel_X : Kernel = RBF(),
                 lambda1_ : float = 0.1,
                 lambda2_ : float = 0.1,
                 nystrom_first_stage_m: int = 500,
                 nystrom_second_stage_m: int = 500,
                 **kwargs) -> None:
        stage1_perc = kwargs.pop('stage1_perc', 0.5)
        make_psd_eps = kwargs.pop('make_psd_eps', 1e-9)
        model_seed = kwargs.pop('model_seed', 0)
        kernel_A_params = kwargs.pop('kernel_A_params', None)
        kernel_W_params = kwargs.pop('kernel_W_params', None)
        kernel_Z_params = kwargs.pop('kernel_Y_params', None)
        kernel_X_params = kwargs.pop('kernel_X_params', None)

        if (not isinstance(kernel_A, Kernel)):
            raise Exception("Kernel for A must be callable Kernel class")
        if (not isinstance(kernel_W, Kernel)):
            raise Exception("Kernel for W must be callable Kernel class")
        if (not isinstance(kernel_Z, Kernel)):
            raise Exception("Kernel for Z must be callable Kernel class")
        if (not isinstance(kernel_X, Kernel)):
            raise Exception("Kernel for X must be callable Kernel class")
        self.kernel_A = kernel_A
        self.kernel_W = kernel_W
        self.kernel_Z = kernel_Z
        self.kernel_X = kernel_X

        if kernel_A_params is not None:
            self.kernel_A.set_params(**kernel_A_params)
        if kernel_W_params is not None:
            self.kernel_W.set_params(**kernel_W_params)
        if kernel_Z_params is not None:
            self.kernel_Z.set_params(**kernel_Z_params)
        if kernel_X_params is not None:
            self.kernel_X.set_params(**kernel_X_params)

        self.lambda1_, self.lambda2_ = lambda1_, lambda2_
        self.nystrom_first_stage_m = nystrom_first_stage_m
        self.nystrom_second_stage_m = nystrom_second_stage_m
        self.stage1_perc = stage1_perc
        self.model_seed = model_seed
        self.make_psd_eps = make_psd_eps 

    def sample_landmarks(self, original_size, m, seed = 0):
        if m > original_size:
            m = original_size
        np.random.seed(seed)
        indices = np.random.choice(original_size, m, replace=False)
        return indices
    
    def fit(self,             
            AWZX: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]], 
            Y: jnp.ndarray,) -> None:
        kernel_A, kernel_W, kernel_Z, kernel_X = self.kernel_A, self.kernel_W, self.kernel_Z, self.kernel_X
        lambda1_, lambda2_ = self.lambda1_, self.lambda2_
        stage1_perc = self.stage1_perc
        make_psd_eps = self.make_psd_eps

        if len(AWZX) == 4:
            A, W, Z, X = AWZX
        elif len(AWZX) == 3:
            A, W, Z = AWZX
            X = None
        
        K_ATrainATrain = kernel_A(A, A)
        K_WTrainWTrain = kernel_W(W, W)
        K_ZTrainZTrain = kernel_Z(Z, Z)
        if X is None:
            K_XTrainXTrain = jnp.ones((W.shape[0], W.shape[0]))
        else:
            K_XTrainXTrain = make_psd(kernel_X(X, X), make_psd_eps)

        self.kernel_A = kernel_A
        self.kernel_W = kernel_W
        self.kernel_Z = kernel_Z
        self.kernel_X = kernel_X

        ############################# SPLIT DATA IN STAGE 1 AND STAGE 2 #####################################
        train_data_size = A.shape[0]
        train_indices = np.random.permutation(train_data_size)

        if (stage1_perc > 0.) & (stage1_perc < 1.):
            stage1_data_size = int(train_data_size * stage1_perc)
            stage2_data_size = train_data_size - stage1_data_size
            stage1_idx, stage2_idx = train_indices[:stage1_data_size], train_indices[stage1_data_size:]
        else:
            stage1_data_size, stage2_data_size = train_data_size, train_data_size
            stage1_idx, stage2_idx = train_indices, train_indices

        nystrom_stage1_landmarks = self.sample_landmarks(stage1_data_size, self.nystrom_first_stage_m, self.model_seed)
        nystrom_stage2_landmarks = self.sample_landmarks(train_data_size, self.nystrom_second_stage_m, self.model_seed)
        ################################ KERNEL MATRICES ####################################################
        # K_AA = K_ATrainATrain[tuple(cartesian_product(stage1_idx, stage1_idx).T)].reshape(stage1_data_size, stage1_data_size)
        # K_AA = K_ATrainATrain[np.ix_(stage1_idx, stage1_idx)]
        K_AATilde = K_ATrainATrain[np.ix_(stage1_idx, stage2_idx)]
        # K_ATildeA = K_AATilde.T
        K_ATildeATilde = K_ATrainATrain[np.ix_(stage2_idx, stage2_idx)]

        K_WW = K_WTrainWTrain[np.ix_(stage1_idx, stage1_idx)]

        # K_ZZ = K_ZTrainZTrain[np.ix_(stage1_idx, stage1_idx)]
        K_ZZTilde = K_ZTrainZTrain[np.ix_(stage1_idx, stage2_idx)]

        # K_XX = K_XTrainXTrain[np.ix_(stage1_idx, stage1_idx)]
        K_XXTilde = K_XTrainXTrain[np.ix_(stage1_idx, stage2_idx)]
        K_XTildeXTilde = K_XTrainXTrain[np.ix_(stage2_idx, stage2_idx)]

        for kernel_ in [self.kernel_A, self.kernel_W, self.kernel_Z, self.kernel_X]:
            if hasattr(kernel_, 'use_length_scale_heuristic'):
                kernel_.use_length_scale_heuristic = False
                
        ########### FIRST STAGE REGRESSION ###########################
        K_ZAX = (K_ZTrainZTrain * K_ATrainATrain * K_XTrainXTrain)[np.ix_(stage1_idx, stage1_idx)]
        K_ZAX_nm = K_ZAX[np.ix_(np.arange(stage1_data_size), nystrom_stage1_landmarks)]
        K_ZAX_mm = K_ZAX[np.ix_(nystrom_stage1_landmarks, nystrom_stage1_landmarks)]
        del K_ZAX
        YTilde = Y[stage2_idx]

        ########### SECOND STAGE REGRESSION ###########################
        # stage1_ridge_weights = (K_ZAX + stage1_data_size * lambda1_ * I_n)
        stage1_ridge_weights = K_ZAX_nm.T @ K_ZAX_nm + stage1_data_size * lambda1_ * K_ZAX_mm
        K_ZAX_m_ZAXTilde = (K_ZZTilde * K_AATilde * K_XXTilde)[np.ix_(nystrom_stage1_landmarks, np.arange(stage2_data_size))]
        B = K_ZAX_nm @ jnp.linalg.solve(make_psd(stage1_ridge_weights), K_ZAX_m_ZAXTilde)
        self.B = B
        stage2_ridge_weights = K_ATildeATilde * (B.T @ K_WW @ B) * K_XTildeXTilde
        del K_ATildeATilde
        del K_XTildeXTilde
        stage2_ridge_weights_mn = stage2_ridge_weights[np.ix_(nystrom_stage2_landmarks, np.arange(stage2_data_size))]
        stage2_ridge_weights_mm = stage2_ridge_weights[np.ix_(nystrom_stage2_landmarks, nystrom_stage2_landmarks)]

        x_mean_vec = jnp.mean(K_XXTilde, axis=0)[:, jnp.newaxis]

        # stage2_ridge_weights += stage2_data_size * lambda2_ * I_m
        alpha = jnp.linalg.solve(make_psd(stage2_ridge_weights_mn @ stage2_ridge_weights_mn.T + stage2_data_size * lambda2_ * stage2_ridge_weights_mm), stage2_ridge_weights_mn @ YTilde)
        w_mean_vec = jnp.mean(K_WW, axis=0)[:, jnp.newaxis]

        self.alpha = alpha
        self.w_mean_vec = w_mean_vec
        self.x_mean_vec = x_mean_vec
        self.ATilde = A[stage2_idx][nystrom_stage2_landmarks]
        self.nystrom_stage2_landmarks = nystrom_stage2_landmarks
        if X is not None:
            self.XTilde = X[stage2_idx][nystrom_stage2_landmarks]
        else:
            self.XTilde = None
        self.W = W[stage1_idx]
        self.upweight = jnp.mean((B[:, nystrom_stage2_landmarks].T @ K_WW) * K_XXTilde[:, nystrom_stage2_landmarks].T, axis = 1).reshape(-1, 1)

    def predict(self, A: jnp.ndarray) -> jnp.ndarray:
        if A.ndim != 2:
            A_test = A.reshape(-1, 1)
        else:
            A_test = A
        K_ATildeATest = self.kernel_A(self.ATilde, A_test)

        pred = (K_ATildeATest * self.upweight).T @ self.alpha
        ## The following line is based on the implementation in https://github.com/liyuan9988/DeepFeatureProxyVariable/blob/master/src/models/kernelPV/model.py
        ## However, it is wrong. THe correct version is given in the line above.
        # pred = (K_ATildeATest * (self.B.T @ self.w_mean_vec) * self.x_mean_vec).T @ self.alpha
        return pred
    
    def _predict_bridge_func(self, A_test : jnp.ndarray, W_test : jnp.ndarray, X_test : jnp.array = None):
        if A_test.ndim != 2:
            A_test = A_test.reshape(-1, 1)
        K_ATildeATest = self.kernel_A(self.ATilde, A_test)
        K_WWTest = self.kernel_W(self.W, W_test)
        if (X_test is not None) & (self.XTilde is not None):
            K_XTildeXTest = self.kernel_X(self.XTilde, X_test)
        else:
            K_XTildeXTest = jnp.ones((self.ATilde.shape[0], W_test.shape[0]))
        
        # bridge_pred = (K_ATildeATest * (self.B.T @ K_WWTest)).T @ self.alpha 
        bridge_pred = jnp.array([(K_ATildeATest[:, jj].reshape(-1, 1) * ((self.B[:, self.nystrom_stage2_landmarks].T @ K_WWTest) * K_XTildeXTest)).T @ self.alpha for jj in range(A_test.shape[0])])
        return bridge_pred[:, :, 0]
   
treatment_bridge_algo_param_dict_default = {
                                            "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                            "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                            "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                            "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                            "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                            "lambda_" : 1e-3,
                                            "eta" : 1e-3,
                                            "lambda2_" : 1e-3,
                                            "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                            "nystrom_first_stage_m": 500,
                                            "nystrom_third_stage_m": 500,
                                            "stage1_perc" : 0.5,
                                            "model_seed": 0,
                                            "make_psd_eps" : 1e-9,
                                            }

outcome_bridge_kpv_algo_param_dict_default = {
                                             "algorithm_name" : "Kernel_Proxy_Variable",
                                             "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                             "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                             "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                             "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                             "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                             "lambda1_" : 0.1,
                                             "lambda2_" : 0.1,
                                             "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                             "nystrom_first_stage_m": 500,
                                             "nystrom_second_stage_m": 500,
                                             "stage1_perc" : 0.5,
                                             "model_seed": 0,
                                             "make_psd_eps" : 1e-9,
                                                }

class DoublyRobustKernelProxyATE_Nystorm(BaseEstimator, RegressorMixin):

    def __init__(self,
                 treatment_bridge_algo_param_dict : Dict = treatment_bridge_algo_param_dict_default,
                 outcome_bridge_algo_param_dict : Dict = outcome_bridge_kpv_algo_param_dict_default,
                 lambda_DR : float = 1e-3,
                 nystorm_m: int = 500,
                 **kwargs,
                 ):
        self.treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict
        self.outcome_bridge_algo_param_dict = outcome_bridge_algo_param_dict
        self.lambda_DR = lambda_DR
        self.nystorm_m = nystorm_m
        self.model_seed = kwargs.pop('model_seed', 0)
        self.make_psd_eps = kwargs.pop('make_psd_eps', 5e-9)

        self.treatment_bridge_algo = KernelAlternativeProxyATE_Nystorm(**treatment_bridge_algo_param_dict)
        if outcome_bridge_algo_param_dict["algorithm_name"] == "Kernel_Proxy_Variable":
            self.outcome_bridge_algo = KernelProxyVariableATE_Nystorm(**outcome_bridge_algo_param_dict)
        elif outcome_bridge_algo_param_dict["algorithm_name"] == "Proxy_Maximum_Moment_Restriction":
            raise Exception("Not implemented yet!")
        elif outcome_bridge_algo_param_dict["algorithm_name"] == "Kernel_Negative_Control":
            raise Exception("Not implemented yet!")

    def sample_landmarks(self, original_size, m, seed = 0):
        if m > original_size:
            m = original_size
        np.random.seed(seed)
        indices = np.random.choice(original_size, m, replace=False)
        return indices
    
    def fit(self, 
            AWZX: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]], 
            Y: jnp.ndarray,) -> None:
        if len(AWZX) == 4:
            A, W, Z, X = AWZX
        elif len(AWZX) == 3:
            A, W, Z = AWZX
            X = None
        
        self.treatment_bridge_algo.fit(AWZX, Y)
        self.outcome_bridge_algo.fit(AWZX, Y)
            
        self.W = W
        self.Z = Z
        self.X = X
        self.A = A
        self.kernel_A = self.treatment_bridge_algo.kernel_A

    def predict(self, A: jnp.ndarray):
        if A.ndim == 2:
            A = A
        else:
            A = A.reshape(-1, 1)
        
        treatment_bridge_algo = self.treatment_bridge_algo
        outcome_bridge_algo = self.outcome_bridge_algo
        treatment_bridge_algo_pred = treatment_bridge_algo.predict(A).reshape(-1)
        outcome_bridge_algo_pred = outcome_bridge_algo.predict(A).reshape(-1)
        self.treatment_bridge_algo_pred, self.outcome_bridge_algo_pred = treatment_bridge_algo_pred, outcome_bridge_algo_pred
        treatment_bridge_algo_bridge_pred = treatment_bridge_algo._predict_bridge_func(A, self.Z, self.X)
        outcome_bridge_algo_bridge_pred = outcome_bridge_algo._predict_bridge_func(A, self.W, self.X)
        lambda_DR = self.lambda_DR

        nystorm_landmarks = self.sample_landmarks(self.A.shape[0], self.nystorm_m, self.model_seed)
        K_AA_nm = self.kernel_A(self.A, self.A[nystorm_landmarks])
        K_AA_mm = self.kernel_A(self.A[nystorm_landmarks], self.A[nystorm_landmarks])
        DR_KRR_Weights = K_AA_nm @ jnp.linalg.solve(K_AA_nm.T @ K_AA_nm + self.A.shape[0] * lambda_DR * K_AA_mm, self.kernel_A(self.A[nystorm_landmarks], A))
        # identity_matrix = jnp.eye(self.A.shape[0])
        # DR_KRR_Weights = jnp.linalg.solve(make_psd(self.K_AA + self.A.shape[0] * lambda_DR * identity_matrix), outcome_bridge_algo.kernel_A(self.A, A))
        slack_prediction = ((treatment_bridge_algo_bridge_pred * outcome_bridge_algo_bridge_pred) * DR_KRR_Weights.T).sum(1)
        self.slack_prediction = slack_prediction
        f_struct_pred = treatment_bridge_algo_pred + outcome_bridge_algo_pred - slack_prediction
        return f_struct_pred
