import sys
sys.path.append("../../../src")

import os
from time import time
import numpy as np
import pandas as pd
import jax.numpy as jnp

from utils.kernel_utils import Kernel, ColumnwiseRBF, RBF
from causal_models.causal_learning import KernelATT
from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyATT
from causal_models.proxy_causal_learning import KernelAlternativeProxyATT, KernelProxyVariableATT
from utils.experimental_data_functions import generate_synthetic_ATE_data
from utils.ml_utils import data_transform

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

data_size_list = [500, 1000, 2000]
seed_list = np.arange(0, 3000, 100)
a_prime_list = [-1.0, -0.5, 0.25, 0.5, 1.0, 1.5]

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "a_prime", "ATT_Estimation", "MSE_dist_to_oracle"])

for n_plus_m in data_size_list:
    for seed_ in seed_list:
            for a_prime in a_prime_list:
                np.random.seed(seed_)
                U, W, Z, A, Y, do_A, EY_do_A = generate_synthetic_ATE_data(size = n_plus_m, seed = seed_)
                W_train, Z_train, A_train, Y_train, do_A, EY_do_A = jnp.array(W), jnp.array(Z), jnp.array(A), jnp.array(Y), jnp.array(do_A), jnp.array(EY_do_A)
                a_prime = jnp.array([a_prime])

                A_transformed, A_transformer = data_transform(A)
                a_prime_transformed = A_transformer.transform(a_prime.reshape(-1, 1))
                Z_transformed, Z_transformer = data_transform(Z)
                W_transformed, W_transformer = data_transform(W)
                Y_transformed, Y_transformer = data_transform(Y)
                U_transformed, U_transformer = data_transform(U)

                data_size = A_transformed.shape[0]
                A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
                Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
                W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
                Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)
                U_transformed = jnp.array(U_transformed).reshape(data_size, -1)
                #######################################################
                #        Oracle Kernel Method                         #
                #######################################################
                do_A_size = do_A.shape[0]
                do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)

                RBF_Kernel_A_ = RBF(use_length_scale_heuristic = True, use_jit_call = True)
                RBF_Kernel_X_ = RBF(use_length_scale_heuristic = True, use_jit_call = True)

                model_KernelATT = KernelATT(
                    kernel_A = RBF_Kernel_A_,
                    kernel_X = RBF_Kernel_X_,
                    lambda_optimization_range = (5e-5, 1.0)
                )
                model_KernelATT.fit((A_transformed, U_transformed), Y_transformed)
                f_struct_pred_transformed_katt = model_KernelATT.predict(do_A_transformed, a_prime_transformed)
                f_struct_pred_katt = np.array(Y_transformer.inverse_transform(f_struct_pred_transformed_katt.reshape(do_A_size, -1)).reshape(do_A_size, -1))

                Kernel_ATT_Dict = {
                    "Algorithm" : "Kernel_ATT",
                    "Data_Size" : n_plus_m,
                    "Seed" : seed_,
                    "a_prime" : a_prime.item(),
                    "ATT_Estimation" : f_struct_pred_katt,
                    "MSE_dist_to_oracle" : None,
                }

                df_results = pd.concat([df_results, pd.DataFrame([Kernel_ATT_Dict])], ignore_index=True)

                #######################################################
                #        Doubly Robust Kernel Proxy                   #
                #######################################################
                treatment_bridge_algo_param_dict_default = {
                                                            "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "kernel_W" : ColumnwiseRBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                                            "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "lambda_" : 1e-3,
                                                            "eta" : 1e-3,
                                                            "lambda2_" : 1e-3,
                                                            "zeta" : 1e-3, # Only required for ATT algorithm
                                                            "optimize_lambda_parameters" : True,
                                                            "optimize_eta_parameter" : True, 
                                                            "optimize_zeta_parameter" : True, # Only required for ATT algorithm
                                                            "lambda_optimization_range" : (5e-5, 1.0),
                                                            "eta_optimization_range" : (5e-5, 1.0),
                                                            "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT algorithm
                                                            "stage1_perc" : 0.5,
                                                            "regularization_grid_points" : 25, 
                                                            "make_psd_eps" : 1e-9,
                                                            "label_variance_in_lambda_opt" : 0.,
                                                            "label_variance_in_eta_opt" : 1.0,
                                                            }

                outcome_bridge_kpv_algo_param_dict_default = {
                                                            "algorithm_name" : "Kernel_Proxy_Variable",
                                                            "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                            "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                                            "lambda1_" : 0.1,
                                                            "lambda2_" : 0.1,
                                                            "zeta" : 1e-3, # Only required for ATT algorithm
                                                            "optimize_lambda1_parameter" : True,
                                                            "optimize_lambda2_parameter" : True,
                                                            "optimize_zeta_parameter" : True, # Only required for ATT algorithm
                                                            "lambda1_optimization_range" : (5e-4, 1.0),
                                                            "lambda2_optimization_range" : (5e-4, 1.0),
                                                            "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT algorithm
                                                            "stage1_perc" : 0.5,
                                                            "regularization_grid_points" : 25, 
                                                            "make_psd_eps" : 1e-9,
                                                            }

                model_DR = DoublyRobustKernelProxyATT(
                                                    treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,
                                                    outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default,
                                                    lambda_DR_optimization_range = (5e-5, 1.0),
                                                    )


                model_DR.fit((A_transformed, W_transformed, Z_transformed), Y_transformed, a_prime_transformed)
                f_struct_pred_transformed = model_DR.predict(do_A_transformed)
                f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)

                f_struct_pred_KAP = Y_transformer.inverse_transform(model_DR.treatment_bridge_algo_pred.reshape(do_A_size, -1)).reshape(do_A_size, -1)
                f_struct_pred_KPV = Y_transformer.inverse_transform(model_DR.outcome_bridge_algo_pred.reshape(do_A_size, -1)).reshape(do_A_size, -1)

                dr_mse = np.mean((f_struct_pred_katt - f_struct_pred)**2)
                kpv_mse = np.mean((f_struct_pred_katt - f_struct_pred_KPV)**2)
                kap_mse = np.mean((f_struct_pred_katt - f_struct_pred_KAP)**2)

                KPV_Dict = {
                    "Algorithm" : "KPV",
                    "Data_Size" : n_plus_m,
                    "Seed" : seed_,
                    "a_prime" : a_prime.item(),
                    "ATT_Estimation" : f_struct_pred_KPV,
                    "MSE_dist_to_oracle" : kpv_mse,
                }

                KAP_Dict = {
                    "Algorithm" : "KAP",
                    "Data_Size" : n_plus_m,
                    "Seed" : seed_,
                    "a_prime" : a_prime.item(),
                    "ATT_Estimation" : f_struct_pred_KAP,
                    "MSE_dist_to_oracle" : kap_mse,
                }

                DRKPV_Dict = {
                    "Algorithm" : "DoublyRobust_KPV",
                    "Data_Size" : n_plus_m,
                    "Seed" : seed_,
                    "a_prime" : a_prime.item(),
                    "ATT_Estimation" : f_struct_pred,
                    "MSE_dist_to_oracle" : dr_mse,
                }

                df_results = pd.concat([df_results, pd.DataFrame([KPV_Dict])], ignore_index=True)
                df_results = pd.concat([df_results, pd.DataFrame([KAP_Dict])], ignore_index=True)
                df_results = pd.concat([df_results, pd.DataFrame([DRKPV_Dict])], ignore_index=True)

                df_results.to_pickle("../../Results/DoublyRobustKPV_ATT_SyntheticData_Comparison.pkl")

















