import sys
sys.path.insert(0, '.')
from model.stonet_proxy import StoNet_Proxy_Model
from experiment.data import ProxySim_Normal
from model.metrics import pehe
from torch.utils.data import DataLoader, random_split, ConcatDataset
import torch
import numpy as np
import os
import errno
import pickle
import json


def train_stonet(configs, data_seed):
    # task
    classification_flag = configs["classification"]

    # load dataset
    train_size = configs["train_size"]
    val_size = int(train_size*0.25)
    input_size = configs["input_size"]
    data = ProxySim_Normal(input_size, train_size + val_size * 2, data_seed)

    train_set, val_set, test_set = random_split(data, [train_size, val_size, val_size],
                                                generator=torch.Generator().manual_seed(1))
    in_sample_set = ConcatDataset([train_set, val_set])

    batch_size = int(train_size/50)
    train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_data = DataLoader(val_set, batch_size=batch_size, shuffle=True)
    test_data = DataLoader(test_set, batch_size=batch_size, shuffle=True)
    in_sample_data = DataLoader(in_sample_set, batch_size=batch_size, shuffle=True)

    # path to save the results
    PATH = os.path.join(configs["path"], str(data_seed))
    if not os.path.isdir(PATH):
        try:
            os.makedirs(PATH)
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(PATH):
                pass
            else:
                raise
    
    # load counters
    with open(os.path.join(configs["counter_path"], "counter.pkl"), "rb") as file:
        counter = pickle.load(file)
    counter += 1
    with open (os.path.join(configs["counter_path"], "counter.pkl"), "wb") as file:
        pickle.dump(counter, file)

    # network args
    sigma_list = configs["sigma"]
    net_args = dict(hidden_dim = configs["hidden_dim"], 
                    input_dim = input_size,
                    output_dim = len(data.y.unique()) if classification_flag else 1,
                    confounder_layer = configs["confounder_layer"], 
                    treat_layer = configs["treatment_layer"], 
                    treat_node = configs["treatment_node"],
                    sigma_list = sigma_list,
                    outcome_cat=classification_flag,
                    log_dir = os.path.join(PATH, "train_log"))

    # imputation parameters
    impute_lrs = configs["impute_lr"]
    impute_lr_decay = configs["impute_lr_decay"]
    treat_loss_weight = configs["treatment_loss_weight"] 
    itas = [0.1/impute_lrs[i] for i in range(len(impute_lrs))]
    mh_step = configs["mh_step"]

    # training setting
    para_lrs_train = configs["para_lr_train"]
    para_lrs_fine_tune = configs["para_lr_fine_tune"]
    training_epochs = configs["train_epoch"]
    pretrain_epochs = configs["pretrain_epoch"]
    fine_tune_epochs = configs["fine_tune_epoch"]
    para_lr_decay = configs["para_lr_decay"]

    # prior parameters
    prior_sigma_0 = configs["sigma0"]
    prior_sigma_1 = configs["sigma1"]
    lambda_n = configs["lambda_n"]

    # training args
    optim_args = dict(train_data=train_data, 
                      val_data=val_data,
                      batch_size=batch_size,
                      mh_step=mh_step, 
                      itas=itas,
                      para_lrs_train = para_lrs_train,
                      para_lrs_fine_tune = para_lrs_fine_tune,
                      para_lr_decay=para_lr_decay, 
                      impute_lr_decay=impute_lr_decay,
                      prior_sigma_0=prior_sigma_0, 
                      prior_sigma_1=prior_sigma_1, 
                      lambda_n=lambda_n,
                      treat_loss_weight=treat_loss_weight)

    # threshold for sparsity
    threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
            0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))
    
    # training results containers
    results = dict(dim=0, BIC=0, num_selection_out=0, num_selection_treat=0, out_train_loss=0, out_val_loss=0,
                   treat_train_loss=0, treat_val_loss=0, treat_train_acc=0, treat_val_acc=0)
    BIC_list = []  # BIC value for model selection
    if classification_flag:
        results.update([('out_train_acc', 0), ('out_val_acc', 0)])
    
    # training starts here!
    best_model = None
    for prune_seed in range(configs["num_run"]):
        print('number of runs', prune_seed)

        # initialize network
        net = StoNet_Proxy_Model(seed=prune_seed, **net_args)
        
        # pretrain
        print("Pretrain")
        net.train(mode="pretrain", epochs=pretrain_epochs, impute_lrs=impute_lrs, **optim_args)

        # train
        print("Train")
        net.train(mode="train", epochs=training_epochs, impute_lrs=impute_lrs, **optim_args)
        var_gamma_out_train = net.input_gamma_path["var_selected"]
        num_gamma_out_train = net.input_gamma_path["num_selected"]
        var_gamma_treat_train = net.input_gamma_path["var_selected_treat"]
        num_gamma_treat_train = net.input_gamma_path["num_selected_treat"]
        impute_lrs_fine_tune = net.step_impute_lrs

        # prune the network
        net.prune_network(threshold)

        # refine non-zero network parameters
        print("Refine Weight")
        net.train(mode="finetune", epochs=fine_tune_epochs, impute_lrs=impute_lrs_fine_tune, **optim_args)
        performance_fine_tune = net.performance
        likelihoods = net.hidden_likelihood

        # calculate BIC
        BIC, num_non_zero_element= net.BIC_and_non_zero_para(train_set, likelihoods)
        BIC_list.append(BIC)

        print("number of non-zero connections:", num_non_zero_element)
        print('BIC:', BIC)
        
        # save model training results for the model with the smallest BIC
        if BIC == min(BIC_list):
            best_model = net
            results['num_selection_out'] = num_gamma_out_train[training_epochs-1].item()
            results['num_selection_treat'] = num_gamma_treat_train[training_epochs-1].item()
            results['out_train_loss'] = performance_fine_tune['train_loss']
            results['treat_train_loss'] = performance_fine_tune['treat_train_loss']
            results['out_val_loss'] = performance_fine_tune['val_loss']
            results['treat_val_loss'] = performance_fine_tune['treat_val_loss']
            results['treat_train_acc'] = performance_fine_tune['treat_train_acc']
            results['treat_val_acc'] = performance_fine_tune['treat_val_acc']
            if classification_flag:
                results['out_train_acc'] = performance_fine_tune['train_acc']
                results['out_val_acc'] = performance_fine_tune['val_acc']
            results['dim'] = num_non_zero_element.item()
            results['BIC'] = BIC

            # temp_str = [str(int(x)) for x in var_gamma_out_train[str(training_epochs-1)]]
            # temp_str = ' '.join(temp_str)
            # filename = PATH + 'selected_variable_out.txt'
            # f = open(filename, 'w')
            # f.write(temp_str)
            # f.close()

            # temp_str = [str(int(x)) for x in var_gamma_treat_train[str(training_epochs-1)]]
            # temp_str = ' '.join(temp_str)
            # filename = PATH + 'selected_variable_treat.txt'
            # f = open(filename, 'w')
            # f.write(temp_str)
            # f.close()

            torch.save(best_model.model.state_dict(), os.path.join(PATH, 'model.pt'))
            
            # save overall performance
            with open(os.path.join(PATH, 'causal_stoNet_results.json'), "w") as f:
                json.dump(results, f)

    # calculate PEHE = 1/n sum_i^n (tau_true - tau_est)^2
    with torch.no_grad():
        pehe_in_sample, pehe_out_sample = 0, 0
        for _, _, x, tau in in_sample_data:
            #x, z_sd, outcome_cat, sample_size=100, seed=1
            y_cf_list = best_model.get_y_cf(x, best_model.model.sigma_z**0.5, classification_flag, sample_size = 100, seed=data_seed)
            pehe_in_sample += pehe(y_cf_list, tau)
        pehe_in_sample /= int(train_size + val_size)

        for _, _, x, tau in test_data:
            y_cf_list = best_model.get_y_cf(x, best_model.model.sigma_z**0.5, classification_flag, sample_size = 100, seed=data_seed)
            pehe_out_sample += pehe(y_cf_list, tau)
        pehe_out_sample /= int(val_size)


    return pehe_in_sample, pehe_out_sample


def main(config_path, num_sim):
    with open(config_path, 'r') as file:
        configs = json.load(file)
    
    pehe_in_samples, pehe_out_samples = np.zeros(num_sim), np.zeros(num_sim)
    for data_seed in range(num_sim):
        pehe_in_sample, pehe_out_sample = train_stonet(configs, data_seed)
        pehe_in_samples[data_seed] = pehe_in_sample
        pehe_out_samples[data_seed] = pehe_out_sample
    
    result = {"in-sample mean": pehe_in_samples.mean(),
              "out-of-sample mean": pehe_out_samples.mean(),
              "in-sample std": pehe_in_samples.std()/np.sqrt(num_sim),
              "out-of-sample std": pehe_out_samples.std()/np.sqrt(num_sim)}
    print(result)

    with open(os.path.join(configs["path"], 'pehe_stonet.json'), "w") as f:
        json.dump(result, f)
    

if __name__ == '__main__':
    main(config_path="result/proxy_sim/proxy_sim_configs2.json", num_sim=10)

