import sys
sys.path.insert(0, '.')
from model.stonet_multicause import StoNet_MultiCause_Model
from experiment.data import BRCA, data_preprocess
from torch.utils.data import DataLoader, ConcatDataset
import torch
import numpy as np
import os
import errno
from sklearn.utils import class_weight
from pickle import dump, load
import matplotlib.pyplot as plt
import json
import pandas as pd
import matplotlib.pyplot as plt

def intersection(lst1, lst2):
    lst3 = [value for value in lst1 if value in lst2]
    return lst3

def train_stonet(configs, cross_fit_no):
    # task
    classification_flag = configs["classification"]

    # load dataset
    data = BRCA()
    input_size = len(data.a[0])
    if classification_flag:
        class_weights_out = class_weight.compute_class_weight(class_weight='balanced', 
                                                              classes=np.unique(data.y),
                                                              y=data.y.numpy())
        class_weights_out = torch.tensor(class_weights_out, dtype=torch.float)

        # class_weights_out *= 1.5
    train_set, val_set1, val_set2, _, _, _, val_size, test_size = data_preprocess(data=data, 
                                                                                 partition_seed=1,
                                                                                 x_scale=True, 
                                                                                 y_scale=False,
                                                                                 cross_val=configs["cross_val_fold"],
                                                                                 cross_fit_no=cross_fit_no,
                                                                                 cat_var=True)
    # val_set = ConcatDataset([val_set1, val_set2])
    # full_set = ConcatDataset([train_set, val_set1, val_set2])
    val_set = val_set1
    test_set = val_set2

    # load training data and validation data
    batch_size = configs["batch_size"]
    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=False)
    # val_data = DataLoader(val_set, batch_size=val_size+test_size, shuffle=False)
    # full_data = DataLoader(full_set, batch_size=data.data_size, shuffle=False)
    val_data = DataLoader(val_set, batch_size=val_size, shuffle=False)
    test_data = DataLoader(test_set, batch_size=test_size, shuffle=False)

    # path to save the results
    PATH = os.path.join(configs["path"], str(cross_fit_no))
    if not os.path.isdir(PATH):
        try:
            os.makedirs(PATH)
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                pass
            else:
                raise
    
    # load counters
    with open(os.path.join(configs["counter_path"], "counter.pkl"), "rb") as file:
        counter = load(file)
    counter += 1
    with open (os.path.join(configs["counter_path"], "counter.pkl"), "wb") as file:
        dump(counter, file)

    # network args
    sigma_list = configs["sigma"]
    hidden_dim = configs["hidden_dim"]
    confounder_layer = configs["confounder_layer"]
    net_args = dict(hidden_dim = hidden_dim, 
                    input_dim = input_size,
                    output_dim = len(data.y.unique()) if classification_flag else 1,
                    confounder_layer = confounder_layer, 
                    sigma_list = sigma_list,
                    outcome_cat=classification_flag,
                    log_dir = os.path.join(PATH, "train_log", "train_"+str(counter)))
                    # log_dir = os.path.join(PATH, "train_log"))

    # imputation parameters
    impute_lrs = configs["impute_lr"]
    impute_lr_decay = configs["impute_lr_decay"]
    itas = [0.1/impute_lrs[i] for i in range(len(impute_lrs))]
    mh_step = configs["mh_step"]

    # training setting
    para_lrs_train = configs["para_lr_train"]
    para_lrs_fine_tune = configs["para_lr_fine_tune"]
    training_epochs = configs["train_epoch"]
    pretrain_epochs = configs["pretrain_epoch"]
    fine_tune_epochs = configs["fine_tune_epoch"]
    para_lr_decay = configs["para_lr_decay"]

    # prior parameters
    prior_sigma_0 = configs["sigma0"]
    prior_sigma_1 = configs["sigma1"]
    lambda_n = configs["lambda_n"]

    # training args
    optim_args = dict(train_data=train_data, 
                      val_data=val_data,
                      batch_size=batch_size,
                      mh_step=mh_step, 
                      itas=itas,
                      para_lrs_train = para_lrs_train,
                      para_lrs_fine_tune = para_lrs_fine_tune,
                      para_lr_decay=para_lr_decay, 
                      impute_lr_decay=impute_lr_decay,
                      prior_sigma_0=prior_sigma_0, 
                      prior_sigma_1=prior_sigma_1, 
                      lambda_n=lambda_n,
                      CE_weight=class_weights_out)

    # threshold for sparsity
    threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
            0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))

    # training results containers
    results = dict(dim=0, BIC=0, num_selection=0, train_loss=0, val_loss=0)
    BIC_list = []  # BIC value for model selection
    if classification_flag:
        results.update([('train_acc', 0), ('val_acc', 0)])

    # Training starts here
    best_model = None
    for prune_seed in range(configs["num_run"]):
        print('number of runs', prune_seed)

        # initialize network
        net = StoNet_MultiCause_Model(seed=prune_seed, **net_args)

        # pretrain
        print("Pretrain")
        net.train(mode="pretrain", epochs=pretrain_epochs, impute_lrs=impute_lrs, **optim_args)

        # train
        print("Train")
        net.train(mode="train", epochs=training_epochs, impute_lrs=impute_lrs, **optim_args)
        var_gamma_train = net.input_gamma_path["var_selected"]
        num_gamma_train = net.input_gamma_path["num_selected"]
        impute_lrs_fine_tune = net.step_impute_lrs

        # prune the network
        net.prune_network(threshold)

         # refine non-zero network parameters
        print("Refine Weight")
        net.train(mode="finetune", epochs=fine_tune_epochs, impute_lrs=impute_lrs_fine_tune, **optim_args)
        likelihoods = net.hidden_likelihood
        performance_fine_tune = net.performance

        # calculate BIC
        BIC, num_non_zero_element= net.BIC_and_non_zero_para(train_set, likelihoods)
        BIC_list.append(BIC)

        print("number of non-zero connections:", num_non_zero_element)
        print('BIC:', BIC)

        # save model training results for the model with the smallest BIC
        if BIC == min(BIC_list):
            best_model = net
            results['num_selection'] = num_gamma_train[training_epochs-1].item()
            results['train_loss'] = performance_fine_tune['train_loss']
            results['val_loss'] = performance_fine_tune['val_loss']
            if classification_flag:
                results['train_acc'] = performance_fine_tune['train_acc']
                results['val_acc'] = performance_fine_tune['val_acc']
            results['dim'] = num_non_zero_element.item()
            results['BIC'] = BIC

            # temp_str = [str(int(x)) for x in var_gamma_train[str(training_epochs-1)]]
            # temp_str = ' '.join(temp_str)
            # filename = PATH + 'selected_variable.txt'
            # f = open(filename, 'w')
            # f.write(temp_str)
            # f.close()

    # marginal_effect_rad, marginal_effect_no_rad = [], []
    marginal_effect = []
    # calculate dependence on variables
    # with torch.no_grad():
    #     for _, treat in test_data:
    #         rad_ind = treat[:,-1].type(torch.BoolTensor)
    #         no_rad_ind = (1-treat[:,-1]).type(torch.BoolTensor)
    #         for i in range(input_size-1):  # we only consider genetic variables here
    #             effect_rad = best_model.get_marginal_effect(treat[rad_ind], best_model.model.sigma_z**0.5, epsilon=1e-1, epsilon_idx=i, outcome_cat=True)[:, 1][:, None]  # P(Y(a)=1)
    #             effect_no_rad = best_model.get_marginal_effect(treat[no_rad_ind], best_model.model.sigma_z**0.5, epsilon=1e-1, epsilon_idx=i, outcome_cat=True)[:, 1][:, None]
    #             marginal_effect_rad.append(effect_rad)
    #             marginal_effect_no_rad.append(effect_no_rad)
    with torch.no_grad():
        for _, treat in test_data:
            for i in range(input_size-1):  # we only consider genetic variables here
                effect = net.get_marginal_effect(treat, best_model.model.sigma_z**0.5, epsilon=1e-1, epsilon_idx=i, outcome_cat=True)[:, 1][:, None]  # P(Y(a)=1)
                marginal_effect.append(effect)

    # marginal_effect_rad = np.concatenate(marginal_effect_rad, axis=1)
    # marginal_effect_no_rad = np.concatenate(marginal_effect_no_rad, axis=1)
    marginal_effect = np.concatenate(marginal_effect, axis=1)

    # save the model
    torch.save(best_model.model.state_dict(), os.path.join(PATH, 'model.pt'))

    # save overall performance
    with open(os.path.join(PATH, 'causal_stoNet_results.json'), "w") as f:
        json.dump(results, f)
                
    # return marginal_effect_rad, marginal_effect_no_rad
    return marginal_effect
            
            
