import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.init import xavier_uniform_
import copy
from torch.autograd import Variable

class MinNormSolver:
    def _min_norm_element_from2(v1v1, v1v2, v2v2):
        if v1v2 >= v1v1:
            gamma = 0.999
            cost = v1v1
            return gamma, cost

        if v1v2 >= v2v2:
            gamma = 0.001
            cost = v2v2
            return gamma, cost

        gamma = -1.0 * ((v1v2 - v2v2) / (v1v1 + v2v2 - 2*v1v1))
        cost = v2v2 + gamma * (v1v2 - v2v2)
        return gamma, cost

    def _min_norm_2d(vecs, dps):
        dmin = 1e8
        for i in range(len(vecs)):
            for j in range(i+1, len(vecs)):
                if (i,j) not in dps:
                    dps[(i,j)] = 0.0
                    for k in range(len(vecs[i])):
                        dps[(i,j)] += torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
                    dps[(j,i)] = dps[(i,j)]
                if (i,i) not in dps:
                    dps[(i,i)] = 0.0
                    for k in range(len(vecs[i])):
                        dps[(i,i)] += torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
                if (j,j) not in dps:
                    dps[(j,j)] = 0.0
                    for k in range(len(vecs[i])):
                        dps[(j,j)] += torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
                c,d = MinNormSolver._min_norm_element_from2(dps[(i,i)], dps[(i,j)], dps[(j,j)])
                if d < dmin:
                    dmin = d
                    sol = [(i,j), c, d]
        return sol, dps

    def _projection2simplex(y):
        m = len(y)
        sorted_y = np.flip(np.sort(y), axis=0)
        tmpsum = 0.0
        tmax_f = (np.sum(y) - 1.0)/m
        for i in range(m-1):
            tmpsum += sorted_y[i]
            tmax = (tmpsum-1)/(i+1.0)
            if tmax > sorted_y[i+1]:
                tmax_f = tmax
                break
        return np.maximum(y - tmax_f, np.zeros(y.shape))

    def _next_point(cur_val, grad, n):
        proj_grad = grad - (np.sum(grad)/n)
        tm1 = -1.0 * cur_val[proj_grad<0]/proj_grad[proj_grad<0]
        tm2 = (1.0 - cur_val[proj_grad>0])/proj_grad[proj_grad>0]

        skippers = np.sum(tm1<1e-7) + np.sum(tm2<1e-7)
        t = 1
        if len(tm1[tm1>1e-7]) > 0: t = np.min(tm1[tm1>1e-7])
        if len(tm2[tm2>1e-7]) > 0: t = min(t, np.min(tm2[tm2>1e-7]))

        next_point = proj_grad*t + cur_val
        next_point = MinNormSolver._projection2simplex(next_point)
        return next_point

    def find_min_norm_element(vecs):
        dps = {}
        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)
        
        n = len(vecs)
        sol_vec = np.zeros(n)
        sol_vec[init_sol[0][0]] = init_sol[1]
        sol_vec[init_sol[0][1]] = 1 - init_sol[1]

        if n < 3: return sol_vec, init_sol[2]
        iter_count = 0
        grad_mat = np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                grad_mat[i, j] = dps[(i,j)]
        
        while iter_count < MinNormSolver.MAX_ITER:
            grad_dir = -1.0 * np.dot(grad_mat, sol_vec)
            new_point = MinNormSolver._next_point(sol_vec, grad_dir, n)

            v1v1 = 0.0
            v1v2 = 0.0
            v2v2 = 0.0

            for i in range(n):
                for j in range(n):
                    v1v1 += sol_vec[i]*sol_vec[j]*dps[(i,j)]
                    v1v2 += sol_vec[i]*sol_vec[j]*dps[(i,j)]
                    v2v2 += new_point[i]*new_point[j]*dps[(i,j)]
            
            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
            new_sol_vec = nc*sol_vec + (1-nc)*new_point
            change = new_sol_vec - sol_vec
            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: return sol_vec, nd
            sol_vec = new_sol_vec


    def find_min_norm_element_FW(vecs):
        dps = {}
        init_sol, dps = MinNormSolver._min_norm_2d(vecs, dps)

        n = len(vecs)
        sol_vec = np.zeros(n)
        sol_vec[init_sol[0][0]] = init_sol[1]
        sol_vec[init_sol[0][1]] = 1 - init_sol[1]

        if n < 3: return sol_vec, init_sol[2]
        iter_count = 0
        grad_mat = np.zeros((n,n))
        while iter_count < MinNormSolver.MAX_ITER:
            t_iter = np.argmin(np.dot(grad_mat, sol_vec))
            
            v1v1 = np.dot(sol_vec, np.dot(grad_mat, sol_vec))
            v1v2 = np.dot(sol_vec, grad_mat[:, t_iter])
            v2v2 = grad_mat[t_iter, t_iter]

            nc, nd = MinNormSolver._min_norm_element_from2(v1v1, v1v2, v2v2)
            new_sol_vec = nc*sol_vec
            new_sol_vec[t_iter] += 1 - nc
            change = new_sol_vec - sol_vec
            if np.sum(np.abs(change)) < MinNormSolver.STOP_CRIT: return sol_vec, nd
            new_sol_vec = sol_vec

