import sys
sys.path.append("../../../src")

import os
from time import time
import numpy as np
import pandas as pd
import jax.numpy as jnp
from tqdm import tqdm

from utils.kernel_utils import ColumnwiseRBF, RBF
from causal_models.proxy_causal_learning import KernelProxyVariableATE
from utils.ml_utils import data_transform
from utils.experimental_data_functions import read_deaner_dataset

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

data_path = '../../../data/deaner'
id_ = "IM"

data_seed_list = np.arange(100, 1100, 100)
seed_list = np.arange(0, 300, 100)

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "Data_Seed", "Causal_MSE", "Causal_MAE", "Algo_Run_Time"])

for data_seed_ in data_seed_list:
    for seed_ in tqdm(seed_list):

        np.random.seed(seed_)

        W, Z, A, Y, do_A, EY_do_A = read_deaner_dataset(data_path, id_, data_seed_)

        W, Z, A, Y, do_A, EY_do_A = jnp.array(W), jnp.array(Z), jnp.array(A), jnp.array(Y), jnp.array(do_A), jnp.array(EY_do_A)

        filtered_indices = ((A > -0.1) & (A < 2.0)).reshape(-1)
        doA_filtered_indices = ((do_A > -0.1) & (do_A < 2.0)).reshape(-1)
        W, Z, A, Y, do_A, EY_do_A = W[filtered_indices], Z[filtered_indices], A[filtered_indices], Y[filtered_indices], do_A[doA_filtered_indices], EY_do_A[doA_filtered_indices]

        A_transformed, A_transformer = data_transform(A)
        Z_transformed, Z_transformer = data_transform(Z)
        W_transformed, W_transformer = data_transform(W)
        Y_transformed, Y_transformer = data_transform(Y)

        data_size = A_transformed.shape[0]
        A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
        Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
        W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
        Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)

        t0 = time()
        RBF_Kernel_Z = RBF(use_length_scale_heuristic = True, use_jit_call = True)
        RBF_Kernel_A = RBF(use_length_scale_heuristic = True, use_jit_call = True)
        RBF_Kernel_W = RBF(use_length_scale_heuristic = True, use_jit_call = True)

        lambda1_ = 0.01
        lambda2_ = 1.2*1e-1
        optimize_lambda1_parameter = True
        optimize_lambda2_parameter = True
        lambda1_optimization_range = (5*1e-5, 1.0)
        lambda2_optimization_range = (5*1e-5, 1.0)
        stage1_perc = 0.5
        regularization_grid_points = 25
        make_psd_eps = 5*1e-9

        model = KernelProxyVariableATE(
                                            kernel_A = RBF_Kernel_A,
                                            kernel_W = RBF_Kernel_W, 
                                            kernel_Z = RBF_Kernel_Z,
                                            lambda1_ = lambda1_,
                                            lambda2_ = lambda2_,
                                            optimize_lambda1_parameter = optimize_lambda1_parameter,
                                            optimize_lambda2_parameter = optimize_lambda2_parameter,
                                            lambda1_optimization_range = lambda1_optimization_range,
                                            lambda2_optimization_range = lambda2_optimization_range,
                                            stage1_perc = stage1_perc,
                                            regularization_grid_points = regularization_grid_points, 
                                            make_psd_eps = make_psd_eps,
                                            )

        model.fit((A_transformed, W_transformed, Z_transformed,), Y_transformed)
        do_A_size = do_A.shape[0]
        do_A_transformed = (A_transformer.transform(do_A.reshape(-1, 1))).reshape(do_A_size, -1)
        f_struct_pred_transformed = model.predict(do_A_transformed)
        f_struct_pred = Y_transformer.inverse_transform(f_struct_pred_transformed.reshape(do_A_size, -1)).reshape(do_A_size, -1)
        t1 = time()
        algo_run_time = t1 - t0

        structured_pred_mse = (np.mean((f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)) ** 2)).item()
        structured_pred_mae = (np.mean(np.abs(f_struct_pred.reshape(-1, 1) - EY_do_A.reshape(-1, 1)))).item()

        KPV_Dict = {
            "Algorithm" : "KPV",
            "Data_Size" : W.shape[0],
            "Seed" : seed_,
            "Data_Seed" : data_seed_,
            "Causal_MSE" : structured_pred_mse,
            "Causal_MAE" : structured_pred_mae,
            "Algo_Run_Time" : algo_run_time
        }

        df_results = pd.concat([df_results, pd.DataFrame([KPV_Dict])], ignore_index=True)

        df_results.to_pickle("../../Results/KPV_Deaner_IM.pkl")