import torch
import torch.nn as nn
import numpy as np
from torch.optim import SGD


class LatentModule(nn.Module):
    def __init__(self, input_dim, layer_dim, first_module=False):
        super(LatentModule, self).__init__()
        num_layer = len(layer_dim)

        layer_list = []
        if first_module:
            layer_list.append(nn.Linear(input_dim, layer_dim[0]))
        else:
            layer_list.append(nn.Tanh())
            layer_list.append(nn.Linear(input_dim, layer_dim[0]))

        for i in range(1, num_layer):
            layer_list.append(nn.Tanh())
            layer_list.append(nn.Linear(layer_dim[i-1], layer_dim[i]))

        self.module = nn.ModuleList(layer_list)
    
    def forward(self, x):
        for layer in self.module:
            x = layer(x)
        return x
    
   
class StoNet_Model():
    def __init__(self, model, outcome_cat):
        # initialize model structure
        self.model = model
        self.outcome_cat = outcome_cat

        # initialize model prune 
        self.prune_flag = 0
        self.mask_prune = None

    def _train_prep(self, para_lrs_train, para_lrs_fine_tune):
        # containers for intermediate results
        self.para_path = {}
        self.para_grad_path = {}
        self.para_gamma_path = {}
        self.hidden_likelihood = np.zeros(len(self.model.module_list))
        self.input_gamma_path = dict(var_selected={}, num_selected=[])

        # parameters for backward imputation
        self.backward_imputation_args = dict()

        # parameters for latent likelihood
        self.likelihood_latent_args = dict()

        # initialize optimizer
        # note: in this setting, everytime you call model.train(), the optimizers will be reinitialized!
        # therefore, the calculation of momentum is also initialized; the para_lr will be set as the starting value 
        self.optimizer_list_train, self.optimizer_list_fine_tune = [], []
        for i in range(len(self.model.module_list)):
            self.optimizer_list_train.append(SGD(self.model.module_list[i].parameters(), lr=para_lrs_train[i],
                                                        momentum=0.9, maximize=True))
            self.optimizer_list_fine_tune.append(SGD(self.model.module_list[i].parameters(), lr=para_lrs_fine_tune[i],
                                                            momentum=0.9, maximize=True))

    def _train(self, mode, train_data, val_data, epochs, batch_size,
               impute_lrs, mh_step, itas, para_lr_decay, impute_lr_decay,
               prior_sigma_0, prior_sigma_1, lambda_n, 
               CE_weight, y_scale):
        
        if mode in ['pretrain', 'train']:
            optimizer_list = self.optimizer_list_train
        else:
            optimizer_list = self.optimizer_list_fine_tune

        # initial value of decaying impute_lrs and para_lrs
        self.step_impute_lrs = impute_lrs.copy()
        init_para_lrs = []
        for i in range(len(self.model.module_list)):
            init_para_lrs.append(optimizer_list[i].param_groups[0]['lr'])

        # loss functions for training
        if self.outcome_cat:
            out_loss_sum = nn.CrossEntropyLoss(weight=CE_weight, reduction='sum')
        else:
            out_loss_sum = nn.MSELoss(reduction='sum')

        # intermediate values for prior gradient calculation
        c1 = np.log(lambda_n) - np.log(1 - lambda_n) + 0.5 * np.log(prior_sigma_0) - 0.5 * np.log(prior_sigma_1)
        c2 = 0.5 / prior_sigma_0 - 0.5 / prior_sigma_1

        # threshold for sparsity
        threshold = np.sqrt(np.log((1 - lambda_n) / lambda_n * np.sqrt(prior_sigma_1 / prior_sigma_0)) / (
                0.5 / prior_sigma_0 - 0.5 / prior_sigma_1))
        
        # intermediate values for calculating the posterior mean of sigma_z
        d1, d2 = 0, 0
        prior_alpha, prior_beta = 1, 1  # a relatively flat prior for sigma_z
        
        # training loop
        for epoch in range(epochs):
            print("Epoch" + str(epoch))

            # tic = time.time()
            if mode in ["train", "finetune"]:
                # impute_lr decay and para_lr decay
                for i in range(len(self.model.module_list)-1):
                    self.step_impute_lrs[i] = impute_lrs[i]/(1+impute_lrs[i]*epoch**impute_lr_decay)
                    print("impute_lr", self.step_impute_lrs[i])
                    optimizer_list[i].param_groups[0]['lr'] = init_para_lrs[i]/(1+init_para_lrs[i]*epoch**para_lr_decay)
                    print("para_lr", optimizer_list[i].param_groups[0]['lr'])
                optimizer_list[-1].param_groups[0]['lr'] = init_para_lrs[-1]/(1+init_para_lrs[-1]*epoch**para_lr_decay)


            for y, treat, *rest in train_data:
                if len(rest) > 0:
                    x = rest[0]
                    self.backward_imputation_args.update(x=x)
                    input = x
                else:
                    input = treat
                # backward imputation
                self.backward_imputation_args.update(impute_lrs=self.step_impute_lrs, outcome_loss=out_loss_sum, treat=treat, y=y,
                                                     mh_step=mh_step, itas=itas)

                hidden_list = self.model.backward_imputation(**self.backward_imputation_args)

                # parameter update
                for para in self.model.module_list.parameters():
                    para.grad = None

                with torch.no_grad():
                    for para in self.model.module_list.parameters():
                        temp = para.pow(2).mul(c2).add(c1).exp().add(1).pow(-1)
                        temp = para.div(-prior_sigma_0).mul(temp) + para.div(-prior_sigma_1).mul(1 - temp)
                        prior_grad = temp.div(len(train_data)*batch_size)
                        para.grad = prior_grad

                for module_index in range(len(self.model.module_list)):
                    self.likelihood_latent_args.update(hidden_list=hidden_list, module_index=module_index,
                                                    outcome_loss=out_loss_sum, y=y, input=input)
                    likelihood = self.model.likelihood_latent(**self.likelihood_latent_args)/batch_size
                    optimizer = optimizer_list[module_index]
                    likelihood.backward()

                    if self.prune_flag == 1:
                        self.prune_masked_grad()
                    
                    optimizer.step()

                    if epoch == epochs-1:
                        with torch.no_grad():
                            # for calculating likelihood
                            likelihood = self.model.likelihood_latent(**self.likelihood_latent_args)
                            self.hidden_likelihood[module_index] += likelihood

                            # for update sigma_z
                            mu_z = self.model.module_list[0](input)
                            d1 += self.model.sse(hidden_list[0], mu_z).div(2)
                            d2 += len(y) * self.model.hidden_dim[self.model.confounder_layer]/2
                
            # evaluate model performance
            train_loss, train_acc, *train_other_result = self.performance_eval('train', train_data, out_loss_sum, y_scale, epoch, mode)
            val_loss, val_acc, *val_other_result = self.performance_eval('val', val_data, out_loss_sum, y_scale, epoch, mode)
            if epoch == epochs-1:
                self.performance = dict(train_loss=train_loss, val_loss=val_loss)
                if self.outcome_cat:
                    self.performance.update(train_acc = train_acc, val_acc = val_acc)
                if len(train_other_result) > 0:
                    self.performance.update(treat_train_loss = train_other_result[0], 
                                            treat_train_acc = train_other_result[1],
                                            treat_val_loss = val_other_result[0],
                                            treat_val_acc = val_other_result[1])

            # save intermediate training performance
            para_path_temp = {str(epoch): {}}
            para_grad_path_temp = {str(epoch): {}}
            para_gamma_path_temp = {str(epoch): {}}

            for name, para in self.model.module_list.named_parameters():
                para_path_temp[str(epoch)][name] = torch.clone(para).data.cpu().numpy().tolist()
                para_grad_path_temp[str(epoch)][name] = torch.clone(para.grad).data.cpu().numpy().tolist()
                para_gamma_path_temp[str(epoch)][name] = (para.abs() > threshold).data.cpu().numpy().tolist()
            self.para_path.update(para_path_temp)
            self.para_grad_path.update(para_grad_path_temp)
            self.para_gamma_path.update(para_gamma_path_temp)

            # prune the newtork
            if mode == "train":
                self.selected_variable(epoch)     
                                                                        
        # update sigma_Z
        d1 += prior_beta
        d2 += (prior_alpha-1)
        with torch.no_grad():
            self.model.sigma_z = d1/d2
        print("sigma_z", self.model.sigma_z)


    def set_prune(self, user_mask):
        self.mask_prune = user_mask
        self.prune_flag = 1

    def cancel_prune(self):
        self.prune_flag = 0
        self.mask_prune = None

    def prune_masked_para(self):
        for name, para in self.model.module_list.named_parameters():
            para.data[self.mask_prune[name]] = 0

    def prune_masked_grad(self):
        for name, para in self.model.module_list.named_parameters():
            para.grad[self.mask_prune[name]] = 0
    
    def prune_network(self, threshold):
        user_mask = {}
        for name, para in self.model.module_list.named_parameters():
            user_mask[name] = para.abs() < threshold
        self.set_prune(user_mask)
        self.prune_masked_para()
    
    def BIC_and_non_zero_para(self, train_set, likelihoods):
        with torch.no_grad():
            num_non_zero_element = 0
            for name, para in self.model.module_list.named_parameters():
                num_non_zero_element = num_non_zero_element + para.numel() - self.mask_prune[name].sum()

            BIC = (np.log(train_set.__len__()) * num_non_zero_element - 2 * np.sum(likelihoods)).item()

        return BIC, num_non_zero_element
    
    def selected_variable(self):
        pass

    def performance_eval(self, *args):
        pass

    def predict(self):
        pass