def gradient_normalizers(grads, losses, normalization_type):
    gn = {}
    if normalization_type == 'l2':
        for t in grads:
            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
    elif normalization_type == 'loss':
        for t in grads:
            gn[t] = losses[t]
    elif normalization_type == 'loss+':
        for t in grads:
            gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
    elif normalization_type == 'none':
        for t in grads:
            gn[t] = 1.0
    else:
        print('ERROR: Invalid Normalization Type')
    return gn


class MainModel(object):

    def __init__(self, config, alpha=1.0, learning_rate=0.001, init_embeddings=None, update_embeddings=True,
                 init_bg=None, update_background=True, adam_beta1=0.99, adam_beta2=0.999, device=None, seed=None,
                 classify_from_covars=True, model='scholar'):

        """
        Create the model
        :param config: a dictionary with the model configuration
        :param alpha: hyperparameter for the document representation prior
        :param learning_rate: learning rate for Adam
        :param init_embeddings: a matrix of embeddings to initialize the first layer of the bag-of-words encoder
        :param update_embeddings: if True, update word embeddings during training
        :param init_bg: a vector of empirical log backgound frequencies
        :param update_background: if True, update the background term during training
        :param adam_beta1: first hyperparameter for Adam
        :param adam_beta2: second hyperparameter for Adam
        :param device: (int) the number of the GPU to use
        """

        if seed is not None:
            torch.manual_seed(seed)
            torch.backends.cudnn.deterministic = True

        self.network_architecture = config
        self.learning_rate = learning_rate
        self.adam_beta1 = adam_beta1

        self.update_embeddings = update_embeddings
        self.update_background = update_background

        # create priors on the hidden state
        self.n_topics = (config["n_topics"])

        if device is None:
            self.device = 'cpu'
        else:
            self.device = 'cuda:' + str(device)

        # interpret alpha as either a (symmetric) scalar prior or a vector prior
        if np.array(alpha).size == 1:
            # if alpha is a scalar, create a symmetric prior vector
            self.alpha = alpha * np.ones((1, self.n_topics)).astype(np.float32)
        else:
            # otherwise use the prior as given
            self.alpha = np.array(alpha).astype(np.float32)
            assert len(self.alpha) == self.n_topics

        # create the pyTorch model
        self._model = SetBasedCLNTM(config, self.alpha, update_embeddings, init_emb=init_embeddings, bg_init=init_bg, device=self.device, classify_from_covars=classify_from_covars).to(self.device)

        # set the criterion
        self.criterion = nn.BCEWithLogitsLoss()

        # create the optimizer
        grad_params = filter(lambda p: p.requires_grad, self._model.parameters())
        self.optimizer = optim.Adam(grad_params, lr=learning_rate, betas=(adam_beta1, adam_beta2))

    def fit(self, X, syn_X, Y, PC, TC, eta_bn_prop=1.0, l1_beta=None, l1_beta_c=None, l1_beta_ci=None):
        """
        Fit the model to a minibatch of data
        :param X: np.array of document word counts [batch size x vocab size]
        :param Y: np.array of labels [batch size x n_labels]
        :param PC: np.array of prior covariates influencing the document-topic prior [batch size x n_prior_covars]
        :param TC: np.array of topic covariates to be associated with topical deviations [batch size x n_topic_covars]
        :param l1_beta: np.array of prior variances on the topic weights
        :param l1_beta_c: np.array of prior variances on the weights for topic covariates
        :param l1_beta_ci: np.array of prior variances on the weights for topic-covariate interactions
        :return: loss; label pred probs; document representations; neg-log-likelihood; KLD
        """
        # move data to device
        X = torch.Tensor(X).to(self.device)
        syn_X = torch.Tensor(syn_X).to(self.device)
        if Y is not None:
            Y = torch.Tensor(Y).to(self.device)
        if PC is not None:
            PC = torch.Tensor(PC).to(self.device)
        if TC is not None:
            TC = torch.Tensor(TC).to(self.device)
        self.optimizer.zero_grad()

        grads = {}
        loss_data = {}
        scale = {}
        n_tasks = 2
        for t in range(n_tasks):

            self.optimizer.zero_grad()

            # do a forward pass
            thetas, X_recon, Y_probs, losses = self._model(X, syn_X, Y, PC, TC, eta_bn_prop=eta_bn_prop, l1_beta=l1_beta, l1_beta_c=l1_beta_c, l1_beta_ci=l1_beta_ci)
            loss, tasks_losses = losses
            tasks_losses = [tasks_losses[1], tasks_losses[2]]
            tasks_loss = tasks_losses[t]
    
            loss_data[t] = tasks_losses[t].item()
            tasks_loss.backward()
            grads[t] = []
            for name, param in self._model.named_parameters():
                if param.grad is not None and 'eta' not in name and 'beta' not in name:
                    grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

        normalization_type = 'loss+'
        gn = gradient_normalizers(grads, loss_data, normalization_type)
        for t in range(n_tasks):
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] /= gn[t]
        
        sol, min_norm = MinNormSolver.find_min_norm_element(grads)
        for t in range(n_tasks): scale[t] = float(sol[t])

        # do a forward pass
        self.optimizer.zero_grad()
        thetas, X_recon, Y_probs, losses = self._model(X, syn_X, Y, PC, TC, eta_bn_prop=eta_bn_prop, l1_beta=l1_beta, l1_beta_c=l1_beta_c, l1_beta_ci=l1_beta_ci)
        loss, tasks_losses = losses
        loss = tasks_losses[0]
        tasks_losses = [tasks_losses[1], tasks_losses[2]]
        for t in range(n_tasks): loss += scale[t] * tasks_losses[t]
        # update model
        loss.backward()
        self.optimizer.step()

        if Y_probs is not None:
            Y_probs = Y_probs.to('cpu').detach().numpy()
        return loss.to('cpu').detach().numpy(), Y_probs, thetas.to('cpu').detach().numpy(), tasks_losses[0].to('cpu').detach().numpy(), tasks_losses[1].to('cpu').detach().numpy()

    def predict(self, X, PC, TC, eta_bn_prop=0.0):
        """
        Predict labels for a minibatch of data
        """
        # input a vector of all zeros in place of the labels that the model has been trained on
        batch_size = self.get_batch_size(X)
        Y = np.zeros((batch_size, self.network_architecture['n_labels'])).astype('float32')
        X = torch.Tensor(X).to(self.device)
        Y = torch.Tensor(Y).to(self.device)
        if PC is not None:
            PC = torch.Tensor(PC).to(self.device)
        if TC is not None:
            TC = torch.Tensor(TC).to(self.device)
        theta, _, Y_recon, _ = self._model(X, Y, PC, TC, do_average=False, var_scale=0.0, eta_bn_prop=eta_bn_prop)
        return theta, Y_recon.to('cpu').detach().numpy()

    def predict_from_topics(self, theta, PC, TC, eta_bn_prop=0.0):
        """
        Predict label probabilities from each topic
        """
        theta = torch.Tensor(theta)
        if PC is not None:
            PC = torch.Tensor(PC)
        if TC is not None:
            TC = torch.Tensor(TC)
        probs = self._model.predict_from_theta(theta, PC, TC)
        return probs.to('cpu').detach().numpy()

    def get_losses(self, X, Y, PC, TC, eta_bn_prop=0.0, n_samples=0, syn_X=None):
        """
        Compute and return the loss values for all instances in X, Y, PC, and TC averaged over multiple samples
        """
        batch_size = self.get_batch_size(X)
        if batch_size == 1:
            X = np.expand_dims(X, axis=0)
        if Y is not None and batch_size == 1:
            Y = np.expand_dims(Y, axis=0)
        if PC is not None and batch_size == 1:
            PC = np.expand_dims(PC, axis=0)
        if TC is not None and batch_size == 1:
            TC = np.expand_dims(TC, axis=0)
        X = torch.Tensor(X).to(self.device)
        if Y is not None:
            Y = torch.Tensor(Y).to(self.device)
        if PC is not None:
            PC = torch.Tensor(PC).to(self.device)
        if TC is not None:
            TC = torch.Tensor(TC).to(self.device)
        if n_samples == 0:
            _, _, _, temp = self._model(X, syn_X, Y, PC, TC, do_average=False, var_scale=0.0, eta_bn_prop=eta_bn_prop)
            loss, _ = temp
            losses = loss.to('cpu').detach().numpy()
        else:
            _, _, _, temp = self._model(X, Y, PC, TC, do_average=False, var_scale=1.0, eta_bn_prop=eta_bn_prop)
            loss, NL, KLD = temp
            losses = loss.to('cpu').detach().numpy()
            for s in range(1, n_samples):
                _, _, _, temp = self._model(X, Y, PC, TC, do_average=False, var_scale=1.0, eta_bn_prop=eta_bn_prop)
                loss, NL, KLD = temp
                losses += loss.to('cpu').detach().numpy()
            losses /= float(n_samples)

        return losses

    def compute_theta(self, X, syn_X, Y, PC, TC, eta_bn_prop=0.0):
        """
        Return the latent document representation (mean of posterior of theta) for a given batch of X, Y, PC, and TC
        """
        batch_size = self.get_batch_size(X)
        if batch_size == 1:
            X = np.expand_dims(X, axis=0)
        if Y is not None and batch_size == 1:
            Y = np.expand_dims(Y, axis=0)
        if PC is not None and batch_size == 1:
            PC = np.expand_dims(PC, axis=0)
        if TC is not None and batch_size == 1:
            TC = np.expand_dims(TC, axis=0)

        X = torch.Tensor(X).to(self.device)
        if Y is not None:
            Y = torch.Tensor(Y).to(self.device)
        if PC is not None:
            PC = torch.Tensor(PC).to(self.device)
        if TC is not None:
            TC = torch.Tensor(TC).to(self.device)
        theta, _, _, _ = self._model(X, syn_X, Y, PC, TC, do_average=False, var_scale=0.0, eta_bn_prop=eta_bn_prop)

        return theta.to('cpu').detach().numpy()

    def get_weights(self):
        """
        Return the topic-vocabulary deviation weights
        """
        emb = self._model.beta_layer.to('cpu').weight.detach().numpy().T
        self._model.beta_layer.to(self.device)
        return emb

    def get_bg(self):
        """
        Return the background terms
        """
        bg = self._model.beta_layer.to('cpu').bias.detach().numpy()
        self._model.beta_layer.to(self.device)
        return bg

    def get_prior_weights(self):
        """
        Return the weights associated with the prior covariates
        """
        emb = self._model.prior_covar_weights.to('cpu').weight.detach().numpy().T
        self._model.prior_covar_weights.to(self.device)
        return emb

    def get_covar_weights(self):
        """
        Return the topic weight (deviations) associated with the topic covariates
        """
        emb = self._model.beta_c_layer.to('cpu').weight.detach().numpy().T
        self._model.beta_c_layer.to(self.device)
        return emb

    def get_covar_interaction_weights(self):
        """
        Return the weights (deviations) associated with the topic-covariate interactions
        """
        emb = self._model.beta_ci_layer.to('cpu').weight.detach().numpy().T
        self._model.beta_ci_layer.to(self.device)
        return emb

    def get_batch_size(self, X):
        """
        Get the batch size for a minibatch of data
        :param X: the minibatch
        :return: the size of the minibatch
        """
        if len(X.shape) == 1:
            batch_size = 1
        else:
            batch_size, _ = X.shape
        return batch_size

    def eval(self):
        self._model.eval()

    def train(self):
        self._model.train()



