from model.modules import LatentModule, StoNet_Model
import torch
import torch.nn as nn
import numpy as np
from torch.distributions.normal import Normal
#from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import recall_score


class StoNet_MultiCause(nn.Module):
    def __init__(self, hidden_dim, input_dim, output_dim, confounder_layer, sigma_list=None):
        super(StoNet_MultiCause, self).__init__()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.hidden_dim = hidden_dim
        self.confounder_layer = confounder_layer

        layer_dim = hidden_dim.copy()
        layer_dim.append(output_dim)

        self.module_az = LatentModule(input_dim, layer_dim[:confounder_layer+1], first_module=True)
        self.module_zay = LatentModule(input_dim + layer_dim[confounder_layer], 
                                       layer_dim[confounder_layer+1:], first_module=True)
        
        self.module_list = nn.ModuleList([self.module_az, self.module_zay])

        if sigma_list is not None:
            self.sigma_z = sigma_list[0]
            self.sigma_list = sigma_list
        
        self.sse = nn.MSELoss(reduction='sum')
    
    def forward(self, treat):
        z = self.module_az(treat)
        az = torch.concat([z, treat], dim=1)
        y = self.module_zay(az)

        return y
    
    def likelihood_latent(self, hidden_list, module_index, outcome_loss, y, input=None, forward_z=None):
        if module_index == 0:  # log_likelihood(Z|A)
            if forward_z is not None:  # for imputation
                likelihood = -self.sse(forward_z, hidden_list[module_index]) / (2 * self.sigma_list[module_index])
            elif input is not None:  # for parameter update
                likelihood = -self.sse(self.module_list[module_index](input), hidden_list[module_index]) / (2 * self.sigma_list[module_index])

        elif module_index == 1:  # log_likelihood(Y|A, Z)
            latent_za = torch.concat([hidden_list[module_index - 1], input], dim=1)
            likelihood = -outcome_loss(self.module_list[module_index](latent_za), y) / (2 * self.sigma_list[-1])

        return likelihood
    

    def backward_imputation(self, mh_step, impute_lrs, outcome_loss, treat, y, itas):
        # initialize momentum term and hidden unit
        hidden_list, momentum_list = [], []

        # initialize momentum term and latent variable
        hidden_list.append(self.module_az(treat).detach())
        momentum_list.append(torch.zeros_like(hidden_list[-1]))

        for i in range(len(hidden_list)):
            hidden_list[i].requires_grad = True
        with torch.no_grad():
            forward_z = torch.clone(hidden_list[0])

        # backward imputation by SGHMC
        for step in range(mh_step):
            # hidden units imputation
            for module_index in reversed(range(len(hidden_list))):
                hidden_list[module_index].grad = None

                hidden_likelihood1 = self.likelihood_latent(hidden_list=hidden_list, module_index=module_index + 1,
                                                            outcome_loss=outcome_loss, y=y, input=treat)
                hidden_likelihood2 = self.likelihood_latent(hidden_list=hidden_list, module_index=module_index,
                                                            outcome_loss=outcome_loss, y=y, forward_z=forward_z)

                hidden_likelihood1.backward()
                hidden_likelihood2.backward()

                lr = impute_lrs[module_index]
                ita = itas[module_index]
                alpha = lr * ita
                with torch.no_grad():
                    momentum_list[module_index] = (1 - alpha) * momentum_list[module_index] + lr * hidden_list[
                        module_index].grad + torch.FloatTensor(hidden_list[module_index].shape).to(self.device).normal_().mul(
                        np.sqrt(2*alpha))
                    hidden_list[module_index].data += lr * momentum_list[module_index]

        return hidden_list


