import sys
sys.path.append("../../../src")

import os
from time import time
import numpy as np
import pandas as pd
import jax.numpy as jnp

from utils.kernel_utils import Kernel, BinaryKernel, ColumnwiseRBF, RBF
from causal_models.causal_learning import KernelCATE
from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyCATE
from utils.experimental_data_functions import generate_synthetic_CATE_data
from utils.ml_utils import data_transform

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

data_size_list = [500, 1000, 2000]
seed_list = np.arange(0, 3000, 100)

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "CATE_Estimation", "MSE"])

for n_plus_m in data_size_list:
    for seed_ in seed_list:
        #######################################################
        #        Data Generation                              #
        #######################################################
        sigma = 1.
        uniform_noise_upper_bound = 1.,
        uniform_noise_lower_bound = -1.,
        U, W, Z, V, A, Y, covariate_v_test, do_A, EY_do_A_CATE = generate_synthetic_CATE_data(2000, 
                                                                                            sigma,
                                                                                            uniform_noise_upper_bound,
                                                                                            uniform_noise_lower_bound,
                                                                                            seed = seed_)

        U, Z, W, V, A, Y = jnp.array(U), jnp.array(Z), jnp.array(W), jnp.array(V), jnp.array(A), jnp.array(Y)

        covariate_v_test, do_A, EY_do_A_CATE = jnp.array(covariate_v_test), jnp.array(do_A), jnp.array(EY_do_A_CATE)

        A_transformed, A_transformer = data_transform(A)
        U_transformed, U_transformer = data_transform(U)
        Z_transformed, Z_transformer = data_transform(Z)
        W_transformed, W_transformer = data_transform(W)
        V_transformed, V_transformer = data_transform(V)
        Y_transformed, Y_transformer = data_transform(Y)

        data_size = A_transformed.shape[0]
        A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
        U_transformed = jnp.array(U_transformed).reshape(data_size, -1)
        Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
        W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
        V_transformed = jnp.array(V_transformed).reshape(data_size, -1)
        Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)

        do_A_size = do_A.shape[0]
        covariate_v_transformed = (V_transformer.transform(covariate_v_test)).reshape(do_A_size, -1)
        #######################################################
        #        Oracle Kernel Method                         #
        #######################################################
        RBF_Kernel_X = RBF(use_length_scale_heuristic = True, use_jit_call = True)
        RBF_Kernel_A = BinaryKernel()
        # RBF_Kernel_A = RBF(length_scale = 0.08, use_median_length_scale_heuristic = False)
        RBF_Kernel_V = RBF(use_length_scale_heuristic = True, use_jit_call = True)

        lambda_ = 1e-3
        lambda2 = 1e-3
        optimize_regularization_parameters = True
        lambda_optimization_range = (1e-5, 1.0)
        regularization_grid_points = 25

        model_KCATE = KernelCATE(
                            kernel_A = RBF_Kernel_A,
                            kernel_V = RBF_Kernel_V,
                            kernel_X = RBF_Kernel_X,
                            lambda_ = lambda_,
                            lambda2 = lambda2,
                            optimize_regularization_parameters = optimize_regularization_parameters,
                            lambda_optimization_range = lambda_optimization_range,
                            regularization_grid_points = regularization_grid_points
                            )

        model_KCATE.fit((A, V_transformed, U_transformed), Y_transformed)

        f_struct_pred_transformed_kcate = model_KCATE.predict(do_A, covariate_v_transformed)
        f_struct_pred_kcate = Y_transformer.inverse_transform(f_struct_pred_transformed_kcate.reshape(do_A_size, -1)).reshape(do_A_size, -1)

        structured_pred_mse_kcate = np.mean((f_struct_pred_kcate.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2)

        Kernel_CATE_Dict = {
            "Algorithm" : "Kernel_CATE",
            "Data_Size" : n_plus_m,
            "Seed" : seed_,
            "CATE_Estimation" : f_struct_pred_kcate,
            "MSE" : structured_pred_mse_kcate,
        }

        df_results = pd.concat([df_results, pd.DataFrame([Kernel_CATE_Dict])], ignore_index=True)


        #######################################################
        #        Doubly Robust KPV                            #
        #######################################################
        treatment_bridge_algo_param_dict_default = {
                                                    "kernel_A" : BinaryKernel(length_scale = 0.5, use_length_scale_heuristic = False, use_jit_call = True),
                                                    "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "lambda_" : 1e-3,
                                                    "eta" : 1e-3,
                                                    "lambda2_" : 1e-3,
                                                    "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                                    "optimize_lambda_parameters" : True,
                                                    "optimize_eta_parameter" : True, 
                                                    "optimize_zeta_parameter" : True, # Only required for ATT or CATE algorithm
                                                    "lambda_optimization_range" : (5*1e-5, 1.0),
                                                    "eta_optimization_range" : (5e-5, 1.0),
                                                    "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT or CATE algorithm
                                                    "stage1_perc" : 0.5,
                                                    "regularization_grid_points" : 25, 
                                                    "make_psd_eps" : 1e-9,
                                                    "label_variance_in_lambda_opt" : 0.,
                                                    "label_variance_in_eta_opt" : 1.0,
                                                    }


        outcome_bridge_kpv_algo_param_dict_default = {
                                                    "algorithm_name" : "Kernel_Proxy_Variable",
                                                    "kernel_A" : BinaryKernel(length_scale = 0.5, use_length_scale_heuristic = False, use_jit_call = True),
                                                    "kernel_W" : ColumnwiseRBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                                    "lambda1_" : 0.1,
                                                    "lambda2_" : 0.1,
                                                    "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                                    "optimize_lambda1_parameter" : True,
                                                    "optimize_lambda2_parameter" : True,
                                                    "optimize_zeta_parameter" : True, # Only required for ATT or CATE algorithm
                                                    "lambda1_optimization_range" : (5e-4, 1.0),
                                                    "lambda2_optimization_range" : (5e-4, 1.0),
                                                    "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT or CATE algorithm
                                                    "stage1_perc" : 0.5,
                                                    "regularization_grid_points" : 25, 
                                                    "make_psd_eps" : 1e-9,
                                                    }

        model_DR = DoublyRobustKernelProxyCATE(
                                            treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,
                                            outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default,
                                            lambda_DR_optimization_range = (5e-5, 1.0),
                                            optimize_lambda_DR_parameter = True,
                                            lambda_DR = 1e-7,
                                            )

        model_DR.fit((A, W_transformed, Z_transformed, V_transformed), Y_transformed)

        f_struct_pred_transformed = model_DR.predict(do_A, covariate_v_transformed)
        f_struct_pred_dr = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(-1, 1))

        f_struct_pred_KAP = Y_transformer.inverse_transform(model_DR.treatment_bridge_algo_pred.reshape(do_A_size, -1)).reshape(do_A_size, -1)
        f_struct_pred_KPV = Y_transformer.inverse_transform(model_DR.outcome_bridge_algo_pred.reshape(do_A_size, -1)).reshape(do_A_size, -1)
        # f_struct_pred_slack = Y_transformer.inverse_transform(model_DR.slack_prediction.reshape(do_A_size, -1)).reshape(do_A_size, -1)

        structured_pred_dr_mse = (np.mean((f_struct_pred_dr.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))
        structured_pred_mse_KAP = (np.mean((f_struct_pred_KAP.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))
        structured_pred_mse_KPV = (np.mean((f_struct_pred_KPV.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))
        # structured_pred_mse_slack = (np.mean((f_struct_pred_slack.reshape(-1, 1) - EY_do_A_CATE.reshape(-1, 1)) ** 2))

        KPV_Dict = {
            "Algorithm" : "KPV",
            "Data_Size" : n_plus_m,
            "Seed" : seed_,
            "CATE_Estimation" : f_struct_pred_KPV,
            "MSE" : structured_pred_mse_KPV,
        }

        KAP_Dict = {
            "Algorithm" : "KAP",
            "Data_Size" : n_plus_m,
            "Seed" : seed_,
            "CATE_Estimation" : f_struct_pred_KAP,
            "MSE" : structured_pred_mse_KAP,
        }

        DRKPV_Dict = {
            "Algorithm" : "DoublyRobust_KPV",
            "Data_Size" : n_plus_m,
            "Seed" : seed_,
            "CATE_Estimation" : f_struct_pred_dr,
            "MSE" : structured_pred_dr_mse,
        }

        df_results = pd.concat([df_results, pd.DataFrame([KPV_Dict])], ignore_index=True)
        df_results = pd.concat([df_results, pd.DataFrame([KAP_Dict])], ignore_index=True)
        df_results = pd.concat([df_results, pd.DataFrame([DRKPV_Dict])], ignore_index=True)

        df_results.to_pickle("../../Results/DoublyRobustKPV_CATE_SyntheticData_Comparison.pkl")