import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import pandas as pd
sys.path.append("../../../src")

from utils.kernel_utils import Kernel, ColumnwiseRBF, RBF
from causal_models.causal_learning import KernelATT
from causal_models.doubly_robust_pcl import DoublyRobustKernelProxyATT
from causal_models.proxy_causal_learning import KernelAlternativeProxyATT, KernelProxyVariableATT
from utils.ml_utils import data_transform
from utils.linalg_utils import cartesian_product, make_psd
from utils.experimental_data_functions import generate_train_jobcorp_data

if not os.path.exists("../../Results"):
    os.mkdir("../../Results")

def Lambda(t):
    return (0.9-0.1)*np.exp(t)/(1+np.exp(t)) + 0.1

seed_list = np.arange(0, 500, 100)
a_prime_list = [500, 1000, 1500]

df_results = pd.DataFrame(columns = ["Algorithm", "Data_Size", "Seed", "a_prime", "ATT_Estimation", "MSE_dist_to_oracle"])

def generate_misspecified_job_corps_data():
    data_path = '../../../data/JCdata.csv'
    U, Y, A = generate_train_jobcorp_data(data_path)
    U = jnp.array(U, dtype = jnp.float64)
    Y = jnp.array(Y, dtype = jnp.float64)
    A = jnp.array(A, dtype = jnp.float64)

    Z = U + jnp.array(np.random.normal(*U.shape)) / 1.
    W = Lambda(U / (U.max(1) + 1e-3).reshape(-1, 1) ) + (2 * jnp.array(np.random.uniform(*U.shape)) - 1) * 1. 
    W = W[:, :20]
    return U, W, Z, A, Y

