import sys
sys.path.insert(0, '.')
from model.net_proxy_variable import StoNet_Proxy
from model.train_proxy_variable import train_proxy
from data import TwinsData, data_preprocess
from torch.utils.data import DataLoader, random_split, ConcatDataset
import torch
import numpy as np
import argparse
import os
import errno
from torch.optim import SGD
from pickle import dump
from sklearn.utils import class_weight


parser = argparse.ArgumentParser(description='Numerical Experiment for Twins Data')
# Basic Setting
# dataset setting
parser.add_argument('--cross_val_fold', default=3, type=int, help = 'k-fold cross validation')
parser.add_argument('--cross_fit_no', default=1, type=int, help='the indicator for training set in three-fold cross-fitting')
parser.add_argument('--batch_size', default=200, type=int, help='batch_size')

# Parameter for StoNet
# model
parser.add_argument('--unit', default=[64, 32, 16, 8], type=int, nargs='+', help='number of hidden unit in each layer')
parser.add_argument('--sigma', default=[1e-3, 1e-5, 1e-7], type=float, nargs='+',
                    help='variance of each layer for the model')
parser.add_argument('--confounder_depth', default=1, type=int, help='number of layers before the latent confounder layer')
parser.add_argument('--treatment_depth', default=2, type=int, help='number of layers before the treatment layer')
parser.add_argument('--treat_node', default=1, type=int, nargs='+', help='the hidden nodes that corresponds to the treatment')
parser.add_argument('--treat_loss_weight', default=100, type=float, help='weight for the treatment loss')
parser.add_argument('--regression', dest='classification_flag', action='store_false', help='false for regression')
parser.add_argument('--classification', dest='classification_flag', action='store_true', help='true for classification')

# training setting
parser.add_argument('--pretrain_epoch', default=100, type=int, help='total number of pretraining epochs')
parser.add_argument('--train_epoch', default=1000, type=int, help='total number of training epochs')
parser.add_argument('--mh_step', default=1, type=int, help='number of SGHMC steps for imputation')
parser.add_argument('--impute_lr', default=[3e-3, 1e-5], type=float, nargs='+', help='step size for SGHMC')
parser.add_argument('--para_lr_train', default=[1e-3, 1e-5, 1e-10], type=float, nargs='+',
                    help='step size for parameter update during training stage')
parser.add_argument('--para_lr_decay', default=1.2, type=float, help='decay factor for para_lr')
parser.add_argument('--impute_lr_decay', default=1, type=float, help='decay factor for impute_lr')

# Parameters for Sparsity
parser.add_argument('--num_run', default=10, type=int, help='Number of different initialization used to train the model')
parser.add_argument('--fine_tune_epoch', default=200, type=int, help='total number of fine tuning epochs')
parser.add_argument('--para_lr_fine_tune', default=[1e-4, 1e-6, 1e-11], type=float, nargs='+',
                    help='step size of parameter update for fine-tuning stage')

# prior setting
parser.add_argument('--sigma0', default=1e-5, type=float, help='variance of component 0 in the mixture gaussian prior')
parser.add_argument('--sigma1', default=1e-2, type=float, help='variance of component 1 in the mixture gaussian prior')
parser.add_argument('--lambda_n', default=1e-6, type=float, help='lambda in the mixture gaussian prior')
parser.add_argument('--alpha', default=1, type=float, help='shape parameter in the inverse gamma prior')
parser.add_argument('--beta', default=1, type=float, help='scale parameter in the inverse gamma prior')