class StoNet_MultiCause_Model(StoNet_Model):
    def __init__(self, hidden_dim, input_dim, output_dim, confounder_layer, seed, log_dir = None, sigma_list=None, outcome_cat=False):
        np.random.seed(seed)
        torch.manual_seed(seed)
        model = StoNet_MultiCause(hidden_dim, input_dim, output_dim, confounder_layer, sigma_list)
        model.to(model.device)
        #self.writer = SummaryWriter(log_dir=log_dir)

        super().__init__(model, outcome_cat)
    
    def train(self, mode, train_data, val_data, epochs, batch_size,
               impute_lrs, mh_step, itas, para_lrs_train, para_lrs_fine_tune, para_lr_decay, impute_lr_decay,
               prior_sigma_0, prior_sigma_1, lambda_n, 
               y_scale=1, CE_weight=None):
        
        self._train_prep(para_lrs_train, para_lrs_fine_tune)
        self._train(mode=mode, train_data=train_data, val_data=val_data, epochs=epochs, batch_size=batch_size,
               impute_lrs=impute_lrs, mh_step=mh_step, itas=itas, para_lr_decay=para_lr_decay, impute_lr_decay=impute_lr_decay,
               prior_sigma_0=prior_sigma_0, prior_sigma_1=prior_sigma_1, lambda_n=lambda_n, y_scale=y_scale, CE_weight=CE_weight)
        # self.writer.flush()
    
    def performance_eval(self, eval_set, data, loss_func, y_scale, epoch, mode):
        loss, correct, recall_neg, recall_pos = self.calculate_loss(data, loss_func, y_scale)

        if eval_set == 'train':
            # self.writer.add_scalar(mode + ": train_loss", loss, epoch)
            print(f"Avg train loss: {loss:>8f} \n")
            if self.outcome_cat:
                # self.writer.add_scalar(mode + ": train_acc", correct, epoch)
                print(f"train accuracy: {correct:>8f} \n")

        elif eval_set == 'val':
            # self.writer.add_scalar(mode + ": val_loss", loss, epoch)
            print(f"Avg val loss: {loss:>8f} \n")
            if self.outcome_cat:
                # self.writer.add_scalar(mode + ": val_acc", correct, epoch)
                print(f"val accuracy: {correct:>8f} \n")
                # self.writer.add_scalar(mode + ": val_recall_neg", recall_neg, epoch)
                # self.writer.add_scalar(mode + ": val_recall_pos", recall_pos, epoch)
        
        return loss, correct
    
    def calculate_loss(self, data, out_loss_sum, y_scale):
        loss, correct, recall_neg, recall_pos = 0, 0, 0, 0
        with torch.no_grad():
            for y, treat, *rest in data:
                pred = self.predict(treat)
                loss += out_loss_sum(pred, y).item()
                if self.outcome_cat:
                    pred = pred.argmax(1)
                    correct += (pred== y).type(torch.float).sum().item()
                    # recall - note that this is the recall for the last batch
                    recall = recall_score(y, pred, average=None).tolist()
                    recall_neg = recall[0]
                    recall_pos = recall[1]

        if self.outcome_cat is False:
            # use RMSE as the model performance metric for regression tasks
            loss = np.sqrt(loss/len(data.dataset)) * y_scale
        else:
            loss /= len(data.dataset)        

        if self.outcome_cat:
            correct /= len(data.dataset)

        return loss, correct, recall_neg, recall_pos

    def predict(self, treat):
        if self.prune_flag == 1:
            self.prune_masked_para()
        
        return self.model(treat)

    def selected_variable(self, epoch):
         # select input variable
        input_size = len(self.para_gamma_path[str(epoch)]['0.module.0.weight'][0])
        var_ind = np.identity(input_size, dtype=bool)
        for i, (name, para) in enumerate(self.model.module_list.named_parameters()):
            if i % 2 == 0:
                var_ind = np.matmul(self.para_gamma_path[str(epoch)][name], var_ind)
                if i/2 == self.model.confounder_layer:
                    temp1 = np.concatenate((var_ind, np.zeros((var_ind.shape[0], input_size), dtype=bool)),
                                            axis=1)
                    temp2 = np.concatenate((np.zeros((input_size, input_size), dtype=bool),
                                            np.identity(input_size, dtype=bool)), axis=1)
                    var_ind = np.concatenate((temp1, temp2), axis=0)
        var_ind = var_ind[:, :input_size] + var_ind[:, input_size:]
        var_ind = np.max(var_ind, 0)
        num_selected = np.sum(var_ind)
        self.input_gamma_path['var_selected'][str(epoch)] = var_ind.tolist()
        self.input_gamma_path['num_selected'].append(num_selected.astype("float64"))
        print('number of selected input variable for outcome:', num_selected)
    
    # def get_z(self, treat, z_sd, seed=1, sample_size=100):
    #     """
    #     sample latent confounders z from p(z), using p(z|a)
    #     assume that p(z|a) is normal -- could modify to incorporate the discrete variable
    #     Note: treat has to be sampled from the whole dataset!!!! (if sample_size < 100, use sample_size, otherwise, sample 100 records from the wole sample)
    #     """
    #     if self.prune_flag == 1:
    #         self.prune_masked_para()

    #     # estimate E[z|x]
    #     z = self.model.module_az(treat)

    #     # sample z from p(z|x)
    #     np.random.seed(seed)
    #     torch.manual_seed(seed)
    #     normal = Normal(loc=z, scale=z_sd)
    #     z = normal.sample((sample_size,)) # dimension of z: sample_size * sample_batch_size * dim_confounder_layer

    #     # pull all the samples together to form sample from p(z)
    #     z = z.flatten(0, 1) # dimension of z: (sample_size * sample_batch_size) * dim_confounder_layer
        
    #     return z
    
    # def get_y_cf(self, treat, z, outcome_cat):
    #     """
    #     predict the counterfactual that corresponds to a treatment

    #     x: observed proxy variable
    #     z_sd: the standard deviation of p(z|x), can be either specified or estimated
    #     outcome_cat: the type of outcome variable, binary or continuous
    #     seed: seed to control randomness.
    #     """
    #     if self.prune_flag == 1:
    #         self.prune_masked_para()
        
    #     sample_size_z = z.shape[0]
    #     batch_size = treat.shape[0]

    #     # sample y_cf
    #     # shape: batch_size * (sample_size * sample_batch_size) * variable length
    #     treat_extend = treat[:, None, :].repeat(1, sample_size_z, 1)
    #     z_extend = z[None, :, :].repeat(batch_size, 1, 1)
    #     az = torch.concat([z_extend, treat_extend], dim=-1)
    #     y = self.model.module_zay(az)

    #     if outcome_cat:
    #         m = nn.Softmax(dim=-1)
    #         y = m(y).mean(dim=1)
    #     else:
    #         y = y.mean(dim=1)

    #     return y

    def get_y_cf(self, treat, z_sd, outcome_cat, sample_size_z=100, seed=1):
        """
        predict the counterfactual that corresponds to a treatment

        x: observed proxy variable
        z_sd: the standard deviation of p(z|x), can be either specified or estimated
        outcome_cat: the type of outcome variable, binary or continuous
        seed: seed to control randomness.
        """
        if self.prune_flag == 1:
            self.prune_masked_para()

        # estimate E[z|x]
        z = self.model.module_az(treat)

        # sample z from p(z|x)
        np.random.seed(seed)
        torch.manual_seed(seed)
        normal = Normal(loc=z, scale=z_sd)
        z = normal.sample((sample_size_z,))  # dimension of z: sample_size * batch * dim_confounder_layer

        # sample y_cf
        temp = treat.unsqueeze(0).repeat(sample_size_z, 1, 1)
        az = torch.concat([torch.tanh(z), temp], dim=2)
        y = self.model.module_zay(az)

        if outcome_cat:
            m = nn.Softmax(dim=2)
            y = m(y).mean(dim=0)
        else:
            y = y.mean(dim=0)

        return y

    def get_marginal_effect(self, treat, z_sd, epsilon, epsilon_idx, outcome_cat):
        treat_cf = treat.clone()
        treat_cf[:, epsilon_idx] += epsilon
        
        # prediction
        y_pred = self.get_y_cf(treat, z_sd, outcome_cat)

        # counterfactual prediction
        y_cf_pred = self.get_y_cf(treat_cf, z_sd, outcome_cat)
        
        marginal_effect_est = (y_cf_pred-y_pred)/epsilon

        return marginal_effect_est
