import numpy as np
import jax.numpy as jnp
from jax import jit
from jax.scipy.optimize import minimize as jax_minimize
from sklearn.base import BaseEstimator, RegressorMixin

import sys
sys.path.append("..")

from utils.kernel_utils import RBF
from utils.linalg_utils import make_psd, cartesian_product

from typing import Callable, Tuple, Optional, Union, Dict

from jax import config
config.update("jax_enable_x64", True)

#### THE FOLLOWING PYTHON CLASS IS AN IMPLEMENTATION OF KERNEL INSTRUMENTAL VARIABLE REGRESSION. 
#### THIS ALGORITHM IS PRESENTED IN THE FOLLOWING PAPER:
# Rahul Singh, Maneesh Sahani, and Arthur Gretton. Kernel instrumental variable re-
# gression. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alch ́e-Buc, E. Fox, and
# R. Garnett, editors, Advances in Neural Information Processing Systems, volume 32.
# Curran Associates, Inc., 2019

class KernelIVRegression(BaseEstimator, RegressorMixin):
    """
    Class for Kernel Instrumental Variable Regression.
    """

    def __init__(
        self,
        kernel_X: Callable,
        kernel_Z: Callable,
        lambda_: float = 0.01,
        xi: float = 0.01,
        optimize_regularization_parameters: bool = True,
        regularization_param_optimization_tol: float = 1e-8,
        **kwargs,):
        """
        Initialize KernelIVRegression object.

        Parameters:
        - kernel_X (Callable): Kernel function for X.
        - kernel_Z (Callable): Kernel function for Z.
        - lambda_ (float): Regularization parameter.
        - xi (float): Regularization parameter.
        - optimize_regularization_parameters (bool): Whether to optimize regularization parameters.
        - regularization_param_optimization_tol (float): Tolerance for optimization.
        - kwargs: Additional arguments:
            - stage1_perc: float, percentage of data to use in stage 1 (default: 0.5).
            - kernel_X_params (dict): Parameters for kernel_X.
            - kernel_Z_params (dict): Parameters for kernel_Z.
        """
        stage1_perc = kwargs.pop('stage1_perc', 0.5)
        kernel_X_params = kwargs.pop('kernel_X_params', None)
        kernel_Z_params = kwargs.pop('kernel_Z_params', None)

        if (not isinstance(kernel_X, Callable)):
            raise Exception("Kernel for X must be callable")
        if (not isinstance(kernel_Z, Callable)):
            raise Exception("Kernel for Z must be callable")
            
        self.kernel_X = kernel_X
        self.kernel_Z = kernel_Z
        if kernel_X_params is not None:
            self.kernel_X.set_params(**kernel_X_params)
        if kernel_Z_params is not None:
            self.kernel_Z.set_params(**kernel_Z_params)
                
        self.lambda_ = lambda_
        self.xi = xi
        self.optimize_regularization_parameters = optimize_regularization_parameters 
        self.regularization_param_optimization_tol = regularization_param_optimization_tol
        self.stage1_perc = stage1_perc
        self.kernel_X_params = kernel_X_params
        self.kernel_Z_params = kernel_Z_params

    ########################################################################
    ###################### STATIC JIT FUNCTIONS ############################
    ########################################################################
        
    @staticmethod
    @jit
    def KIV_lambda_objective(log_lambda_: float, K_XX: jnp.ndarray, K_XTildeX: jnp.ndarray, 
                              K_XTildeXTilde: jnp.ndarray, K_ZZ: jnp.ndarray, K_ZZTilde: jnp.ndarray) -> float:
        """
        Kernel Instrumental Variable optimization objective for optimal lambda_ value. See Algorithm 2 in page 24.
        """
        n, m = K_ZZTilde.shape
        gamma_ZTilde = jnp.linalg.inv(make_psd(K_ZZ) + n * jnp.exp(log_lambda_) * jnp.eye(n)) @ K_ZZTilde
        ## There is no need to multiply with 1/m and add K_XTildeXTilde in the trace function for the cost since they are not dependent on 
        ## the lambda value. I commented out the original cost function stated in the paper (see Algorithm 2 in page 24). Below that, I use
        ## the simplified version.
        # L1_lambda = (1 / m) * jnp.trace(K_XTildeXTilde - 2 * K_XTildeX @ gamma_ZTilde + gamma_ZTilde.T @ make_psd(K_XX) @ gamma_ZTilde)
        L1_lambda = jnp.trace(- 2 * K_XTildeX @ gamma_ZTilde + gamma_ZTilde.T @ make_psd(K_XX) @ gamma_ZTilde)
        return L1_lambda

    
    @staticmethod
    @jit
    def KIV_xi_objective(log_xi: float, lambda_: float, Y_train_stage1: jnp.ndarray, Y_train_stage2: jnp.ndarray, 
                         K_XX: jnp.ndarray, K_ZZ: jnp.ndarray, K_ZZTilde: jnp.ndarray) -> float:
        """
        Kernel Instrumental Variable optimization objective for optimal xi value. See Algorithm 2 in page 24.
        """
        n, m = K_ZZTilde.shape
    
        W = K_XX @ jnp.linalg.inv(make_psd(K_ZZ) + n * lambda_ * jnp.eye(n)) @ K_ZZTilde
        alpha = jnp.linalg.inv(make_psd(W @ W.T) + m * jnp.exp(log_xi) * make_psd(K_XX)) @ W @ Y_train_stage2
    
        Y_stage1_pred = (alpha.T @ K_XX).T
        mse_loss = ((Y_stage1_pred - Y_train_stage1) ** 2).mean()
        return mse_loss
    

    def fit(self, XZ: Tuple[jnp.ndarray, jnp.ndarray], Y: jnp.ndarray) -> 'KernelIVRegression':
        """
        Fit the Kernel Instrumental Variable Regression Model.

        Parameters:
        - XZ (Tuple[np.ndarray, np.ndarray]): Input data XZ.
        - Y (np.ndarray): Output data.
        - stage1_perc (float): Percentage for stage 1.

        Returns:
        - KernelIVRegression: Trained model.
        """
        # Read out the model parameters.
        kernel_X = self.kernel_X
        kernel_Z = self.kernel_Z
        lambda_, xi = self.lambda_, self.xi
        optimize_regularization_parameters = self.optimize_regularization_parameters
        stage1_perc = self.stage1_perc
        # Split data for Kernel IV stage 1 and stage 2 training. 
        X, Z = XZ
        data_train = jnp.hstack((X, Y, Z))
        # stage2_perc = 1 - stage1_perc
        train_indices = np.random.permutation(data_train.shape[0])

        stage1_data_size = int(data_train.shape[0] * stage1_perc)
        stage2_data_size = int(data_train.shape[0] - stage1_data_size)
        stage1_idx, stage2_idx = train_indices[:stage1_data_size], train_indices[stage1_data_size:]
        
        X_train_stage1 = X[stage1_idx].reshape(-1, 1)
        Y_train_stage1 = Y[stage1_idx].reshape(-1, 1)
        self.X_train_stage1 = X_train_stage1 # This matrix is required for inference on new data. It is used to construct the kernel matrix for test data points.
        Y_train_stage2 = Y[stage2_idx].reshape(-1, 1)

        # Construct the kernels. X represent the stage 1 X data matrix while XTilde corresponds to stage 2 data matrix. Similarly for Z and ZTilde.
        K_XTrainXTrain = kernel_X(X, X)
        K_ZTrainZTrain = kernel_Z(Z, Z)

        K_XX = K_XTrainXTrain[tuple(cartesian_product(stage1_idx, stage1_idx).T)].reshape(stage1_data_size, stage1_data_size)
        K_XTildeX = K_XTrainXTrain[tuple(cartesian_product(stage2_idx, stage1_idx).T)].reshape(stage2_data_size, stage1_data_size)
        K_XTildeXTilde = K_XTrainXTrain[tuple(cartesian_product(stage2_idx, stage2_idx).T)].reshape(stage2_data_size, stage2_data_size)
        K_ZZ = K_ZTrainZTrain[tuple(cartesian_product(stage1_idx, stage1_idx).T)].reshape(stage1_data_size, stage1_data_size)
        K_ZZTilde = K_ZTrainZTrain[tuple(cartesian_product(stage1_idx, stage2_idx).T)].reshape(stage1_data_size, stage2_data_size)

        if hasattr(self.kernel_X, 'use_length_scale_heuristic'):
            self.kernel_X.use_length_scale_heuristic = False
        if hasattr(self.kernel_Z, 'use_length_scale_heuristic'):
            self.kernel_Z.use_length_scale_heuristic = False

        if optimize_regularization_parameters:
            # jax_minimize is imported as from jax.scipy.optimize import minimize as jax_minimize
            log_lambda_res = jax_minimize(self.KIV_lambda_objective, jnp.array([jnp.log(lambda_)]), 
                                  args=(K_XX, K_XTildeX, K_XTildeXTilde, K_ZZ, K_ZZTilde), method='BFGS', # Only available method is 'BFGS'
                                          tol=self.regularization_param_optimization_tol, )
            lambda_ = jnp.exp(log_lambda_res.x.item())
            log_xi_res = jax_minimize(self.KIV_xi_objective, jnp.array([jnp.log(xi)]), 
                              args=(lambda_, Y_train_stage1, Y_train_stage2, K_XX, K_ZZ, K_ZZTilde), method='BFGS', 
                                      tol=self.regularization_param_optimization_tol, )
            xi = jnp.exp(log_xi_res.x.item())
            self.lambda_, self.xi = lambda_, xi

        # Establish the etimator (Kernel IV estimator)
        # make_psd function is used for numerical stability.
        W = K_XX @ jnp.linalg.inv(make_psd(K_ZZ) + stage1_data_size * lambda_ * jnp.eye(stage1_data_size)) @ K_ZZTilde
        alpha = jnp.linalg.inv(make_psd(W @ W.T) + stage2_data_size * xi * make_psd(K_XX)) @ (W @ Y_train_stage2)
        self.W = W
        self.alpha = alpha

        return self

    
    def predict(self, XZ: Union[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]) -> jnp.ndarray:
        """
        In prediction, only the X values are required. However, in order to keep the function arguments consistent (i.e. compared to fit function), we will allow user to           pass both X and Z values as well as only X values. Below we check whether the user passed both X and Z or only X.
        """
        if isinstance(XZ, tuple) and (len(XZ) == 2):
            X, Z = XZ
        elif isinstance(XZ, np.ndarray) or isinstance(XZ, jnp.ndarray):
            X = XZ
        else:
            raise Exception("For inference, the parameter XZ must be either a tuple consisting both X and Z or a (j)np.ndarray consisting only X.")
        
        alpha = self.alpha
        K_XTest = self.kernel_X(self.X_train_stage1, X)
        Y_pred = (alpha.T @ K_XTest).T
        return Y_pred
    

    def fit_predict(self, XZ: Union[Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray], Y: jnp.ndarray, stage1_perc: float = 0.5) -> jnp.ndarray:
        """
        Fit the model and make predictions.

        Parameters:
        - XZ (Union[Tuple[jnp.ndarray, jnp.ndarray], np.ndarray]): Input data XZ. (It is either a Tuple containing X and Z or jnp.ndarray only containing X)
        - Y (jnp.ndarray): Output data.
        - stage1_perc (float): Percentage for stage 1.

        Returns:
        - jnp.ndarray: Predicted values.
        """
        self.fit(XZ, Y, stage1_perc)
        return self.predict(XZ)

