import numpy as np
import torch
import torch.nn as nn
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import recall_score


def train_multiple_causes(mode, net, train_data, val_data, epochs, batch_size, optimizer_list, impute_lr, mh_step, ita,
                          prior_sigma_0, prior_sigma_1, prior_alpha, prior_beta, lambda_n,
                          para_lr_decay, impute_lr_decay, scalar_y=1, outcome_cat=False, CE_weight=None):

    """
    train the network
    inputs:
    mode: training mode
        for "pretrain", the impute_lr and para_lr will keep constant; pruning result will not be recorded.
        for "train", the impute_lr and para_lr decays epoch by epoch; pruning result will be recorded.
    net: StoNet_Proxy object defined in net_proxy_variable.py
        the network to be trained
    train_data, val_data: DataLoader object
        training data and validation data, respectively
    epochs: float
        the number of training epochs
    batch_size: int
        sample size of each batch
    optimizer_list: list of torch.optim objects
        optimizers for parameter update
    impute_lr: float
        learning rate for SGHMC
    mh_step: int
        the number of backward imputation steps
    prior_sigma_0, prior_sigma_1: float
        variances for mixture gaussian prior
    lambda_n: float
        proportion for components of mixture gaussian prior
    prior_alpha, prior_beta: float
        parameters for the inverse gamma prior for confound_sd.
        prior_alpha is the shape parameter; prior_beta is the scale paramter.
    para_lr_decay, impute_lr_decay: float
        decay factor for para_lr and impute_lr, respectively
    scalar_y: float
        when the output is standardized, the losses need to be converted back to the original scale by multiplying
        scalar_y, which is essentially the variance of the train set of y
    outcome_cat: bool
        the type of outcome variable
        if TRUE, the outcome variable is a categorical variable, and this is a regression task
        if False, the outcome variable is a numerical variable, and this is a classification task
    CE_weight: None or torch tensor:
        when the outcome variable is categorical, the weight assigned to each class.
        can be used to deal with unbalanced dataset.

    output:
    para_path: dictionary
            parameter values for each epoch
    para_grad_path: dictionary
            parameter gradient values for each epoch
    para_gamma_path: dictionary
            indicator of connection selection for each epoch
    input_gamma_path: dictionary
            var_selected_out: indicator for variable selection for outcome variable
            num_selected_out: number of selected input variables for outcome variable
            var_selected_treat: indicator for variable selection for treatment
            num_selected_treat: number of selected input variables for treatment
    performance: dictionary
            model performance for each epoch
    impute_lr: list of floats
            starting value of impute_lr for refining network weight
    likelihoods: list of floats
            stores the likelihoods for each hidden layer after final updates of the neural network
            the likelihoods are used to calculate BIC (note: the likelihoods are calculated based on the standardized
            dataset)
    """

    # save training and validation performance
    train_loss_path = []
    val_loss_path = []
    performance = dict(train_loss=train_loss_path, val_loss=val_loss_path)
    if outcome_cat:
        train_acc_path = []
        val_acc_path = []
        # temporary for binary classification
        val_recall_neg = []
        val_recall_pos = []
        performance.update([('train_acc', train_acc_path), ('val_acc', val_acc_path),
                            ('val_recall_neg', val_recall_neg), ('val_recall_pos', val_recall_pos)])

    para_path = {}
    para_grad_path = {}
    para_gamma_path = {}

    module_num = len(net.module_list)
    # save hidden likelihoods for calculating BIC
    hidden_likelihood = np.zeros(module_num)

    # initial value of decaying impute_lrs and para_lrs
    step_impute_lr = impute_lr
    # ita = 0.2/step_impute_lr
    init_para_lrs = []
    for i in range(module_num):
        init_para_lrs.append(optimizer_list['module'+str(i)].param_groups[0]['lr'])

    if mode == "train":
        # save parameter values, indicator variables, and selected input variables for each epoch
        input_gamma_path = dict(var_selected={}, num_selected=[])

    # settings for output loss functions
    if outcome_cat:
        # calculate the weight for each class
        out_loss = nn.CrossEntropyLoss(weight=CE_weight)
        out_loss_sum = nn.CrossEntropyLoss(weight=CE_weight, reduction='sum')
    else:
        out_loss = nn.MSELoss()
        out_loss_sum = nn.MSELoss(reduction='sum')

    # intermediate values for prior gradient calculation
    c1 = np.log(lambda_n) - np.log(1 - lambda_n) + 0.5 * np.log(prior_sigma_0) - 0.5 * np.log(prior_sigma_1)
    c2 = 0.5 / prior_sigma_0 - 0.5 / prior_sigma_1

    # threshold for sparsity
    threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
            0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))

    # training
    for epoch in range(epochs):
        print("Epoch" + str(epoch))

        if mode == "train":
            # impute_lr decay and para_lr decay
            step_impute_lr = impute_lr/(1+impute_lr*epoch**impute_lr_decay)
            for i in range(module_num):
                optimizer_list['module'+str(i)].param_groups[0]['lr'] = init_para_lrs[i]/(1+init_para_lrs[i]*epoch**para_lr_decay)


        # print("impute_lrs", step_impute_lrs)
        # for i in range(net.num_hidden):
        #     print("para_lr", optimizer_list[i].param_groups[0]['lr'])

        # var_tune = 0
        for y, treat, *rest in train_data:
            backward_imputation_args = dict(mh_step=mh_step, impute_lr=step_impute_lr,
                                            outcome_loss=out_loss_sum, treat=treat, y=y, ita=ita)
            # backward imputation
            latent_z = net.backward_imputation(**backward_imputation_args)

            # parameter update
            # update layer weight and layer bias
            # prior gradient
            for para in net.module_list.parameters():
                para.grad = None

            with torch.no_grad():
                for para in net.module_list.parameters():
                    temp = para.pow(2).mul(c2).add(c1).exp().add(1).pow(-1)
                    temp = para.div(-prior_sigma_0).mul(temp) + para.div(-prior_sigma_1).mul(1 - temp)
                    prior_grad = temp.div(len(train_data)*batch_size)
                    para.grad = prior_grad

            # likelihood gradient
            for module_index in range(len(net.module_list)):
                likelihood = net.likelihood_latent(latent_z=latent_z, module_index=module_index,
                                                   outcome_loss=out_loss_sum, y=y, treat=treat)/batch_size
                optimizer = optimizer_list['module' + str(module_index)]
                likelihood.backward()

                if net.prune_flag == 1:
                    net.prune_masked_grad()

                # # gradient clipping
                # torch.nn.utils.clip_grad_norm_(net.module_list[module_index].parameters(),
                #                                max_norm=1/net.sigma_list[module_index], norm_type=2)
                optimizer.step()

                if epoch == epochs-1:
                    with torch.no_grad():
                        # need to recalculate likelihood afater the last parameter update
                        # make sure that treat loss have the same weight as outcome loss
                        likelihood = net.likelihood_latent(latent_z, module_index, out_loss_sum, y, treat)
                        hidden_likelihood[module_index] += likelihood

            #     with torch.no_grad():
            #         mu_z = net.predict_z(x)
            #         var_tune += net.sse(hidden_list[net.confounder_layer], mu_z)
            #
            # var_tune /= epochs
            # print(var_tune)

            # update variance for the hidden confounder
            mu_z = net.module_az(treat)
            # scale parameter of the posterior distribution
            d1 = net.sse(latent_z, mu_z).div(2).add(prior_beta)

            # shape parameter of the posterior distribution
            d2 = batch_size*net.hidden_dim[net.confounder_layer]/2 + prior_alpha-1
            # note that the first variable that needs imputation is the confounder
            with torch.no_grad():
                net.sigma_list[0] = d1/d2

        # calculate training performance
        train_loss, train_correct= 0, 0
        with torch.no_grad():
            for y, treat, *rest in train_data:
                pred = net.forward(treat)
                train_loss += out_loss(pred, y).item()
                if outcome_cat:
                    train_correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        if outcome_cat is False:
            # use RMSE as the model performance metric for regression tasks
            train_loss = np.sqrt(train_loss/len(train_data.dataset)) * scalar_y
        else:
            train_loss /= len(train_data.dataset)
        train_loss_path.append(train_loss)
        print(f"Avg train loss: {train_loss:>8f} \n")
        if outcome_cat:
            train_correct /= len(train_data.dataset)
            train_acc_path.append(train_correct)
            print(f"train accuracy: {train_correct:>8f} \n")

        # calculate validation performance
        val_loss, val_correct = 0, 0
        with torch.no_grad():
            for y, treat, *rest in val_data:
                pred = net.forward(treat)
                val_loss += out_loss(pred, y).item()
                if outcome_cat:
                    pred = pred.argmax(1)
                    val_correct += (pred == y).type(torch.float).sum().item()
                    recall = recall_score(y, pred, average=None).tolist()

        if outcome_cat is False:
            # use RMSE as the model performance metric for regression tasks
            val_loss = np.sqrt(val_loss/len(val_data.dataset)) * scalar_y
        else:
            val_loss /= len(val_data.dataset)
        val_loss_path.append(val_loss)
        print(f"Avg val loss: {val_loss:>8f} \n")
        if outcome_cat:
            val_correct /= len(val_data.dataset)
            val_acc_path.append(val_correct)
            print(f"val accuracy: {val_correct:>8f} \n")
            val_recall_neg.append(recall[0])
            val_recall_pos.append(recall[1])
            print("val recall", recall)

        # toc = time.time()
        # accumulated_time += toc - tic

        # save parameter values and selected connections
        para_path_temp = {str(epoch): {}}
        para_grad_path_temp = {str(epoch): {}}
        para_gamma_path_temp = {str(epoch): {}}

        for name, para in net.module_list.named_parameters():
            para_path_temp[str(epoch)][name] = torch.clone(para).data.cpu().numpy().tolist()
            para_grad_path_temp[str(epoch)][name] = torch.clone(para.grad).data.cpu().numpy().tolist()
            para_gamma_path_temp[str(epoch)][name] = (para.abs() > threshold).data.cpu().numpy().tolist()
        para_path.update(para_path_temp)
        para_grad_path.update(para_grad_path_temp)
        para_gamma_path.update(para_gamma_path_temp)

        if mode == "train":
            # select input variable
            input_size = len(para_gamma_path[str(epoch)]['0.linear0.weight'][0])
            var_ind = np.identity(input_size, dtype=bool)
            for i, (name, para) in enumerate(net.module_list.named_parameters()):
                if i % 2 == 0:
                    var_ind = np.matmul(para_gamma_path[str(epoch)][name], var_ind)
                    if i/2 == net.confounder_layer:
                        temp1 = np.concatenate((var_ind, np.zeros((var_ind.shape[0], input_size), dtype=bool)),
                                               axis=1)
                        temp2 = np.concatenate((np.zeros((input_size, input_size), dtype=bool),
                                                np.identity(input_size, dtype=bool)), axis=1)
                        var_ind = np.concatenate((temp1, temp2), axis=0)
            var_ind = var_ind[:, :input_size] + var_ind[:, input_size:]
            var_ind = np.max(var_ind, 0)
            num_selected = np.sum(var_ind)
            input_gamma_path['var_selected'][str(epoch)] = var_ind.tolist()
            input_gamma_path['num_selected'].append(num_selected.astype("float64"))
            print('number of selected input variable for outcome:', num_selected)

            # # para_lr decay
            # for layer_index in range(net.num_hidden + 1):
            #     scheduler = scheduler_list[layer_index]
            #     # scheduler.step()
            #     # print("para_lr", scheduler.get_last_lr())
            #     scheduler.step(val_loss)
            #     # print("para_lr", optimizer_list[layer_index].param_groups[0]['lr'])

    # print("average time per epoch", accumulated_time/epochs)

    if mode == "pretrain":
        output = dict(para_path=para_path, para_grad_path=para_grad_path, para_gamma_path=para_gamma_path,
                      performance=performance)
    else:
        output = dict(para_path=para_path, para_grad_path=para_grad_path, para_gamma_path=para_gamma_path,
                      input_gamma_path=input_gamma_path, performance=performance,
                      impute_lr=step_impute_lr, likelihoods=hidden_likelihood)

    return output
