import sys
sys.path.insert(0, '.')
from model.stonet_multicause import StoNet_MultiCause_Model
from experiment.data import MultipleSim_Separable, MultipleSim_NonSeparable, data_preprocess
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import torch
import numpy as np
import os
import errno
from pickle import dump, load
import json
import random
import matplotlib.pyplot as plt
from model.metrics import marginal_effect_eval


def train_stonet(configs, data_seed):
    # task
    classification_flag = configs["classification"]

    # load dataset
    train_size = configs["train_size"]
    val_size = configs["train_size"]
    input_size = configs["input_size"]
    sim_type = configs["sim_type"]

    if sim_type == 'separable':
        data = MultipleSim_Separable(input_size, train_size + val_size * 2, configs["confounder_size"], data_seed)
    elif sim_type == 'nonseparable':
        data = MultipleSim_NonSeparable(input_size, train_size + val_size * 2, configs["confounder_size"], data_seed)
    
    train_set, val_set, test_set, _, y_scalar, _, _, _ = data_preprocess(data, partition_seed=1, 
                                                                         x_scale=False, y_scale=True, 
                                                                         train_size=train_size, val_size=val_size, test_size=val_size)
    y_scale = y_scalar.scale_[0]
    
    # load training data and validation data
    batch_size = int(train_size/50)
    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_data = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_data = DataLoader(test_set, batch_size=val_size, shuffle=False)

    # # sample treatmnet, as a preparation for sampling marginal distribution of the latenet varibale z
    # 
    # 
    # 
    #  = configs["sample_batch"]
    # if len(test_set) <= sample_batch:
    #     sampled_treat= data.a[test_set.indices]
    # else:
    #     random.seed(data_seed)
    #     sampled_idx = random.sample(test_set.indices, sample_batch)
    #     sampled_treat = data.a[sampled_idx]
    # sampled_treat = torch.FloatTensor(sampled_treat)

    # path to save the results
    PATH = os.path.join(configs["path"], str(data_seed))
    if not os.path.isdir(PATH):
        try:
            os.makedirs(PATH)
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                pass
            else:
                raise
    
    # load counters
    with open(os.path.join(configs["counter_path"], "counter.pkl"), "rb") as file:
        counter = load(file)
    counter += 1
    with open (os.path.join(configs["counter_path"], "counter.pkl"), "wb") as file:
        dump(counter, file)

    # network args
    sigma_list = configs["sigma"]
    hidden_dim = configs["hidden_dim"]
    confounder_layer = configs["confounder_layer"]
    net_args = dict(hidden_dim = hidden_dim, 
                    input_dim = input_size,
                    output_dim = len(data.y.unique()) if classification_flag else 1,
                    confounder_layer = confounder_layer, 
                    sigma_list = sigma_list,
                    outcome_cat=classification_flag,
                    log_dir = os.path.join(PATH, "train_log", "train_"+str(counter)))
                    # log_dir = os.path.join(PATH, "train_log"))

    # imputation parameters
    impute_lrs = configs["impute_lr"]
    impute_lr_decay = configs["impute_lr_decay"]
    itas = [0.1/impute_lrs[i] for i in range(len(impute_lrs))]
    mh_step = configs["mh_step"]

    # training setting
    para_lrs_train = configs["para_lr_train"]
    para_lrs_fine_tune = configs["para_lr_fine_tune"]
    training_epochs = configs["train_epoch"]
    pretrain_epochs = configs["pretrain_epoch"]
    fine_tune_epochs = configs["fine_tune_epoch"]
    para_lr_decay = configs["para_lr_decay"]

    # prior parameters
    prior_sigma_0 = configs["sigma0"]
    prior_sigma_1 = configs["sigma1"]
    lambda_n = configs["lambda_n"]

    # training args
    optim_args = dict(train_data=train_data, 
                      val_data=val_data,
                      batch_size=batch_size,
                      mh_step=mh_step, 
                      itas=itas,
                      para_lrs_train = para_lrs_train,
                      para_lrs_fine_tune = para_lrs_fine_tune,
                      para_lr_decay=para_lr_decay, 
                      impute_lr_decay=impute_lr_decay,
                      prior_sigma_0=prior_sigma_0, 
                      prior_sigma_1=prior_sigma_1, 
                      lambda_n=lambda_n,
                      y_scale=y_scale)

    # threshold for sparsity
    threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
            0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))

    # training results containers
    results = dict(dim=0, BIC=0, num_selection=0, train_loss=0, val_loss=0)
    BIC_list = []  # BIC value for model selection
    if classification_flag:
        results.update([('train_acc', 0), ('val_acc', 0)])
    
    # training starts here!
    for prune_seed in range(configs["num_run"]):
        print('number of runs', prune_seed)

        # initialize network
        net = StoNet_MultiCause_Model(seed=prune_seed, **net_args)

        # pretrain
        print("Pretrain")
        net.train(mode="pretrain", epochs=pretrain_epochs, impute_lrs=impute_lrs, **optim_args)

        # train
        print("Train")
        net.train(mode="train", epochs=training_epochs, impute_lrs=impute_lrs, **optim_args)
        var_gamma_train = net.input_gamma_path["var_selected"]
        num_gamma_train = net.input_gamma_path["num_selected"]
        impute_lrs_fine_tune = net.step_impute_lrs

        # prune the network
        net.prune_network(threshold)

        # refine non-zero network parameters
        print("Refine Weight")
        net.train(mode="finetune", epochs=fine_tune_epochs, impute_lrs=impute_lrs_fine_tune, **optim_args)
        likelihoods = net.hidden_likelihood
        performance_fine_tune = net.performance

        # calculate BIC
        BIC, num_non_zero_element= net.BIC_and_non_zero_para(train_set, likelihoods)
        BIC_list.append(BIC)

        print("number of non-zero connections:", num_non_zero_element)
        print('BIC:', BIC)

        # save model training results for the model with the smallest BIC
        best_model = None
        if BIC == min(BIC_list):
            best_model = net
            results['num_selection'] = num_gamma_train[training_epochs-1].item()
            results['train_loss'] = performance_fine_tune['train_loss']
            results['val_loss'] = performance_fine_tune['val_loss']
            if classification_flag:
                results['train_acc'] = performance_fine_tune['train_acc']
                results['val_acc'] = performance_fine_tune['val_acc']
            results['dim'] = num_non_zero_element.item()
            results['BIC'] = BIC

            # temp_str = [str(int(x)) for x in var_gamma_train[str(training_epochs-1)]]
            # temp_str = ' '.join(temp_str)
            # filename = os.path.join(PATH, 'selected_variable.txt')
            # f = open(filename, 'w')
            # f.write(temp_str)
            # f.close()
            
            # calculate metrics
            # z = net.get_z(treat=sampled_treat, z_sd=net.model.sigma_z**0.5, sample_size=2)
            # z = net.model.module_az(sampled_treat)
            # plt.hist(z.detach().numpy())
            # plt.savefig(os.path.join(PATH, 'sampled_z.png'))
            # plt.close()

    with torch.no_grad():
        for y, treat in test_data:
            y_pred = best_model.get_y_cf(treat, best_model.model.sigma_z**0.5, classification_flag)
            y_pred = y_scalar.inverse_transform(y_pred.detach().numpy())
            y = y_scalar.inverse_transform(y.detach().numpy())
    
        data_args = dict(input_size = input_size,
                         sample_size = train_size + val_size * 2,
                         confounder_size = hidden_dim[confounder_layer],
                         seed = data_seed,
                         train_size = train_size,
                         val_size = val_size,
                         classification_flag=classification_flag,
                         sim_type=configs["sim_type"],
                         y_scalar=y_scalar)

        marginal_effect_est, marginal_effect_true = [], []
        marginal_effect_est_mean, marginal_effect_true_mean = [], []
        marginal_effect_est_std, marginal_effect_true_std = [], []
        marginal_effect_error_percent, marginal_effect_error_mean, marginal_effect_dir = [], [], []
        for idx in range(input_size):
            est, true, error_percent, error_mean, dir_cor = marginal_effect_eval(net=best_model, epsilon=1e-1, epsilon_idx=idx, y=y, 
                                                                                    y_pred=y_pred, z_sd=best_model.model.sigma_z**0.5, **data_args)
            marginal_effect_est.append(est)
            marginal_effect_true.append(true)
            marginal_effect_est_mean.append(est.mean())
            marginal_effect_true_mean.append(true.mean())
            marginal_effect_est_std.append(est.std())
            marginal_effect_true_std.append(true.std())
            marginal_effect_error_percent.append(error_percent)
            marginal_effect_error_mean.append(error_mean)
            marginal_effect_dir.append(dir_cor)
    
    # visualize the metrics
    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(12, 12))
    for i in range(9):
        est = marginal_effect_est[i]
        true = marginal_effect_true[i]

        range_max = max(est.max(), true.max())
        range_min = min(est.min(), true.min())
        ax[i//3, i%3].hist(est, alpha=0.5, label='est', range = (range_min, range_max))
        ax[i//3, i%3].hist(true, alpha=0.5, label='true', range = (range_min, range_max))

        ax[i//3, i%3].legend(loc='upper right') 
        ax[i//3, i%3].set_title("treatment" + str(i+1))
    
    fig.savefig(os.path.join(PATH, 'marginal_effect_error.png'))
    plt.close()

    # save the model
    torch.save(best_model.model.state_dict(), os.path.join(PATH, 'model.pt'))

    # save overall performance
    with open(os.path.join(PATH, 'causal_stoNet_results.json'), "w") as f:
        json.dump(results, f)

    return marginal_effect_error_percent, marginal_effect_error_mean, marginal_effect_dir, marginal_effect_est_mean, marginal_effect_true_mean, marginal_effect_est_std, marginal_effect_true_std
    

def main(config_path, num_sim, data_seed=None):
    with open(config_path, 'r') as file:
        configs = json.load(file)
    
    error_percent, error_mean, dir = np.zeros((num_sim, configs["input_size"])), np.zeros((num_sim, configs["input_size"])), np.zeros((num_sim, configs["input_size"]))
    true_mean, est_mean = np.zeros((num_sim, configs["input_size"])), np.zeros((num_sim, configs["input_size"]))
    true_std, est_std = np.zeros((num_sim, configs["input_size"])), np.zeros((num_sim, configs["input_size"]))

    for data_seed in range(num_sim):
        # stonet
        marginal_effect_error_percent, marginal_effect_error_mean, marginal_effect_dir, mean_est, mean_true, std_est, std_true = train_stonet(configs, data_seed = data_seed)
        true_mean[data_seed] = mean_true
        est_mean[data_seed] = mean_est
        true_std[data_seed] = std_true
        est_std[data_seed] = std_est
        error_percent[data_seed] = marginal_effect_error_percent
        error_mean[data_seed] = marginal_effect_error_mean
        dir[data_seed] = marginal_effect_dir
    # marginal_effect_error_percent, marginal_effect_error_mean, marginal_effect_dir = train_stonet(configs, data_seed = data_seed)
    # error_percent[0] = marginal_effect_error_percent
    # error_mean[0] = marginal_effect_error_mean
    # dir[0] = marginal_effect_dir

    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
    for i in range(9):
        est = est_mean[:,i]
        true = true_mean[:,i]

        range_max = max(est.max(), true.max())
        range_min = min(est.min(), true.min())

        ax[i//3, i%3].scatter(est, true)
        ax[i//3, i%3].errorbar(est, true, yerr=est_std[:,i], fmt='none')
        ax[i//3, i%3].axline([range_min, range_min], [range_max, range_max], ls='--')
        ax[i//3, i%3].set_title("treatment " + str(i+1))
        ax[i//3, i%3].set_xlabel("estimated")
        ax[i//3, i%3].set_ylabel("true")

    fig.savefig(os.path.join(configs["path"], 'marginal_effect_overall.png'))
    plt.close()

    marginal_effect_metrics = dict(true_mean = true_mean,
                                   est_mean = est_mean,
                                   true_std = true_std,
                                   est_std = est_std,
                                   marginal_effect_error_percent = error_percent, 
                                   marginal_effect_dir = dir,
                                   marginal_effect_error_mean = error_mean,
                                   marginal_effect_error_sd = (error_mean.std(axis=0)/np.sqrt(num_sim)))
    


    with open(os.path.join(configs["path"], 'marginal_effect_metrics.pkl'), "wb") as f:
        dump(marginal_effect_metrics, f)
    # with open(os.path.join(configs["path"], str(data_seed), 'marginal_effect_metrics.pkl'), "wb") as f:
    #     dump(marginal_effect_metrics, f)
        

if __name__ == '__main__':
    main(config_path="result/multicause_sim/separable/multicause_sim_configs2.json", num_sim = 10)
    # main(config_path="result/multicause_sim/non-separable/multicause_sim_configs2.json", num_sim = 10)