import sys
sys.path.append("../../../src")

import os
from time import time
import numpy as np
import pandas as pd
import jax.numpy as jnp
from tqdm import tqdm

from utils.kernel_utils import RBF, ColumnwiseRBF
from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyATE
from utils.experimental_data_functions import generate_synthetic_ATE_data
from utils.ml_utils import data_transform

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

data_size_list = [1000]
seed_list = np.arange(0, 3000, 100)
scale_data_list = [True]

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "DR_Causal_MSE", "Treatment_Causal_MSE", "Outcome_Causal_MSE", 
                                     "DR_Prediction", "Treatment_Prediction", "Outcome_Prediction", "Slack_Prediction", 
                                     "scale_data", "Algo_Run_Time"])
np.random.seed(0)
random_jitter = 0.2 * np.random.randn(501, 1)

for n_plus_m in data_size_list:
    for seed_ in tqdm(seed_list):
        for scale_data in scale_data_list:
            np.random.seed(seed_)
            
            U, W, Z, A, Y, do_A, EY_do_A = generate_synthetic_ATE_data(size = n_plus_m, seed = seed_)
            W, Z, A, Y, do_A, EY_do_A = jnp.array(W), jnp.array(Z), jnp.array(A), jnp.array(Y), jnp.array(do_A), jnp.array(EY_do_A)
            if scale_data:
                A_transformed, A_transformer = data_transform(A)
                Z_transformed, Z_transformer = data_transform(Z)
                W_transformed, W_transformer = data_transform(W)
                Y_transformed, Y_transformer = data_transform(Y)

                data_size = A_transformed.shape[0]
                A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
                Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
                W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
                Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)

            t0 = time()
            treatment_bridge_algo_param_dict_default = {
                                                        "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        "kernel_W" : ColumnwiseRBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                                        "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        # "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        "lambda_" : 1e-3,
                                                        "eta" : 1e-3,
                                                        "lambda2_" : 1e-3,
                                                        "optimize_lambda_parameters" : True,
                                                        "optimize_eta_parameter" : True,
                                                        "lambda_optimization_range" : (5*1e-5, 1.0),
                                                        "eta_optimization_range" : (5*1e-5, 1.0),
                                                        "stage1_perc" : 0.5,
                                                        "regularization_grid_points" : 25, 
                                                        "make_psd_eps" : 5e-9,
                                                        "label_variance_in_lambda_opt" : 0.,
                                                        "label_variance_in_eta_opt" : 1.0,
                                                        }
            outcome_bridge_pmmr_algo_param_dict_default = {
                                                        "algorithm_name" : "Proxy_Maximum_Moment_Restriction",
                                                        "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                        # "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                                        "lambda_" : 1e-3,
                                                        "optimize_lambda_parameter" : True,
                                                        "validation_percentage" : 0.1,
                                                        "lambda_optimization_range" : (5*1e-5, 1e-3),
                                                        "regularization_grid_points" : 25,
                                                        "make_psd_eps" : 5e-9,
                                                        }

            model_DR = DoublyRobustKernelProxyATE(  treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,
                                                    outcome_bridge_algo_param_dict = outcome_bridge_pmmr_algo_param_dict_default,
                                                    lambda_DR = 1*1e-3,
                                                    optimize_lambda_DR_parameter = True,
                                                    lambda_DR_optimization_range = (5*1e-5, 1.0),
                                                    regularization_grid_points = 25, 
                                                    )
            if scale_data:
                model_DR.fit((A_transformed, W_transformed, Z_transformed), Y_transformed)
                model_DR.treatment_bridge_algo.alpha = jnp.array(model_DR.treatment_bridge_algo.alpha + random_jitter)
                do_A_size = do_A.shape[0]
                do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)
                f_struct_pred_transformed = model_DR.predict(do_A_transformed)
                f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)
                f_struct_pred_kap, f_struct_pred_pmmr = model_DR.treatment_bridge_algo_pred, model_DR.outcome_bridge_algo_pred
                slack_prediction = Y_transformer.inverse_transform(model_DR.slack_prediction.reshape(do_A_size, -1)).reshape(do_A_size, -1)
                f_struct_pred_kap = Y_transformer.inverse_transform(f_struct_pred_kap.reshape(do_A_size, -1)).reshape(do_A_size, -1)
                f_struct_pred_pmmr = Y_transformer.inverse_transform(f_struct_pred_pmmr.reshape(do_A_size, -1)).reshape(do_A_size, -1)
            else:
                model_DR.fit((A, W, Z), Y)
                model_DR.treatment_bridge_algo.alpha = jnp.array(model_DR.treatment_bridge_algo.alpha + random_jitter)
                f_struct_pred = model_DR.predict(do_A)
                f_struct_pred_kap, f_struct_pred_pmmr = model_DR.treatment_bridge_algo_pred, model_DR.outcome_bridge_algo_pred
                slack_prediction = model_DR.slack_prediction

            t1 = time()
            algo_run_time = t1 - t0

            structured_pred_mse = (jnp.mean((f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)) ** 2)).item()
            kap_mse = jnp.mean((EY_do_A.reshape(-1, 1) - f_struct_pred_kap.reshape(-1, 1)) ** 2).item()
            pmmr_mse = jnp.mean((EY_do_A.reshape(-1, 1) - f_struct_pred_pmmr.reshape(-1, 1)) ** 2).item()

            DRPMMR_Dict = {
                "Algorithm" : "DRKPV",
                "Data_Size" : n_plus_m,
                "Seed" : seed_,
                "DR_Causal_MSE" : structured_pred_mse,
                "Treatment_Causal_MSE" : kap_mse,
                "Outcome_Causal_MSE" : pmmr_mse, 
                "DR_Prediction" : np.array(f_struct_pred).reshape(-1, 1),
                "Treatment_Prediction" : np.array(f_struct_pred_kap).reshape(-1, 1),
                "Outcome_Prediction" : np.array(f_struct_pred_pmmr).reshape(-1, 1),
                "Slack_Prediction" : np.array(slack_prediction).reshape(-1, 1),
                "scale_data" : scale_data,
                "Algo_Run_Time" : algo_run_time
            }

            df_results = pd.concat([df_results, pd.DataFrame([DRPMMR_Dict])], ignore_index=True)

            df_results.to_pickle("../../Results/DoublyRobustPMMR_TreatmentBridgeMisspecified_SyntheticLowDim_Experiment_V1.pkl")