args = parser.parse_args()


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # task
    classification_flag = args.classification_flag

    # generate dataset
    data = TwinsData()
    train_size = int(data.data_size/3)
    class_weights_out = class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(data.y),
                                                          y=data.y.numpy())
    class_weights_out = torch.tensor(class_weights_out, dtype=torch.float)
    cross_fit_no = args.cross_fit_no
    train_set, val_set, test_set, _, _ = data_preprocess(data, 1, cross_fit_no,
                                                         args.cross_val_fold, x_scale=False, y_scale=False)
    in_sample_set = ConcatDataset([train_set, val_set])

    # load training data and validation data
    batch_size = args.batch_size
    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_data = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_data = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    in_sample_data = DataLoader(in_sample_set, batch_size=batch_size)

    # network setup
    sigma_list = args.sigma
    net_args = dict(hidden_dim=args.unit, input_dim=data.x[0].size(dim=0),
                    output_dim=len(data.y.unique()) if classification_flag else 1,
                    confounder_layer=args.confounder_depth, treat_layer=args.treatment_depth, treat_node=args.treat_node,
                    sigma_list=sigma_list)

    # set number of independent runs for sparsity
    num_seed = args.num_run

    # training setting
    para_lrs_train = args.para_lr_train
    para_lrs_fine_tune = args.para_lr_fine_tune
    training_epochs = args.train_epoch
    pretrain_epochs = args.pretrain_epoch
    fine_tune_epochs = args.fine_tune_epoch
    para_lr_decay = args.para_lr_decay
    impute_lr_decay = args.impute_lr_decay
    treat_loss_weight = args.treat_loss_weight

    # imputation parameters
    impute_lrs = args.impute_lr
    mh_step = args.mh_step

    # prior parameters
    prior_sigma_0 = args.sigma0
    prior_sigma_1 = args.sigma1
    lambda_n = args.lambda_n
    prior_alpha = args.alpha
    prior_beta = args.beta

    # threshold for sparsity
    threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
            0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))

    # training results containers
    results = dict(dim=0, BIC=0, num_selection_out=0, num_selection_treat=0, out_train_loss=0, out_val_loss=0,
                   treat_train_loss=0, treat_val_loss=0, treat_train_acc=0, treat_val_acc=0)
    BIC_list = []  # BIC value for model selection
    if classification_flag:
        results.update([('out_train_acc', 0), ('out_val_acc', 0)])

    # path to save the result
    base_path = os.path.join('.', 'twins', 'result', )
    basic_spec = str(sigma_list) + '_' + str(mh_step) + '_' + str(training_epochs) + '_' + str(treat_loss_weight)
    spec = str(impute_lrs) + '_' + str(para_lrs_train) + '_' + str(prior_sigma_0) + '_' + \
           str(prior_sigma_1) + '_' + str(lambda_n)
    decay_spec = str(impute_lr_decay) + '_' + str(para_lr_decay)
    base_path = os.path.join(base_path, basic_spec, spec, decay_spec, str(cross_fit_no))

    for prune_seed in range(num_seed):
        print('number of runs', prune_seed)

        PATH = os.path.join(base_path)
        if not os.path.isdir(PATH):
            try:
                os.makedirs(PATH)
            except OSError as exc:  # Python >2.5
                if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                    pass
                else:
                    raise

        # initialize network
        np.random.seed(prune_seed)
        torch.manual_seed(prune_seed)
        net = StoNet_Proxy(**net_args)
        net.to(device)

        # optimizer
        optimizer_list_train = {}
        # set maximize = True to do gradient ascent on log likelihood
        # set up optimizer for layer weight and bias
        for i in range(len(net.module_list)):
            optimizer_list_train['module'+str(i)] = SGD(net.module_list[i].parameters(), lr=para_lrs_train[i],
                                                        momentum=0.9, maximize=True)

        optimizer_list_fine_tune = {}
        for j in range(len(net.module_list)):
            optimizer_list_fine_tune['module'+str(j)] = SGD(net.module_list[j].parameters(), lr=para_lrs_fine_tune[j],
                                                            momentum=0.9, maximize=True)

        optim_args = dict(train_data=train_data, val_data=val_data,
                          batch_size=batch_size, mh_step=mh_step,
                          prior_sigma_0=prior_sigma_0, prior_sigma_1=prior_sigma_1, lambda_n=lambda_n,
                          prior_alpha=prior_alpha, prior_beta=prior_beta,
                          para_lr_decay=para_lr_decay, impute_lr_decay=impute_lr_decay,
                          outcome_cat=classification_flag, treat_loss_weight=treat_loss_weight,
                          CE_weight=class_weights_out)

        # pretrain
        print("Pretrain")
        output_pretrain = train_proxy(mode="pretrain", net=net, epochs=pretrain_epochs, optimizer_list=optimizer_list_train,
                                   impute_lrs=impute_lrs, **optim_args)
        # para_pretrain = output_pretrain["para_path"]
        # para_grad_pretrain = output_pretrain["para_grad_path"]
        # para_gamma_pretrain = output_pretrain["para_gamma_path"]
        performance_pretrain = output_pretrain["performance"]

        # train
        print("Train")
        output_train = train_proxy(mode="train", net=net, epochs=training_epochs, optimizer_list=optimizer_list_train,
                                impute_lrs=impute_lrs, **optim_args)
        para_train = output_train["para_path"]
        # para_grad_train = output_train["para_grad_path"]
        # para_gamma_train = output_train["para_gamma_path"]
        var_gamma_out_train = output_train["input_gamma_path"]["var_selected_out"]
        num_gamma_out_train = output_train["input_gamma_path"]["num_selected_out"]
        var_gamma_treat_train = output_train["input_gamma_path"]["var_selected_treat"]
        num_gamma_treat_train = output_train["input_gamma_path"]["num_selected_treat"]
        performance_train = output_train["performance"]
        impute_lrs_fine_tune = output_train["impute_lrs"]

        # prune network parameters
        with torch.no_grad():
            for name, para in net.module_list.named_parameters():
                para.data = torch.FloatTensor(para_train[str(training_epochs-1)][name]).to(device)

        user_mask = {}
        for name, para in net.module_list.named_parameters():
            user_mask[name] = para.abs() < threshold
        net.set_prune(user_mask)
        net.prune_masked_para()

        # refine non-zero network parameters
        print("Refine Weight")
        output_fine_tune = train_proxy(mode="train", net=net, epochs=fine_tune_epochs, optimizer_list=optimizer_list_fine_tune,
                                    impute_lrs=impute_lrs_fine_tune, **optim_args)
        performance_fine_tune = output_fine_tune["performance"]
        likelihoods = output_fine_tune["likelihoods"]

        # calculate BIC
        with torch.no_grad():
            num_non_zero_element = 0
            for name, para in net.module_list.named_parameters():
                num_non_zero_element = num_non_zero_element + para.numel() - net.mask_prune[name].sum()

            BIC = (np.log(train_set.__len__()) * num_non_zero_element - 2 * np.sum(likelihoods)).item()
            BIC_list.append(BIC)

            print("number of non-zero connections:", num_non_zero_element)
            print('BIC:', BIC)

        # calculate absolute ate error
        with torch.no_grad():
            tau_true = -0.0248

            tau_est_sum=0
            for _, _, x in in_sample_data:
                y_cf_list = net.get_y_cf(x, net.sigma_list[net.confounder_layer], cross_fit_no, classification_flag)
                y_cf_contrast = y_cf_list[0] - y_cf_list[1]
                tau_est_sum += y_cf_contrast[:, 1].sum().item()
            tau_est_in_sample = tau_est_sum / int(train_size*2)

            tau_est_sum=0
            for _, _, x, in test_data:
                y_cf_list = net.get_y_cf(x, net.sigma_list[net.confounder_layer], cross_fit_no, classification_flag)
                y_cf_contrast = y_cf_list[0] - y_cf_list[1]
                tau_est_sum += y_cf_contrast[:, 1].sum().item()
            tau_est_out_sample = tau_est_sum / int(train_size)

        # save model training results for the model with the smallest BIC
        if BIC == min(BIC_list):
            results['num_selection_out'] = num_gamma_out_train[training_epochs-1].item()
            results['num_selection_treat'] = num_gamma_treat_train[training_epochs-1].item()
            results['out_train_loss'] = performance_fine_tune['out_train_loss'][-1]
            results['treat_train_loss'] = performance_fine_tune['treat_train_loss'][-1]
            results['out_val_loss'] = performance_fine_tune['out_val_loss'][-1]
            results['treat_val_loss'] = performance_fine_tune['treat_val_loss'][-1]
            results['treat_train_acc'] = performance_fine_tune['treat_train_acc'][-1]
            results['treat_val_acc'] = performance_fine_tune['treat_val_acc'][-1]
            results['ate_in_sample'] = tau_est_in_sample
            results['ate_out_sample'] = tau_est_out_sample
            print('abs_in_sample', tau_est_in_sample)
            print('abs_out_sample', tau_est_out_sample)
            if classification_flag:
                results['out_train_acc'] = performance_fine_tune['out_train_acc'][-1]
                results['out_val_acc'] = performance_fine_tune['out_val_acc'][-1]
            results['dim'] = num_non_zero_element.item()
            results['BIC'] = BIC

            temp_str = [str(int(x)) for x in var_gamma_out_train[str(training_epochs-1)]]
            temp_str = ' '.join(temp_str)
            filename = PATH + 'selected_variable_out.txt'
            f = open(filename, 'w')
            f.write(temp_str)
            f.close()

            temp_str = [str(int(x)) for x in var_gamma_treat_train[str(training_epochs-1)]]
            temp_str = ' '.join(temp_str)
            filename = PATH + 'selected_variable_treat.txt'
            f = open(filename, 'w')
            f.write(temp_str)
            f.close()

            with open(os.path.join(PATH, 'performance_pretrain.pkl'), 'wb') as f:
                dump(performance_pretrain, f)

            with open(os.path.join(PATH, 'performance_train.pkl'), 'wb') as f:
                dump(performance_train, f)

            with open(os.path.join(PATH, 'performance_fine_tune.pkl'), 'wb') as f:
                dump(performance_fine_tune, f)

            torch.save(net.state_dict(), os.path.join(base_path, 'model.pt'))

    # save overall performance
    with open(os.path.join(base_path, 'causal_stoNet_results.pkl'), "wb") as f:
        dump(results, f)


if __name__ == '__main__':
    main()
