import sys
sys.path.append("../../../src")

import os
from time import time
import numpy as np
import pandas as pd
import jax.numpy as jnp

from utils.kernel_utils import RBF, ColumnwiseRBF
from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyATE
from utils.experimental_data_functions import dSprite_ProxyVariable_DatasetV2
from utils.ml_utils import data_transform

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

data_path = '../../../data/dsprite'
data_size_list = [1200, 2500, 5000, 7500, 10000, 15000]
seed_list = np.arange(0, 600, 100)

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "Causal_MSE", "Causal_MAE", "Algo_Run_Time"])

for n_plus_m in data_size_list:
    for seed_ in seed_list:
        print("Data size", n_plus_m)
        np.random.seed(seed_)
        
        dsprite_data_generator = dSprite_ProxyVariable_DatasetV2()
        A, Y, Z, W, do_A, EY_do_A = dsprite_data_generator.generate_dsprite_pv(data_path, n_sample = n_plus_m, generate_test = True, rand_seed = seed_)

        A_transformed, A_transformer = data_transform(A)
        Z_transformed, Z_transformer = data_transform(Z)
        W_transformed, W_transformer = data_transform(W)
        Y_transformed, Y_transformer = data_transform(Y)

        data_size = A_transformed.shape[0]
        A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
        Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
        W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
        Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)
        
        t0 = time()
        treatment_bridge_algo_param_dict_default = {
                                                    "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "lambda_" : 5e-4,
                                                    "eta" : 5e-4,
                                                    "lambda2_" : 5e-4,
                                                    "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                                    "optimize_lambda_parameters" : False,
                                                    "optimize_eta_parameter" : False,
                                                    "stage1_perc" : 0.5,
                                                    "model_seed": seed_,
                                                    "make_psd_eps" : 1e-9,
                                                    }

        outcome_bridge_kpv_algo_param_dict_default = {
                                                    "algorithm_name" : "Kernel_Proxy_Variable",
                                                    "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_V" : RBF(use_length_scale_heuristic = True, use_jit_call = True), # Only required for CATE algorithm
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                                    "lambda1_" : 5e-4,
                                                    "lambda2_" : 5e-4,
                                                    "zeta" : 1e-3, # Only required for ATT or CATE algorithm
                                                    "optimize_lambda1_parameter" : False,
                                                    "optimize_lambda2_parameter" : False,
                                                    "stage1_perc" : 0.5,
                                                    "model_seed": seed_,
                                                    "make_psd_eps" : 1e-9,
                                                    }

        model_DR = DoublyRobustKernelProxyATE(  treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,
                                                outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default,
                                                lambda_DR = 5e-4,
                                                optimize_lambda_DR_parameter = False,
                                                )

        model_DR.fit((A_transformed, W_transformed, Z_transformed), Y_transformed)
        do_A_size = do_A.shape[0]
        do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)
        f_struct_pred_transformed = model_DR.predict(do_A_transformed)
        f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)

        t1 = time()
        algo_run_time = t1 - t0

        structured_pred_mse = (jnp.mean((f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)) ** 2)).item()
        structured_pred_mae = (jnp.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)))).item()

        DRKPV_Dict = {
            "Algorithm" : "DRKPV",
            "Data_Size" : n_plus_m,
            "Seed" : seed_,
            "Causal_MSE" : structured_pred_mse,
            "Causal_MAE" : structured_pred_mae,
            "Algo_Run_Time" : algo_run_time
        }

        df_results = pd.concat([df_results, pd.DataFrame([DRKPV_Dict])], ignore_index=True)

        df_results.to_pickle("../../Results/DoublyRobustKPV_Original_RunTime_dSprite_Experiment.pkl")