for seed_ in seed_list:
    for a_prime in a_prime_list:
        a_prime = jnp.array([a_prime])
        np.random.seed(seed_)

        U, W, Z, A, Y = generate_misspecified_job_corps_data()
        do_A = jnp.linspace(40, 2000, 1000)[:, jnp.newaxis]
        do_A_size = do_A.shape[0]

        A_transformed, A_transformer = data_transform(A)
        do_A_transformed = (A_transformer.transform(do_A)).reshape(do_A_size, -1)
        a_prime_transformed = A_transformer.transform(a_prime.reshape(-1, 1))
        Z_transformed, Z_transformer = data_transform(Z)
        W_transformed, W_transformer = data_transform(W)
        Y_transformed, Y_transformer = data_transform(Y)
        U_transformed, U_transformer = data_transform(U)

        data_size = A_transformed.shape[0]
        A_transformed = jnp.array(A_transformed).reshape(data_size, -1)
        Z_transformed = jnp.array(Z_transformed).reshape(data_size, -1)
        W_transformed = jnp.array(W_transformed).reshape(data_size, -1)
        Y_transformed = jnp.array(Y_transformed).reshape(data_size, -1)
        U_transformed = jnp.array(U_transformed).reshape(data_size, -1)
        #######################################################
        #        Oracle Kernel Method                         #
        #######################################################

        kernel_X = RBF(use_length_scale_heuristic = True, use_jit_call = True)
        kernel_A = RBF(use_length_scale_heuristic = True, )
        optimize_regularization_parameters = True

        KernelATT_model = KernelATT(kernel_X = kernel_X,
                                    kernel_A = kernel_A, 
                                    optimize_regularization_parameters = optimize_regularization_parameters, 
                                    lambda_optimization_range = (5*1e-5, 1.0), 
                                    regularization_grid_points = 25)

        KernelATT_model.fit((A_transformed, U_transformed), Y)


        f_struct_pred_katt = KernelATT_model.predict(do_A_transformed, a_prime_transformed)


        Kernel_ATT_Dict = {
            "Algorithm" : "Kernel_ATT",
            "Data_Size" : A.shape[0],
            "Seed" : seed_,
            "a_prime" : a_prime.item(),
            "ATT_Estimation" : f_struct_pred_katt,
            "MSE_dist_to_oracle" : None,
        }

        df_results = pd.concat([df_results, pd.DataFrame([Kernel_ATT_Dict])], ignore_index=True)



        treatment_bridge_algo_param_dict_default = {
                                                    "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True), 
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "lambda_" : 1e-3,
                                                    "eta" : 1e-3,
                                                    "lambda2_" : 1e-3,
                                                    "zeta" : 1e-3, # Only required for ATT algorithm
                                                    "optimize_lambda_parameters" : True,
                                                    "optimize_eta_parameter" : True, 
                                                    "optimize_zeta_parameter" : True, # Only required for ATT algorithm
                                                    "lambda_optimization_range" : (5e-5, 1.0),
                                                    "eta_optimization_range" : (5e-5, 1.0),
                                                    "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT algorithm
                                                    "stage1_perc" : 0.5,
                                                    "regularization_grid_points" : 25, 
                                                    "make_psd_eps" : 1e-9,
                                                    "label_variance_in_lambda_opt" : 0.,
                                                    "label_variance_in_eta_opt" : 0.0,
                                                    }

        outcome_bridge_kpv_algo_param_dict_default = {
                                                    "algorithm_name" : "Kernel_Proxy_Variable",
                                                    "kernel_A" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_W" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_Z" : RBF(use_length_scale_heuristic = True, use_jit_call = True),
                                                    "kernel_X" : RBF(use_length_scale_heuristic = True, use_jit_call = True),      
                                                    "lambda1_" : 0.1,
                                                    "lambda2_" : 0.1,
                                                    "zeta" : 1e-3, # Only required for ATT algorithm
                                                    "optimize_lambda1_parameter" : True,
                                                    "optimize_lambda2_parameter" : True,
                                                    "optimize_zeta_parameter" : True, # Only required for ATT algorithm
                                                    "lambda1_optimization_range" : (5*1e-4, 1.0),
                                                    "lambda2_optimization_range" : (5*1e-4, 1.0),
                                                    "zeta_optimization_range" : (5e-5, 1.0), # Only required for ATT algorithm
                                                    "stage1_perc" : 0.5,
                                                    "regularization_grid_points" : 25, 
                                                    "make_psd_eps" : 1e-9,
                                                    }

        model_DR = DoublyRobustKernelProxyATT(
                                            treatment_bridge_algo_param_dict = treatment_bridge_algo_param_dict_default,
                                            outcome_bridge_algo_param_dict = outcome_bridge_kpv_algo_param_dict_default, 
                                            lambda_DR_optimization_range = (5e-5, 1.),
                                            )

        model_DR.fit((A_transformed, W_transformed, Z_transformed), Y, a_prime_transformed)
        f_struct_pred = model_DR.predict(do_A_transformed).reshape(do_A_size, -1)

        f_struct_pred_KAP = model_DR.treatment_bridge_algo_pred.reshape(do_A_size, -1)
        f_struct_pred_KPV = model_DR.outcome_bridge_algo_pred.reshape(do_A_size, -1)

        dr_mse = np.mean((f_struct_pred_katt - f_struct_pred)**2)
        kpv_mse = np.mean((f_struct_pred_katt - f_struct_pred_KPV)**2)
        kap_mse = np.mean((f_struct_pred_katt - f_struct_pred_KAP)**2)

        KPV_Dict = {
            "Algorithm" : "KPV",
            "Data_Size" : A.shape[0],
            "Seed" : seed_,
            "a_prime" : a_prime.item(),
            "ATT_Estimation" : f_struct_pred_KPV,
            "MSE_dist_to_oracle" : kpv_mse,
        }

        KAP_Dict = {
            "Algorithm" : "KAP",
            "Data_Size" : A.shape[0],
            "Seed" : seed_,
            "a_prime" : a_prime.item(),
            "ATT_Estimation" : f_struct_pred_KAP,
            "MSE_dist_to_oracle" : kap_mse,
        }

        DRKPV_Dict = {
            "Algorithm" : "DoublyRobust_KPV",
            "Data_Size" : A.shape[0],
            "Seed" : seed_,
            "a_prime" : a_prime.item(),
            "ATT_Estimation" : f_struct_pred,
            "MSE_dist_to_oracle" : dr_mse,
        }

        df_results = pd.concat([df_results, pd.DataFrame([KPV_Dict])], ignore_index=True)
        df_results = pd.concat([df_results, pd.DataFrame([KAP_Dict])], ignore_index=True)
        df_results = pd.concat([df_results, pd.DataFrame([DRKPV_Dict])], ignore_index=True)

        df_results.to_pickle("../../Results/DoublyRobustKPV_ATT_JobCorps_Misspecified_Setting1.2.pkl")