import pandas as pd
from tqdm import tqdm
from scipy.io import loadmat
import numpy as np
from sklearn.preprocessing import StandardScaler
from collections import defaultdict as ddict
from Base_Solver import Base_Solver



if __name__ == "__main__": 
    execution_time = []
    np.random.seed(19991115)
    data_mat = loadmat('../data_bermuda/SEM_data.mat')
    for key in ['__header__', '__version__', '__globals__', 'Site', 'Lat', 'Lon', 'Year', 'Month', 'Day']:
        data_mat.pop(key)
    data = pd.DataFrame({key: value.reshape(-1) for key, value in data_mat.items()})
    scaler = StandardScaler()
    cols = list(data.columns)
    data = scaler.fit_transform(data)
    data = pd.DataFrame(data, columns=cols)
    vars_stage_1 = ['Light', 'Temp', 'Sal']
    vars_stage_2 = ['DIC', 'TA', 'Omega', 'Nutrients_PC1', 'Chla', 'pHsw', 'CO2']
    vars_stage_3 = ['NEC']

    nodes = vars_stage_1 + vars_stage_2 + vars_stage_3
    nodes_stage = [vars_stage_1, vars_stage_2, vars_stage_3]

    parents = {'Light': [], 'Chla': ['Nutrients_PC1', 'Light', 'Temp'], 'Temp': ['Light'], 'Sal': ['Temp'],
               'Omega': ['Sal', 'DIC', 'Temp', 'TA'], 'pHsw': ['Sal', 'DIC', 'Temp', 'TA'], 'DIC': ['Sal'],
               'TA': ['Sal'], 'CO2': ['Sal', 'TA', 'DIC', 'Temp'], 'Nutrients_PC1': [],
               'NEC': ['Nutrients_PC1', 'Light', 'pHsw', 'Omega', 'Chla', 'CO2', 'Temp']}


    

    theta_true = ddict(lambda: ddict(lambda: 0.0))

    theta_true["Light"]["Temp"] = 0.08336954980497932
    theta_true["Temp"]["Sal"] = -0.4809837373167684
    theta_true["Sal"]["DIC"] = 0.4777168829735183
    theta_true["Sal"]["TA"] = 0.5457734531124397
    theta_true["Temp"]["Omega"] = 0.5182253589055726
    theta_true["Sal"]["Omega"] = 0.03507218718555735 
    theta_true["DIC"]["Omega"] = -1.1056652215053286
    theta_true["TA"]["Omega"] = 1.6104231541803835
    theta_true["Light"]["Chla"] = -0.15106684218500258
    theta_true["Temp"]["Chla"] = -0.04451583134247557
    theta_true["Nutrients_PC1"]["Chla"] = -0.07690378415962325
    theta_true["Temp"]["pHsw"] = -0.7482789216296077 
    theta_true["Sal"]["pHsw"] = 0.013001179873522933
    theta_true["TA"]["pHsw"] = 0.7676261914081877
    theta_true["DIC"]["pHsw"] = -0.5879618774787132
    theta_true["Temp"]["CO2"] = 0.8613318110706953  
    theta_true["Sal"]["CO2"] = 0.04051812201172802
    theta_true["DIC"]["CO2"] = 0.5700488513842487
    theta_true["TA"]["CO2"] = -0.596251974686561
    theta_true["Light"]["NEC"] = 0.0322460348829162
    theta_true["Temp"]["NEC"] = 5.227658403563992
    theta_true["Omega"]["NEC"] = -2.343629162533968 
    theta_true["Chla"]["NEC"] = 0.13182892043084415
    theta_true["Nutrients_PC1"]["NEC"] =  0.09881771775808317
    theta_true["pHsw"]["NEC"] = 2.0492558654639
    theta_true["CO2"]["NEC"] = -2.5146414696724295 


    lagged_theta_true = ddict(lambda: ddict(lambda: 0.0))
    lagged_theta_true["Light"]["Light"] = 0.6
    lagged_theta_true["Temp"]["Temp"] = 0.6
    lagged_theta_true["Temp"]["DIC"] = -0.1
    lagged_theta_true["Sal"]["Sal"] = 0.6
    lagged_theta_true["Sal"]["DIC"] = 0.23
    lagged_theta_true["Sal"]["TA"] = 0.25
    lagged_theta_true["TA"]["Chla"] = -0.1
    lagged_theta_true["CO2"]["NEC"] = -1.1
    lagged_theta_true["NEC"]["NEC"] = 0.6

    A_true = np.zeros((len(nodes), len(nodes)))
    B_true = np.zeros((len(nodes), len(nodes)))

    idx = {var: i for i, var in enumerate(nodes)}

    for parent, children in theta_true.items():
        for child, value in children.items():
            A_true[idx[child], idx[parent]] = value  

    for parent, children in lagged_theta_true.items():
        for child, value in children.items():
            B_true[idx[child], idx[parent]] = value  
    C_true = np.array([ [1.2e-2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 1.6e-2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 1.6e-2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 1.0e-2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 2.0e-2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 1.6e-2, 0.0, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6e-2, 0.0, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8e-2, 0.0, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.001, 0.0, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.6e-3, 0.0],
                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5e-1],
                        ])
    

    
    noise_type = 'laplace'
    # noise_type = 'gaussian'
    Para_true = [ A_true, B_true, C_true, noise_type]

    binary_edge = []

    task_para = [100, 9] # t_0, T
    val_times = 1000
    times = 5

    seed_list = np.random.choice(np.arange(times*50), times, replace=False).tolist()



    
    succ_prob_list = []

    def evaluate_Y(Y_value):
        sign = 1. <= float(Y_value) <= 1.1
        return 1 if sign else 0
    
    for rnd_seed in tqdm(seed_list):
        np.random.seed(rnd_seed) 
        ba = Base_Solver(
                nodes = nodes, 
                Para_true = Para_true, 
                task_para = task_para,
                binary_edge = binary_edge,
                val_times = val_times,
                nodes_stage=nodes_stage,
                evaluate_func=evaluate_Y
            )

        succ_prob = ba.AUF_prob()
        succ_prob_list.append(succ_prob)

    print("Data: Bermuda", "\tNoise type:", noise_type, "\tWindow length:", task_para[1], "\tApproach: Baseline")
    print("Success probability:\t", np.mean(succ_prob_list), '+-', np.std(succ_prob_list))