class ContrastiveLearning(nn.Module):
    def __init__(self):
        super().__init__()
        self.softplus = nn.Softplus()

    def forward(self, max_x, max_y, min_x, min_y, temp=0.2):
        max_x = F.normalize(max_x)
        max_y = F.normalize(max_y)
        min_x = F.normalize(min_x)
        min_y = F.normalize(min_y)
        N = max_x.shape[0]//2

        min_x_y_mul = torch.matmul(min_x, min_y.T)/temp
        min_x_y_max, _ = torch.max(min_x_y_mul, dim=1, keepdim=True)
        min_logits = min_x_y_mul - min_x_y_max.detach()
        pos_logits = torch.trace(-torch.log(F.softmax(min_logits, -1)))

        max_x_y_mul = torch.matmul(max_x, max_y.T)/temp
        max_x_y_max, _ = torch.max(max_x_y_mul, dim=1, keepdim=True)
        max_logits = max_x_y_mul - max_x_y_max.detach()
        neg_logits_tensor = torch.log(F.softmax(max_logits, -1))
        neg_logits = torch.sum(neg_logits_tensor) - torch.trace(neg_logits_tensor) + pos_logits

        contrastive_loss = pos_logits - neg_logits
        return contrastive_loss


class SetBasedCLNTM(nn.Module):

    def __init__(self, config, alpha, update_embeddings=True, init_emb=None, bg_init=None, device='cpu', classify_from_covars=False):
        super(SetBasedCLNTM, self).__init__()

        # load the configuration
        self.vocab_size = config['vocab_size']
        self.words_emb_dim = config['embedding_dim']
        self.n_topics = config['n_topics']
        self.n_labels = config['n_labels']
        self.n_prior_covars = config['n_prior_covars']
        self.n_topic_covars = config['n_topic_covars']
        self.classifier_layers = config['classifier_layers']
        self.use_interactions = config['use_interactions']
        self.l1_beta_reg = config['l1_beta_reg']
        self.l1_beta_c_reg = config['l1_beta_c_reg']
        self.l1_beta_ci_reg = config['l1_beta_ci_reg']
        self.l2_prior_reg = config['l2_prior_reg']
        self.device = device
        self.classify_from_covars = classify_from_covars

        # create a layer for prior covariates to influence the document prior
        if self.n_prior_covars > 0:
            self.prior_covar_weights = nn.Linear(self.n_prior_covars, self.n_topics, bias=False)
        else:
            self.prior_covar_weights = None

        # create the encoder
        self.embeddings_x_layer = nn.Linear(self.vocab_size, self.words_emb_dim, bias=False)
        emb_size = self.words_emb_dim
        classifier_input_dim = self.n_topics
        if self.n_prior_covars > 0:
            emb_size += self.n_prior_covars
            if self.classify_from_covars:
                classifier_input_dim += self.n_prior_covars
        if self.n_topic_covars > 0:
            emb_size += self.n_topic_covars
            if self.classify_from_covars:
                classifier_input_dim += self.n_topic_covars
        if self.n_labels > 0:
            emb_size += self.n_labels

        self.encoder_dropout_layer = nn.Dropout(p=0.2)

        if not update_embeddings:
            self.embeddings_x_layer.weight.requires_grad = False
        if init_emb is not None:
            self.embeddings_x_layer.weight.data.copy_(torch.from_numpy(init_emb)).to(self.device)
        else:
            xavier_uniform_(self.embeddings_x_layer.weight)

        # create the mean and variance components of the VAE
        self.mean_layer = nn.Linear(emb_size, self.n_topics)
        self.logvar_layer = nn.Linear(emb_size, self.n_topics)

        self.mean_bn_layer = nn.BatchNorm1d(self.n_topics, eps=0.001, momentum=0.001, affine=True)
        self.mean_bn_layer.weight.data.copy_(torch.from_numpy(np.ones(self.n_topics))).to(self.device)
        self.mean_bn_layer.weight.requires_grad = False
        self.logvar_bn_layer = nn.BatchNorm1d(self.n_topics, eps=0.001, momentum=0.001, affine=True)
        self.logvar_bn_layer.weight.data.copy_(torch.from_numpy(np.ones(self.n_topics))).to(self.device)
        self.logvar_bn_layer.weight.requires_grad = False

        self.z_dropout_layer = nn.Dropout(p=0.2)

        # create the decoder
        self.beta_layer = nn.Linear(self.n_topics, self.vocab_size)

        xavier_uniform_(self.beta_layer.weight)
        if bg_init is not None:
            self.beta_layer.bias.data.copy_(torch.from_numpy(bg_init))
            self.beta_layer.bias.requires_grad = False
        self.beta_layer = self.beta_layer.to(self.device)

        if self.n_topic_covars > 0:
            self.beta_c_layer = nn.Linear(self.n_topic_covars, self.vocab_size, bias=False).to(self.device)
            if self.use_interactions:
                self.beta_ci_layer = nn.Linear(self.n_topics * self.n_topic_covars, self.vocab_size, bias=False).to(self.device)

        # create the classifier
        if self.n_labels > 0:
            if self.classifier_layers == 0:
                self.classifier_layer_0 = nn.Linear(classifier_input_dim, self.n_labels).to(self.device)
            else:
                self.classifier_layer_0 = nn.Linear(classifier_input_dim, classifier_input_dim).to(self.device)
                self.classifier_layer_1 = nn.Linear(classifier_input_dim, self.n_labels).to(self.device)

        # create a final batchnorm layer
        self.eta_bn_layer = nn.BatchNorm1d(self.vocab_size, eps=0.001, momentum=0.001, affine=True).to(self.device)
        self.eta_bn_layer.weight.data.copy_(torch.from_numpy(np.ones(self.vocab_size)).to(self.device))
        self.eta_bn_layer.weight.requires_grad = False

        # create the document prior terms
        prior_mean = (np.log(alpha).T - np.mean(np.log(alpha), 1)).T
        prior_var = (((1.0 / alpha) * (1 - (2.0 / self.n_topics))).T + (1.0 / (self.n_topics * self.n_topics)) * np.sum(1.0 / alpha, 1)).T

        prior_mean = np.array(prior_mean).reshape((1, self.n_topics))
        prior_logvar = np.array(np.log(prior_var)).reshape((1, self.n_topics))
        self.prior_mean = torch.from_numpy(prior_mean).to(self.device)
        self.prior_mean.requires_grad = False
        self.prior_logvar = torch.from_numpy(prior_logvar).to(self.device)
        self.prior_logvar.requires_grad = False

        # group-based contrastive hyperparameters
        self.n_perm = config['n_perm']
        self.K = config['K']
        self.contrastive_objective = ContrastiveLearning()

    def encode(self, X, Y, var_scale, PC, TC):
        # embed the word counts
        en0_x = self.embeddings_x_layer(X)
        encoder_parts = [en0_x]

        # append additional components to the encoder, if given
        if self.n_prior_covars > 0:
            encoder_parts.append(PC)
        if self.n_topic_covars > 0:
            encoder_parts.append(TC)
        if self.n_labels > 0:
            encoder_parts.append(Y)

        if len(encoder_parts) > 1:
            en0 = torch.cat(encoder_parts, dim=1).to(self.device)
        else:
            en0 = en0_x

        encoder_output = F.softplus(en0)
        encoder_output_do = self.encoder_dropout_layer(encoder_output)

        # compute the mean and variance of the document posteriors
        posterior_mean = self.mean_layer(encoder_output_do)
        posterior_logvar = self.logvar_layer(encoder_output_do)

        posterior_mean_bn = self.mean_bn_layer(posterior_mean)
        posterior_logvar_bn = self.logvar_bn_layer(posterior_logvar)
        #posterior_mean_bn = posterior_mean
        #posterior_logvar_bn = posterior_logvar

        posterior_var = posterior_logvar_bn.exp().to(self.device)

        # sample noise from a standard normal
        eps = X.data.new().resize_as_(posterior_mean_bn.data).normal_().to(self.device)

        # compute the sampled latent representation
        z = posterior_mean_bn + posterior_var.sqrt() * eps * var_scale
        z_do = self.z_dropout_layer(z)

        # pass the document representations through a softmax
        theta = F.softmax(z_do, dim=1)
        return z_do, theta, posterior_mean, posterior_mean_bn, posterior_logvar, posterior_logvar_bn

    def forward(self, X, syn_X, Y, PC, TC, compute_loss=True, do_average=True, eta_bn_prop=1.0, var_scale=1.0, l1_beta=None, l1_beta_c=None, l1_beta_ci=None):
        """
        Do a forward pass of the model
        :param X: np.array of word counts [batch_size x vocab_size]
        :param Y: np.array of labels [batch_size x n_classes]
        :param PC: np.array of covariates influencing the prior [batch_size x n_prior_covars]
        :param TC: np.array of covariates with explicit topic deviations [batch_size x n_topic_covariates]
        :param compute_loss: if True, compute and return the loss
        :param do_average: if True, average the loss over the minibatch
        :param eta_bn_prop: (float) a weight between 0 and 1 to interpolate between using and not using the final batchnorm layer
        :param var_scale: (float) a parameter which can be used to scale the variance of the random noise in the VAE
        :param l1_beta: np.array of prior variances for the topic weights
        :param l1_beta_c: np.array of prior variances on topic covariate deviations
        :param l1_beta_ci: np.array of prior variances on topic-covariate interactions
        :return: document representation; reconstruction; label probs; (loss, if requested)
        """

        z_do, theta, posterior_mean, posterior_mean_bn, posterior_logvar, posterior_logvar_bn = self.encode(X, Y, var_scale, PC, TC)
        if syn_X is not None:
            syn_z_do, syn_theta, _, _, _, _ = self.encode(syn_X, Y, var_scale, PC, TC)

            # group-based contrastive learning
            n_batch = X.size(0)
            rand = torch.rand(self.n_perm, n_batch)
            batch_rand_perm = rand.argsort(dim=1)
            min_set_z_do_list = []
            min_set_z_do_syn_list = []
            max_set_z_do_list = []
            max_set_z_do_syn_list = []
            set_z_do = theta[batch_rand_perm]
            set_z_do_syn = syn_theta[batch_rand_perm]

            for i in range(n_batch // self.K):
                psi = batch_rand_perm[:, i*self.K:i*self.K+self.K].cuda()
                min_extracted_set_z_do, _ = torch.min(set_z_do[:, psi, :][0].squeeze(0), 1)
                min_extracted_set_z_do_syn, _ = torch.min(set_z_do_syn[:, psi, :][0].squeeze(0), 1)
                min_set_z_do_list.append(min_extracted_set_z_do)
                min_set_z_do_syn_list.append(min_extracted_set_z_do_syn)

                max_extracted_set_z_do, _ = torch.max(set_z_do[:, psi, :][0].squeeze(0), 1)
                max_extracted_set_z_do_syn, _ = torch.max(set_z_do_syn[:, psi, :][0].squeeze(0), 1)
                max_set_z_do_list.append(max_extracted_set_z_do)
                max_set_z_do_syn_list.append(max_extracted_set_z_do_syn)

            max_set_z_tensor = torch.stack(max_set_z_do_list, 1)
            max_set_z_syn_tensor = torch.stack(max_set_z_do_syn_list, 1)
            max_set_z_tensor = max_set_z_tensor.view(max_set_z_tensor.size(0) * max_set_z_tensor.size(1), -1)
            max_set_z_syn_tensor = max_set_z_syn_tensor.view(max_set_z_syn_tensor.size(0) * max_set_z_syn_tensor.size(1), -1)

            min_set_z_tensor = torch.stack(min_set_z_do_list, 1)
            min_set_z_syn_tensor = torch.stack(min_set_z_do_syn_list, 1)
            min_set_z_tensor = min_set_z_tensor.view(min_set_z_tensor.size(0) * min_set_z_tensor.size(1), -1)
            min_set_z_syn_tensor = min_set_z_syn_tensor.view(min_set_z_syn_tensor.size(0) * min_set_z_syn_tensor.size(1), -1)
            
            contrastive_loss = self.contrastive_objective(max_set_z_tensor, max_set_z_syn_tensor, min_set_z_tensor, min_set_z_syn_tensor)
        else:
            contrastive_loss = 0.0
        # combine latent representation with topics and background
        # beta layer here includes both the topic weights and the background term (as a bias)
        eta = self.beta_layer(theta)

        # add deviations for covariates (and interactions)
        if self.n_topic_covars > 0:
            eta = eta + self.beta_c_layer(TC)
            if self.use_interactions:
                theta_rsh = theta.unsqueeze(2)
                tc_emb_rsh = TC.unsqueeze(1)
                covar_interactions = theta_rsh * tc_emb_rsh
                batch_size, _, _ = covar_interactions.shape
                eta += self.beta_ci_layer(covar_interactions.reshape((batch_size, self.n_topics * self.n_topic_covars)))

        # pass the unnormalized word probabilities through a batch norm layer
        eta_bn = self.eta_bn_layer(eta)
        #eta_bn = eta

        # compute X recon with and without batchnorm on eta, and take a convex combination of them
        X_recon_bn = F.softmax(eta_bn, dim=1)
        X_recon_no_bn = F.softmax(eta, dim=1)
        X_recon = eta_bn_prop * X_recon_bn + (1.0 - eta_bn_prop) * X_recon_no_bn

        # predict labels
        Y_recon = None
        if self.n_labels > 0:

            classifier_inputs = [theta]
            if self.classify_from_covars:
                if self.n_prior_covars > 0:
                    classifier_inputs.append(PC)
                if self.n_topic_covars > 0:
                    classifier_inputs.append(TC)

            if len(classifier_inputs) > 1:
                classifier_input = torch.cat(classifier_inputs, dim=1).to(self.device)
            else:
                classifier_input = theta

            if self.classifier_layers == 0:
                decoded_y = self.classifier_layer_0(classifier_input)
            elif self.classifier_layers == 1:
                cls0 = self.classifier_layer_0(classifier_input)
                cls0_sp = F.softplus(cls0)
                decoded_y = self.classifier_layer_1(cls0_sp)
            else:
                cls0 = self.classifier_layer_0(classifier_input)
                cls0_sp = F.softplus(cls0)
                cls1 = self.classifier_layer_1(cls0_sp)
                cls1_sp = F.softplus(cls1)
                decoded_y = self.classifier_layer_2(cls1_sp)
            Y_recon = F.softmax(decoded_y, dim=1)

        # compute the document prior if using prior covariates
        if self.n_prior_covars > 0:
            prior_mean = self.prior_covar_weights(PC)
            prior_logvar = self.prior_logvar.expand_as(posterior_logvar)
        else:
            prior_mean   = self.prior_mean.expand_as(posterior_mean)
            prior_logvar = self.prior_logvar.expand_as(posterior_logvar)

        if compute_loss:
            return theta, X_recon, Y_recon, self._loss(X, Y, X_recon, Y_recon, prior_mean, prior_logvar, posterior_mean_bn, posterior_logvar_bn, do_average, l1_beta, l1_beta_c, l1_beta_ci, contrastive_loss)
        else:
            return theta, X_recon, Y_recon

    def _loss(self, X, Y, X_recon, Y_recon, prior_mean, prior_logvar, posterior_mean, posterior_logvar, do_average=True, l1_beta=None, l1_beta_c=None, l1_beta_ci=None, contrastive_loss=None):

        # compute reconstruction loss
        NL = -(X * (X_recon+1e-10).log()).sum(1)
        # compute label loss
        if self.n_labels > 0:
            NL += -(Y * (Y_recon+1e-10).log()).sum(1)

        # compute KLD
        prior_var = prior_logvar.exp()
        posterior_var = posterior_logvar.exp()
        var_division    = posterior_var / prior_var
        diff            = posterior_mean - prior_mean
        diff_term       = diff * diff / prior_var
        logvar_division = prior_logvar - posterior_logvar

        # put KLD together
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - self.n_topics)

        # combine
        loss = (NL + KLD) + 0.1*contrastive_loss
        # add regularization on prior
        if self.l2_prior_reg > 0 and self.n_prior_covars > 0:
            loss += self.l2_prior_reg * torch.pow(self.prior_covar_weights.weight, 2).sum()

        # add regularization on topic and topic covariate weights
        if self.l1_beta_reg > 0 and l1_beta is not None:
            l1_strengths_beta = torch.from_numpy(l1_beta).to(self.device)
            beta_weights_sq = torch.pow(self.beta_layer.weight, 2)
            loss += self.l1_beta_reg * (l1_strengths_beta * beta_weights_sq).sum()

        if self.n_topic_covars > 0 and l1_beta_c is not None and self.l1_beta_c_reg > 0:
            l1_strengths_beta_c = torch.from_numpy(l1_beta_c).to(self.device)
            beta_c_weights_sq = torch.pow(self.beta_c_layer.weight, 2)
            loss += self.l1_beta_c_reg * (l1_strengths_beta_c * beta_c_weights_sq).sum()

        if self.n_topic_covars > 0 and self.use_interactions and l1_beta_c is not None and self.l1_beta_ci_reg > 0:
            l1_strengths_beta_ci = torch.from_numpy(l1_beta_ci).to(self.device)
            beta_ci_weights_sq = torch.pow(self.beta_ci_layer.weight, 2)
            loss += self.l1_beta_ci_reg * (l1_strengths_beta_ci * beta_ci_weights_sq).sum()


        # average losses if desired
        if do_average:
            return loss.mean(), [NL.mean(), KLD.mean(), contrastive_loss]
        else:
            return loss, [NL, KLD, contrastive_loss]

    def predict_from_theta(self, theta, PC, TC):
        # Predict labels from a distribution over topics
        Y_recon = None
        if self.n_labels > 0:

            classifier_inputs = [theta]
            if self.classify_from_covars:
                if self.n_prior_covars > 0:
                    classifier_inputs.append(PC)
                if self.n_topic_covars > 0:
                    classifier_inputs.append(TC)
            if len(classifier_inputs) > 1:
                classifier_input = torch.cat(classifier_inputs, dim=1).to(self.device)
            else:
                classifier_input = theta.to(self.device)

            if self.classifier_layers == 0:
                decoded_y = self.classifier_layer_0(classifier_input)
            elif self.classifier_layers == 1:
                cls0 = self.classifier_layer_0(classifier_input)
                cls0_sp = F.softplus(cls0)
                decoded_y = self.classifier_layer_1(cls0_sp)
            else:
                cls0 = self.classifier_layer_0(classifier_input)
                cls0_sp = F.softplus(cls0)
                cls1 = self.classifier_layer_1(cls0_sp)
                cls1_sp = F.softplus(cls1)
                decoded_y = self.classifier_layer_1(cls1_sp)
            Y_recon = F.softmax(decoded_y, dim=1)

        return Y_recon