def main(config_path):
    with open(config_path, 'r') as file:
        configs = json.load(file)

    cross_val_fold = configs["cross_val_fold"]
    # effect_rad_full, effect_no_rad_full = [], []
    effect_full = []

    for cross_fit_no in range(1, cross_val_fold+1):
        # marginal_effect_rad, marginal_effect_no_rad = train_stonet(configs, cross_fit_no)
        marginal_effect = train_stonet(configs, cross_fit_no)
        # effect_rad_full.append(marginal_effect_rad)
        # effect_no_rad_full.append(marginal_effect_no_rad)
        effect_full.append(marginal_effect)
    
    # effect_rad_full = np.concatenate(effect_rad_full, axis=0)
    # effect_no_rad_full = np.concatenate(effect_no_rad_full, axis=0)
    effect_full = np.concatenate(effect_full, axis=0)

    # marginal_effect_rad_full = np.mean(effect_rad_full, axis=0)
    # marginal_effect_no_rad_full = np.mean(effect_no_rad_full, axis=0)

    # importance_rad = np.square(marginal_effect_rad_full)
    # importance_no_rad = np.square(marginal_effect_no_rad_full)
    importance = np.abs(np.mean(effect_full, axis=0))
    
    # save marginal effects
    # marginal_effect = dict(marginal_effect_rad = marginal_effect_rad_full,
    #                        marginal_effect_no_rad = marginal_effect_no_rad_full)
    
    # selected_marginal_effect = dict(marginal_effect_rad = effect_rad_full[:, selected_gene_idx],
    #                                 marginal_effect_no_rad = effect_no_rad_full[:, selected_gene_idx])

    # with open (os.path.join(configs["path"], "marginal_effect.pkl"), "wb") as file:
    #     dump(marginal_effect, file)

    # with open (os.path.join(configs["path"], "selected_marginal_effect.pkl"), "wb") as file:
    #     dump(selected_marginal_effect, file)

    
    # save t-statistics
    # t_stat = dict(t_stat_rad = t_stat_rad,
    #               t_stat_no_rad = t_stat_no_rad)

    # with open (os.path.join(configs["path"], "t_stat.pkl"), "wb") as file:
    #     dump(t_stat, file)

    # save importance
    # importance = dict(importance_rad=importance_rad, importance_no_rad=importance_no_rad)

    # importance_rad = pd.DataFrame(importance_rad, index=range(0, len(importance_rad)), columns=["importance_rad"])
    # importance_rad = importance_rad.sort_values(by="importance_rad")
    # importance_rad.to_csv("FDR.code/importance_rad.dat", sep=" ", header = False)

    # importance_no_rad = pd.DataFrame(importance_no_rad, index=range(0, len(importance_no_rad)), columns=["importance_no_rad"])
    # importance_no_rad = importance_no_rad.sort_values(by="importance_no_rad")
    # importance_no_rad.to_csv("FDR.code/importance_no_rad.dat", sep=" ", header = False)

    # with open (os.path.join(configs["path"], "importance.pkl"), "wb") as file:
    #     dump(importance, file)

    result = dict(marginal_effect = effect_full,
                  importance = importance)
    
    with open (os.path.join(configs["path"], "brca_result.pkl"), "wb") as file:
        dump(result, file)

    # fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
    # fig.suptitle('dependence of potential outcome on genetic variables')
    # # ax1.hist(np.log(marginal_effect_rad))
    # ax[0].hist(np.log(importance_rad))
    # ax[0].set_title("with radiation")
    # # ax2.hist(np.log(marginal_effect_no_rad))
    # ax[1].hist(np.log(importance_no_rad))
    # ax[1].set_title("without radiation")
    # fig.savefig(os.path.join(configs["path"], 'variable_impact.png'))
    # plt.close()



plt.show()

if __name__ == '__main__':
    # main("result/brca/brca_configs2.json")
    # main("result/brca/brca_configs3.json")
    # main("result/brca/brca_configs4.json")
    main("result/brca/brca_configs5.json")
    # main("result/brca/brca_configs6.json")