import pickle
import numpy as np
import time
import itertools
import pandas as pd
import torch
import pdb
from sklearn.preprocessing import OneHotEncoder
import torch.nn.utils.prune as prune
from torch.nn import functional as F
import torch.distributed as dist
import matplotlib.pyplot as plt
from torch.distributed.fsdp import (
                                    FullyShardedDataParallel as FSDP,
                                    MixedPrecision,
                                    BackwardPrefetch,
                                    ShardingStrategy,
                                    FullStateDictConfig,
                                    StateDictType,
                                    )

from .models import multimodal_mixVAE as mixVAE_model
from ..utils.deepnet_tools import zinb_error
from ..utils.deepnet_tools import safe_item
from ..utils.deepnet_tools import jaccard_distance
from ..utils.deepnet_tools import save_model_with_gzip, load_model_with_gzip
from ..utils.deepnet_tools import gumbel_softmax
from .networks import classifier as classifier_model


class multi_mod_cpl_mixVAE:

    def __init__(self, saving_folder='', augmenter=[], device=None, eps=1e-8, save_flag=True):
        """
        Initialized the cpl_mixVAE class.

        input args:
            saving_folder: a string that indicates the folder to save the model(s) and file(s).
            augmentor: the pre-trained augmenter.
            device: computing device, either 'cpu' or 'cuda'.
            eps: a small constant value to fix computation overflow.
            save_flag: a boolean variable, if True, the model is saved.
        """

        self.eps = eps
        self.save = save_flag
        self.folder = saving_folder
        self.aug = [True if augmenter else False][0]
        self.device = device
        
        if self.aug:
            self.netA = augmenter



    def init_model(self, 
                    networks,
                    n_categories, 
                    input_dim, 
                    lowD_dim,
                    state_dim, 
                    n_class,
                    x_drop=None, 
                    n_arm=None, 
                    noise_model='Gaussian',
                    tau=0.005, 
                    trained_model='', 
                    n_pr=0,
                    ):
        """
        Initialized the deep mixture model and its optimizer.

        input args:
            n_categories: number of categories of the latent variables.
            state_dim: dimension of the state variable.
            input_dim: input dimension (size of the input layer).
            lowD_dim: dimension of the latent representation.
            x_drop: dropout probability at the first (input) layer.
            n_arm: int value that indicates number of arms.
            tau: temperature of the softmax layers, usually equals to 1/n_categories (0 < tau <= 1).
            trained_model: a pre-trained model, in case you want to initialized the network with a pre-trained network.
            n_pr: number of pruned categories, only if you want to initialize the network with a pre-trained network.
            momentum: a hyperparameter for batch normalization that updates its running statistics.
            noise_model: the noise model of the data, either 'Gaussian' or 'ZINB'.
        """
        self.lowD_dim = lowD_dim
        self.n_categories = n_categories
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.n_arm = n_arm
        self.n_class = n_class
        self.modalities = list(networks.keys())
        self.mod_pairs = list(itertools.combinations(self.modalities, 2))
        self.cross_mod_pairs = [''.join(pair) for pair in self.mod_pairs]
        self.all_modalities = list(self.modalities) + self.cross_mod_pairs
        
        if x_drop is None:
            x_drop = dict.fromkeys(self.modalities, 0.)
        if n_arm is None:
            n_arm = dict.fromkeys(self.modalities, 1)
            
        
        self.model = mixVAE_model(
                                modalities=self.modalities, 
                                networks=networks,
                                input_dim=input_dim, 
                                n_categories=n_categories, 
                                x_drop=x_drop, 
                                n_arm=n_arm, 
                                tau=tau, 
                                eps=self.eps,
                                noise_model=noise_model,
                                )
        
        self.classifier = classifier_model(
                                            input_dim=self.n_categories, 
                                            n_class=self.n_class, 
                                            p_drop=0., 
                                            )

        self.onehot_encoder = OneHotEncoder(categories=[list(range(self.n_class))], sparse_output=False, handle_unknown='ignore')
        self.dist = False

        if len(trained_model) > 0:
            print('Load the pre-trained model')
            # if you wish to load another model for evaluation
            self.load_model(trained_model)
            self.init = False
            self.n_pr = n_pr
        else:
            self.init = True
            self.n_pr = 0


    def load_model(self, trained_model):
        if self.compress:
            self.model = load_model_with_gzip(model=self.model, filepath=trained_model)
        else:
            loaded_file = torch.load(trained_model, map_location='cpu', weights_only=False)
            try:
                self.model.load_state_dict(loaded_file['model_state_dict'])
                # self.optimizer = torch.optim.Adam(self.model.parameters())
                # self.optimizer.load_state_dict(loaded_file['optimizer_state_dict'])
            except:
                self.model.load_state_dict(loaded_file)
    
    
    def load_classifier(self, trained_clf):
        if self.compress:
            self.classifier = load_model_with_gzip(model=self.model, filepath=trained_clf)
        else:
            loaded_file = torch.load(trained_clf, map_location='cpu', weights_only=False)
            try:
                self.classifier.load_state_dict(loaded_file['model_state_dict'])
                # self.optimizer = torch.optim.Adam(self.model.parameters())
                # self.optimizer.load_state_dict(loaded_file['optimizer_state_dict'])
            except:
                self.classifier.load_state_dict(loaded_file)
    
    
    def loss(self, recon_x, x, y_pred, y, mu, log_sigma, qc, c, lam, beta, prior_c, mask, ref_mod):
        """
        loss function of the cpl-mixVAE network including.

       input args
            recon_x: a distionary including the reconstructed data for each modality.
            x: a distionary includes original input data.
            y_pred: a distionary includes the predicted labels for each modality.
            y: an array of the true labels.
            mu: distionary of mean of the Gaussian distribution for the sate variable.
            log_sigma: log of variance of the Gaussian distribution for the sate variable.
            qc: probability of categories for all modalities.
            c: samples fom all distrubtions for all modalities.
            lam: coupling factor in the cpl-mixVAE model.
            beta: regularizer for the KL divergence term.

        return
            total_loss: total loss value.
            l_rec: reconstruction loss for each modality.
            loss_joint: coupling loss.
            neg_joint_entropy: negative joint entropy of the categorical variable.
            qc_distance: distance between a pair of categorical distributions, i.e. qc_a & qc_b.
            c_distance: Euclidean distance between a pair of categorical variables, i.e. c_a & c_b.
            KLD: dictionary of KL divergences for the state variables across all modalities.
            var_a.min(): minimum variance of the modalities.

        """
        l_rec = dict.fromkeys(self.modalities)
        l_classifier = dict.fromkeys(self.modalities)
        loss_indep = dict.fromkeys(self.modalities)
        KLD_cont = dict.fromkeys(self.modalities)
        neg_joint_entropy = dict.fromkeys(self.all_modalities)
        z_distance_rep = dict.fromkeys(self.all_modalities)
        z_distance = dict.fromkeys(self.all_modalities)
        loss_joint = dict.fromkeys(self.all_modalities)
        loss = dict.fromkeys(self.all_modalities)

        loss_ind = 0.
        for m in self.modalities:
            l_rec[m] = [None] * self.n_arm[m]
            l_classifier[m] = [None] * self.n_arm[m]
            loss_indep[m] = [None] * self.n_arm[m]
            KLD_cont[m] = [None] * self.n_arm[m] 
            var_qc_inv, log_qc = [None] * 2, [None] * 2
            neg_joint_entropy[m], z_distance_rep[m], z_distance[m] = [], [], []
            
            for arm_a in range(self.n_arm[m]):
                recon_x_ = recon_x[m][arm_a][mask[m]]
                x_ = x[m][arm_a][mask[m]]
                
                x_bin = 0. * x_
                x_bin[x_ > self.eps] = 1.
                rec_bin = 0. * recon_x_
                rec_bin[recon_x_ > self.eps] = 1.
                if self.model.noise_model == 'Gaussian':
                    l_rec[m][arm_a] = F.mse_loss(recon_x_, x_, reduction='sum') / (x_.size(0))
                    l_rec[m][arm_a] += F.binary_cross_entropy(rec_bin, x_bin)

                if self.variational:
                    log_sigma_ = log_sigma[m][arm_a][mask[m]]    
                    mu_ = mu[m][arm_a][mask[m]]
                    KLD_cont[m][arm_a] = (-0.5 * torch.mean(1 + log_sigma_ - mu_.pow(2) - log_sigma_.exp(), dim=0)).sum()
                    loss_indep[m][arm_a] = l_rec[m][arm_a] + beta * KLD_cont[m][arm_a]
                else:
                    loss_indep[m][arm_a] = l_rec[m][arm_a]
                    KLD_cont[m][arm_a] = [0.]

                qc_a = qc[m][arm_a][mask[m]]
                log_qc[0] = torch.log(qc_a + self.eps)
                var_qc0 = qc_a.var(0)

                var_qc_inv[0] = (1 / (var_qc0 + self.eps)).repeat(qc_a.size(0), 1).sqrt()
                
                for arm_b in range(arm_a + 1, self.n_arm[m]):
                    qc_b = qc[m][arm_b][mask[m]]
                    log_qc[1] = torch.log(qc_b + self.eps)
                    tmp_entropy = (torch.sum(qc_a * log_qc[0], dim=-1)).mean() + (torch.sum(qc_b * log_qc[1], dim=-1)).mean()
                    neg_joint_entropy[m].append(tmp_entropy)
                    # var = qc[arm_b].var(0)
                    var_qc1 = qc_b.var(0)
                    var_qc_inv[1] = (1 / (var_qc1 + self.eps)).repeat(qc_b.size(0), 1).sqrt()

                    # distance between z_1 and z_2 i.e., ||z_1 - z_2||^2
                    # Euclidean distance
                    z_distance_rep[m].append((torch.norm((c[m][arm_a][mask[m]] - c[m][arm_b][mask[m]]), p=2, dim=1).pow(2)).mean())
                    z_distance[m].append((torch.norm((log_qc[0] * var_qc_inv[0]) - (log_qc[1] * var_qc_inv[1]), p=2, dim=1).pow(2)).mean())
                
                
                # if the prior categorical variable is provided
                if self.ref_prior:
                    n_comb = max(self.n_arm[m] * (self.n_arm[m] + 1) / 2, 1)
                    scaler = self.n_arm[m]
                    # distance between z_1 and z_2 i.e., ||z_1 - z_2||^2
                    # Euclidean distance
                    z_distance_rep[m].append((torch.norm((c[m][arm_a][mask[m]] - prior_c[mask[m]]), p=2, dim=1).pow(2)).mean())
                    tmp_entropy = (torch.sum(qc[m][arm_a][mask[m]] * log_qc[0], dim=-1)).mean()
                    neg_joint_entropy[m].append(tmp_entropy)
                    qc_bin = gumbel_softmax(qc[m][arm_a][mask[m]], 1, self.n_categories, 1, hard=True, gumble_noise=False)
                    # z_distance[m].append(self.lam_pc * F.binary_cross_entropy(qc_bin, prior_c[mask[m]]))
                    z_distance[m].append(z_distance_rep[m][-1])
                else:
                    n_comb = max(self.n_arm[m] * (self.n_arm[m] - 1) / 2, 1)
                    scaler = max((self.n_arm[m] - 1), 1)
            
            loss_joint[m] = lam[m] * sum(z_distance[m]) + sum(neg_joint_entropy[m]) +  ((self.n_categories / 2) * (np.log(2 * np.pi)) - 0.5 * np.log(2 * lam[m])) 
            loss[m] = sum(loss_indep[m]) + loss_joint[m]
            
        for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
            mn = self.cross_mod_pairs[im]
            pair_mask = mask[self.cross_mod_pairs[im]]
            neg_joint_entropy[mn] = 0.
            z_distance[mn] = 0.
            z_distance_rep[mn] = 0.
            iter = 0.
            for arm_a in range(self.n_arm[mod_1]):
                y_pred_ = y_pred[mod_1][arm_a][pair_mask]
                y_ = y[pair_mask]
                l_classifier[mod_1][arm_a] = F.binary_cross_entropy(y_pred_, y_)
                
                qc_1 = qc[mod_1][arm_a][pair_mask]
                log_qc_1 = torch.log(qc_1 + self.eps)
                c_1 = c[mod_1][arm_a][pair_mask]
                var_qc = qc_1.var(0)
                var_qc_inv_1 = (1 / (var_qc + self.eps)).repeat(qc_1.size(0), 1).sqrt()
                for arm_b in range(self.n_arm[mod_2]):
                    y_pred_ = y_pred[mod_2][arm_b][pair_mask]
                    y_ = y[pair_mask]
                    l_classifier[mod_2][arm_b] = F.binary_cross_entropy(y_pred_, y_)
                
                    qc_2 = qc[mod_2][arm_b][pair_mask]
                    log_qc_2 = torch.log(qc_2 + self.eps)
                    c_2 = c[mod_2][arm_b][pair_mask]
                    var_qc = qc_2.var(0)
                    var_qc_inv_2 = (1 / (var_qc + self.eps)).repeat(qc_2.size(0), 1).sqrt()

                    if ref_mod == mod_1:
                        qc_prior = qc_1.clone().detach()
                        neg_joint_entropy[mn] += (torch.sum(qc_2 * log_qc_2, dim=-1)).mean()
                        qc_bin = gumbel_softmax(qc_2, 1, self.n_categories, 1, hard=True, gumble_noise=False)
                        z_distance[mn] += F.binary_cross_entropy(qc_bin, qc_prior)
                    
                    elif ref_mod == mod_2:
                        qc_prior = qc_2.clone().detach()
                        neg_joint_entropy[mn] += (torch.sum(qc_1 * log_qc_1, dim=-1)).mean()
                        qc_bin = gumbel_softmax(qc_1, 1, self.n_categories, 1, hard=True, gumble_noise=False)
                        z_distance[mn] += F.binary_cross_entropy(qc_bin, qc_prior)
                        
                    else:
                        neg_joint_entropy[mn] += (torch.sum(qc_1 * log_qc_1, dim=-1)).mean() + (torch.sum(qc_2 * log_qc_2, dim=-1)).mean()
                        # distance between c_rna and c_atac i.e., ||c_rna - c_atac||^2
                        z_distance[mn] += (torch.norm((log_qc_1 * var_qc_inv_1) - (log_qc_2 * var_qc_inv_2), p=2, dim=1).pow(2)).mean()
                    
                    # Euclidean distance (only for monitoriing purpose)
                    z_distance_rep[mn] += (torch.norm((c_1 - c_2), p=2, dim=1).pow(2)).mean()
                    iter += 1
            
            neg_joint_entropy[mn] /= iter
            z_distance[mn] /= iter
            z_distance_rep[mn] /= iter
            loss_joint[mn] = lam[mn] * z_distance[mn] + neg_joint_entropy[mn] +  ((self.n_categories / 2) * (np.log(2 * np.pi)) - 0.5 * np.log(2 * lam[mn])) 
            loss[mn] = (pair_mask.shape[0] / pair_mask.sum()) *  loss_joint[mn]
            loss[mod_1] += sum(l_classifier[mod_1]) 
            loss[mod_2] += sum(l_classifier[mod_2])

        var_min = min((qc[self.modalities[0]][self.arm_paired][pair_mask].var(0).min(), qc[self.modalities[1]][self.arm_paired][pair_mask].var(0).min()))
        total_loss = sum(loss.values())
        return total_loss, loss, l_rec, l_classifier, loss_joint, neg_joint_entropy, z_distance, z_distance_rep, KLD_cont, var_min


    def train(self, train_loader, test_loader, n_epoch, n_epoch_p, lr=1e-3, min_con=.5, max_prun_it=0, ref_prior=False, temp=1., ws=1,
              hard=False, variational=True, c_p=None, lam=1., beta=1., arm_paired=0, world_size=1, wandb_run=None, compress=False, rank=None, 
              add_noise=False, noise_std=0.1, ref_mod=''):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            train_loader: train dataloader.
            test_loader: test dataloader.
            n_epoch: number of training epoch, without pruning.
            n_epoch_p: number of training epoch, with pruning.
            lr: the learning rate of the optimizer, here Adam.
            c_p: the prior categorical variable, only if ref_prior is True.
            c_onehot: the one-hot representation of the prior categorical variable, only if ref_prior is True.
            min_con: minimum value of consensus among pair of arms.
            max_prun_it: maximum number of pruning iterations.
            temp: temperature of sampling
            hard: a boolean variable, True uses one-hot method that is used in Gumbel-softmax, and False uses the Gumbel-softmax function.
            variational: a boolean variable for variational mode, False mode does not use sampling.

        return
            data_file_id: the output dictionary.
        """
        # define current_time
        self.current_time = time.strftime('%Y-%m-%d-%H-%M-%S')
        self.variational = variational
        self.temp = temp
        self.ref_prior = ref_prior
        self.lr = lr
        self.arm_paired = arm_paired
        self.n_modality = len(self.modalities)
        
        if rank is None:
            rank = self.device
        
        # make a preformance dataframes for training, validation, and testing 
        df_keys = ['total_loss', 'minVar']
        for m in self.modalities:
            df_keys.append(f'loss_joint_{m}')
            df_keys.append(f'entropy_{m}')
            df_keys.append(f'distance_{m}')
            df_keys.append(f'log_distance_{m}')
            for arm in range(self.n_arm[m]):
                df_keys.append(f'recon_error_arm{m}_{arm}')

        for m in self.cross_mod_pairs:
            df_keys.append(f'loss_joint_{m}')
            df_keys.append(f'entropy_{m}')
            df_keys.append(f'distance_{m}')
            df_keys.append(f'log_distance_{m}')

        train_df = pd.DataFrame(np.zeros((n_epoch, len(df_keys))), columns=df_keys)
        validation_df = pd.DataFrame(np.zeros((n_epoch, len(df_keys))), columns=df_keys)
        
        # initialize the model parameters for pruning
        bias_mask = torch.ones(self.n_categories)
        weight_mask = dict.fromkeys(self.modalities)
        fc_mu = dict.fromkeys(self.modalities)
        fc_sigma = dict.fromkeys(self.modalities)
        lowD_mask = dict.fromkeys(self.modalities)
        for m in self.modalities:
            weight_mask[m] = torch.ones((self.n_categories, self.lowD_dim[m]))
            fc_mu[m] = torch.ones((self.state_dim[m], self.n_categories + self.lowD_dim[m]))
            fc_sigma[m] = torch.ones((self.state_dim[m], self.n_categories + self.lowD_dim[m]))
            lowD_mask[m] = torch.ones((self.lowD_dim[m], self.state_dim[m] + self.n_categories))
        
        batch_size = train_loader.batch_size
        self.compress = compress
        self.model = self.model.to(rank)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        print(f"Model is on device: {next(self.model.parameters()).device}")
        
        if self.aug:
            self.netA.to(rank)
        
        self.classifier = self.classifier.to(rank)
        self.cl_optimizer = torch.optim.Adam(self.classifier.parameters(), lr=lr)
            
        # training the model without pruning
        if self.init:
            print("Start training ...")
            for epoch in range(n_epoch):
                mod_loss_val = {key: 0 for key in self.all_modalities} 
                jointloss_val = {key: 0 for key in self.all_modalities} 
                dist_qc = {key: 0 for key in self.all_modalities} 
                log_dist_qc = {key: 0 for key in self.all_modalities} 
                entr = {key: 0 for key in self.all_modalities} 
                loss_rec = {key: 0 for key in self.modalities} 
                loss_clf = dict.fromkeys(self.modalities)
                
                for m in self.modalities:
                    loss_rec[m] = np.zeros(self.n_arm[m])
                    loss_clf[m] = np.zeros(self.n_arm[m])
                    
                var_min = 0
                loss_val = 0
                self.model.train()
                t0 = time.time()
                for batch_indx, data_block in enumerate(train_loader):
                    d_idx = data_block[-1]
                    label = data_block[-2]
                    label_onehot = np.full((label.shape[0], self.n_class), -1)
                    valid_indices = label != -1
                    label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
                    label_onehot = torch.FloatTensor(label_onehot).to(rank)
                    if self.aug:
                        label_onehot = label_onehot.repeat(2, 1)

                    train_data = dict.fromkeys(self.modalities)
                    mod_mask = dict.fromkeys(self.all_modalities)
                    for im, m in enumerate(self.modalities):
                        if m == 'M':
                            data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                        else:
                            data = data_block[im]
                        mod_mask[m] = data_block[im + self.n_modality]
                        train_data[m] = []
                        for arm in range(self.n_arm[m]):
                            data = data.to(rank)
                            if self.aug:
                                if arm == 0:
                                    mod_mask[m] = torch.concat((mod_mask[m], mod_mask[m]), 0)
                                if m == 'T':
                                    _, gen_data = self.netA(data, True, noise_std)
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                                elif m == 'M': 
                                    noise = torch.distributions.Exponential(1/0.5).sample(data.shape).to(rank)  
                                    gen_data = data.clone()
                                    tmp_mask = gen_data > 0
                                    gen_data[tmp_mask] += noise[tmp_mask]
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                                else:
                                    gen_data = data + torch.randn(data.shape).to(rank) * noise_std
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                            else:
                                train_data[m].append(data)
                            
                    for im, (m1, m2) in enumerate(self.mod_pairs):
                        mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]

                    if self.ref_prior:
                        prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                    else:
                        prior_c = None

                    self.optimizer.zero_grad()
                    self.cl_optimizer.zero_grad()
                    recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                            x=train_data, 
                                                                            temp=self.temp, 
                                                                            hard=hard, 
                                                                            variational=variational,
                                                                            )
                    
                    label_pred = dict.fromkeys(self.modalities)
                    for m in self.modalities:
                        label_pred[m] = [None] * self.n_arm[m]
                        for arm in range(self.n_arm[m]):
                            label_pred[m][arm] = self.classifier(qc[m][arm])
                            
                    loss, loss_dict, l_rec, l_classifier, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                                recon_x=recon_batch, 
                                                                                                                                x=train_data, 
                                                                                                                                y_pred=label_pred,
                                                                                                                                y=label_onehot,
                                                                                                                                mu=mu, 
                                                                                                                                log_sigma=log_var, 
                                                                                                                                qc=qc, 
                                                                                                                                c=c, 
                                                                                                                                lam=lam,
                                                                                                                                beta=beta,
                                                                                                                                prior_c=prior_c, 
                                                                                                                                mask=mod_mask,
                                                                                                                                ref_mod=ref_mod,
                                                                                                                                )
                                                                                                                  
                    loss.backward()
                    self.optimizer.step()
                    self.cl_optimizer.step()
                    var_min += min_var.data.item()
                    loss_val += loss.data.item()
                    
                    for m in self.modalities:
                        for arm in range(self.n_arm[m]):
                            loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                            loss_clf[m][arm] += l_classifier[m][arm].data.item() 
                    
                    for m in self.all_modalities:
                        mod_loss_val[m] += safe_item(loss_dict[m])
                        jointloss_val[m] += safe_item(loss_joint[m])
                        log_dist_qc[m] += safe_item(c_distance[m])
                        dist_qc[m] += safe_item(c_rep_dist[m])
                        entr[m] += safe_item(neg_entropy[m])
                
                # if self.dist:
                #     dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(loss_rec, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(c_dist, op=dist.ReduceOp.SUM)
                
                train_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                train_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                for m in self.modalities:
                    train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                    for arm in range(self.n_arm[m]):
                        train_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                        train_df.loc[epoch, f'clf_loss_arm{m}_{arm}'] = loss_clf[m][arm] / (batch_indx + 1)
                       
                for m in self.cross_mod_pairs:
                    train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
            
                print('Training')
                print('Epoch:{}, Elapsed Time:{:.2f}'.format(epoch, time.time() - t0))
                print('====> Total Loss: {:.4f}, minVar: {:.4f}'.format(train_df['total_loss'][epoch], train_df['minVar'][epoch]))
                for m in self.modalities:
                    val_1 = train_df[f'recon_error_arm{m}_0'][epoch]
                    val_2 = train_df[f'distance_{m}'][epoch]
                    val_3 = train_df[f'log_distance_{m}'][epoch]
                    val_4 = train_df[f'entropy_{m}'][epoch]
                    print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                for m in self.cross_mod_pairs:
                    val_1 = train_df[f'distance_{m}'][epoch]
                    val_2 = train_df[f'log_distance_{m}'][epoch]
                    val_3 = train_df[f'entropy_{m}'][epoch]
                    print('====> {}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3))
                    
                if wandb_run:
                    wandb_run.log(
                                {
                                "train/total-loss": train_df['total_loss'][epoch],
                                "train/min-var": train_df['minVar'][epoch],
                                **dict(map(lambda m: (f"train/rec-loss-{m}", train_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                **dict(map(lambda m: (f"train/clf-acc-{m}", train_df[f'clf_loss_arm{m}_0'][epoch]), self.modalities)),
                                **dict(map(lambda m: (f"train/distance-{m}", train_df[f'distance_{m}'][epoch]), self.all_modalities)),
                                **dict(map(lambda m: (f"train/log_distance-{m}", train_df[f'log_distance_{m}'][epoch]), self.all_modalities)),
                                "train/time": time.time() - t0,
                                }
                                )
                # validation step
                self.model.eval()
                self.classifier.eval()
                with torch.no_grad():
                    mod_loss_val = {key: 0 for key in self.all_modalities} 
                    jointloss_val = {key: 0 for key in self.all_modalities} 
                    dist_qc = {key: 0 for key in self.all_modalities} 
                    log_dist_qc = {key: 0 for key in self.all_modalities} 
                    entr = {key: 0 for key in self.all_modalities} 
                    loss_rec = dict.fromkeys(self.modalities)
                    loss_clf = dict.fromkeys(self.modalities)
                    for m in self.modalities:
                        loss_rec[m] = np.zeros(self.n_arm[m])
                        loss_clf[m] = np.zeros(self.n_arm[m])
                        
                    var_min = 0
                    loss_val = 0
                    for batch_indx, data_block in enumerate(test_loader):
                        d_idx = data_block[-1]
                        label = data_block[-2]
                        label_onehot = np.full((label.shape[0], self.n_class), -1)
                        valid_indices = label != -1
                        label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
                        label_onehot = torch.FloatTensor(label_onehot).to(rank)
                    
                        val_data = dict.fromkeys(self.modalities)
                        mod_mask = dict.fromkeys(self.all_modalities)
                        for im, m in enumerate(self.modalities):
                            if m == 'M':
                                data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                            else:
                                data = data_block[im]
                            mod_mask[m] = data_block[im + self.n_modality]
                            val_data[m] = []
                            for arm in range(self.n_arm[m]):
                                val_data[m].append(data.to(rank))
                        for im, (m1, m2) in enumerate(self.mod_pairs):
                            mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                        if self.ref_prior:
                            prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                        else:
                            prior_c = None
                            
                        recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                x=val_data, 
                                                                                temp=self.temp, 
                                                                                hard=hard, 
                                                                                variational=variational, 
                                                                                eval=True,
                                                                                )
                        label_pred = dict.fromkeys(self.modalities)
                        for m in self.modalities:
                            label_pred[m] = [None] * self.n_arm[m]
                            for arm in range(self.n_arm[m]):
                                label_pred[m][arm] = self.classifier(qc[m][arm])
                            
                        loss, loss_dict, l_rec, l_classifier, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                                    recon_x=recon_batch, 
                                                                                                                                    x=val_data, 
                                                                                                                                    y_pred=label_pred,
                                                                                                                                    y=label_onehot,
                                                                                                                                    mu=mu, 
                                                                                                                                    log_sigma=log_var, 
                                                                                                                                    qc=qc, 
                                                                                                                                    c=c, 
                                                                                                                                    lam=lam,
                                                                                                                                    beta=beta,
                                                                                                                                    prior_c=prior_c, 
                                                                                                                                    mask=mod_mask,
                                                                                                                                    ref_mod=ref_mod,
                                                                                                                                    )
                        var_min += min_var.data.item()
                        loss_val += loss.data.item()
                        
                        for m in self.modalities:
                            for arm in range(self.n_arm[m]):
                                loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                                loss_clf[m][arm] += l_classifier[m][arm].data.item()
                        
                        for m in self.all_modalities:
                            mod_loss_val[m] += safe_item(loss_dict[m])
                            jointloss_val[m] += safe_item(loss_joint[m])
                            log_dist_qc[m] += safe_item(c_distance[m])
                            dist_qc[m] += safe_item(c_rep_dist[m])
                            entr[m] += safe_item(neg_entropy[m])

                    validation_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                    validation_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                    for m in self.modalities:
                        validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        for arm in range(self.n_arm[m]):
                            validation_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                            validation_df.loc[epoch, f'clf_loss_arm{m}_{arm}'] = loss_clf[m][arm] / (batch_indx + 1)
                        
                    for m in self.cross_mod_pairs:
                        validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        
                    print('Validation')
                    for m in self.modalities:
                        val_1 = validation_df[f'recon_error_arm{m}_0'][epoch]
                        val_2 = validation_df[f'distance_{m}'][epoch]
                        val_3 = validation_df[f'log_distance_{m}'][epoch]
                        val_4 = validation_df[f'entropy_{m}'][epoch]
                        print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                    
                    if wandb_run:
                        wandb_run.log(
                                    {
                                    "validate/total-loss": validation_df['total_loss'][epoch],
                                    "validate/min-var": validation_df['minVar'][epoch],
                                    **dict(map(lambda m: (f"validate/rec-loss-{m}", validation_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"validate/clf-acc-{m}", validation_df[f'clf_loss_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"validate/distance-{m}", validation_df[f'distance_{m}'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"validate/log_distance-{m}", validation_df[f'log_distance_{m}'][epoch]), self.modalities)),
                                    "validate/time": time.time() - t0,
                                    }
                                    )

                if self.save and n_epoch > 10 and epoch // 1000 == 0:
                    trained_model = self.folder + f'/model/cpl_mixVAE_model_{self.current_time}.pth'
                    if ws > 1:
                        dist.barrier()
                        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                        with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                            cpu_state = self.model.state_dict()
                        if rank == 0:
                            torch.save(cpu_state, trained_model)
                        dist.barrier()
                    else:
                        if self.compress:
                            trained_model += '.gz'
                            save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                        else:
                            torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                            
                        prune_indx = []
                        # save train and validation dataframes as text file
                        train_df.to_csv(self.folder + f'/train_df_{self.current_time}.txt', sep='\t', index=False)
                        validation_df.to_csv(self.folder + f'/validation_df_{self.current_time}.txt', sep='\t', index=False)

            # save the model and the learning curve
            if self.save and n_epoch > 0:
                trained_model = self.folder + f'/model/cpl_mixVAE_model_before_pruning_{self.current_time}.pth'
                classifier_model = self.folder + f'/model/classifier_before_pruning_{self.current_time}.pth'
                if ws > 1:
                    dist.barrier()
                    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                    with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                        cpu_state = self.model.state_dict()
                    if rank == 0:
                        torch.save(cpu_state, trained_model)
                        torch.save(self.classifier.state_dict(), classifier_model)
                    dist.barrier()
                else:
                    if self.compress:
                        trained_model += '.gz'
                        save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                        classifier_model += '.gz'
                        save_model_with_gzip(self.classifier.state_dict(), classifier_model, compression_level=3)
                    else:
                        torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                        torch.save({'model_state_dict': self.classifier.state_dict()}, classifier_model, pickle_protocol=4)
                        
                    prune_indx = []
                    # save train and validation dataframes as text file
                    train_df.to_csv(self.folder + f'/train_df_{self.current_time}.txt', sep='\t', index=False)
                    validation_df.to_csv(self.folder + f'/validation_df_{self.current_time}.txt', sep='\t', index=False)
                
    
        if ws > 1:
            #Ensure all ranks reach this point
            dist.barrier()
        # training the model with pruning
        if n_epoch_p > 0:
            # initialized pruning parameters of the layer of the discrete variable
            bias = self.model.qc[self.modalities[0]][0].bias.detach().cpu().numpy()
            pruning_mask = np.where(bias != 0.)[0]
            prune_indx = np.where(bias == 0.)[0]
            stop_prune = False
        else:
            stop_prune = True

        pr = self.n_pr
        ind = []
        while not stop_prune:
            if ref_mod:
                predicted_label = [None] * self.n_arm[ref_mod]
            else:
                predicted_label = {key: [] for key in self.modalities}
            # Assessment over all dataset
            self.model.eval()
            with torch.no_grad():
                for i, data_block in enumerate(train_loader):
                    d_idx = data_block[-1]
                    label = data_block[-2]
                    label_onehot = np.full((label.shape[0], self.n_class), -1)
                    valid_indices = label != -1
                    label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
                    label_onehot = torch.FloatTensor(label_onehot).to(rank)
                    
                    data_dict = dict.fromkeys(self.modalities)
                    mod_mask = dict.fromkeys(self.all_modalities)
                    for im, m in enumerate(self.modalities):
                        if m == 'M':
                            data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                        else:
                            data = data_block[im]
                        mod_mask[m] = data_block[im + self.n_modality]
                        data_dict[m] = []
                        for arm in range(self.n_arm[m]):
                            data_dict[m].append(data.to(rank))
                            
                    for im, (m1, m2) in enumerate(self.mod_pairs):
                        mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                    
                    if self.ref_prior:
                        prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                    else:
                        prior_c = None

                    _, _, qc, _, _, _, _, _ = self.model(x=data_dict, temp=self.temp, hard=hard, variational=variational, eval=True, pruning_mask=pruning_mask)

                    if ref_mod:
                        for arm in range(self.n_arm[ref_mod]):
                            c_encoder = qc[ref_mod][arm].data.view(qc[ref_mod][arm].size()[0], self.n_categories)
                            predicted_label[arm].append(c_encoder[mod_mask[ref_mod]].argmax(dim=1).detach().cpu().numpy())
                    else:
                        for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
                            c_encoder = qc[mod_1][self.arm_paired].data.view(qc[mod_1][self.arm_paired].size()[0], self.n_categories)
                            predicted_label[mod_1].append(c_encoder[mod_mask[self.cross_mod_pairs[im]]].argmax(dim=1).detach().cpu().numpy())
                            c_encoder = qc[mod_2][self.arm_paired].data.view(qc[mod_2][self.arm_paired].size()[0], self.n_categories)
                            predicted_label[mod_2].append(c_encoder[mod_mask[self.cross_mod_pairs[im]]].argmax(dim=1).detach().cpu().numpy())

            
            if ws > 1:
                #Ensure all ranks reach this point
                dist.barrier()
            
            if ws == 1 or (ws > 1 and rank == 0):
                c_agreement = []
                if ref_mod:
                    for arm_a in range(self.n_arm[ref_mod]):
                        pred_a = predicted_label[arm_a]
                        for arm_b in range(arm_a + 1, self.n_arm):
                            pred_b = predicted_label[arm_b]
                            armA_vs_armB = np.zeros((self.n_categories, self.n_categories))

                            for samp in range(pred_a.shape[0]):
                                armA_vs_armB[pred_a[samp].astype(int), pred_b[samp].astype(int)] += 1

                            num_samp_arm = []
                            for ij in range(self.n_categories):
                                sum_row = armA_vs_armB[ij, :].sum()
                                sum_column = armA_vs_armB[:, ij].sum()
                                num_samp_arm.append(max(sum_row, sum_column))

                            armA_vs_armB = np.divide(armA_vs_armB, np.array(num_samp_arm), out=np.zeros_like(armA_vs_armB),
                                                    where=np.array(num_samp_arm) != 0)
                            c_agreement.append(np.diag(armA_vs_armB))
                            ind_sort = np.argsort(c_agreement[-1])

                            # plot the consensus matrix
                            plt.figure()
                            plt.imshow(armA_vs_armB[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
                            plt.colorbar()
                            plt.xlabel('arm_' + str(arm_a), fontsize=20)
                            plt.xticks(range(self.n_categories), range(self.n_categories))
                            plt.yticks(range(self.n_categories), range(self.n_categories))
                            plt.ylabel('arm_' + str(arm_b), fontsize=20)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title('|c|=' + str(self.n_categories), fontsize=20)
                            plt.savefig(self.folder + f'/consensus_{pr}_{ref_mod}_arm_{arm_a}_arm_{arm_b}.png', dpi=600)
                            plt.close("all")

                else:
                    for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
                        pred_a = np.concatenate(predicted_label[mod_1])
                        pred_b = np.concatenate(predicted_label[mod_2])
                        armA_vs_armB = np.zeros((self.n_categories, self.n_categories))

                        for samp in range(pred_a.shape[0]):
                            armA_vs_armB[pred_a[samp].astype(int), pred_b[samp].astype(int)] += 1

                        num_samp_arm = []
                        for ij in range(self.n_categories):
                            sum_row = armA_vs_armB[ij, :].sum()
                            sum_column = armA_vs_armB[:, ij].sum()
                            num_samp_arm.append(max(sum_row, sum_column))

                        armA_vs_armB = np.divide(armA_vs_armB, np.array(num_samp_arm), out=np.zeros_like(armA_vs_armB),
                                                where=np.array(num_samp_arm) != 0)
                        c_agreement.append(np.diag(armA_vs_armB))
                        ind_sort = np.argsort(c_agreement[-1])

                        # plot the consensus matrix
                        plt.figure()
                        plt.imshow(armA_vs_armB[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
                        plt.colorbar()
                        plt.xlabel(f'arm_{mod_1}', fontsize=20)
                        plt.xticks(range(self.n_categories), range(self.n_categories))
                        plt.yticks(range(self.n_categories), range(self.n_categories))
                        plt.ylabel(f'arm_{mod_2}', fontsize=20)
                        plt.xticks([])
                        plt.yticks([])
                        plt.title('|c|=' + str(self.n_categories), fontsize=20)
                        plt.savefig(self.folder + f'/consensus_{pr}_{mod_1}_vs_{mod_2}.png', dpi=600)
                        plt.close("all")

                # consensus among arms for each pair of arms 
                c_agreement = np.mean(c_agreement, axis=0)
                agreement = c_agreement[pruning_mask]
                if (np.min(agreement) <= min_con) and pr < max_prun_it:
                    if pr > 0:
                        ind_min = pruning_mask[np.argmin(agreement)]
                        ind_min = np.array([ind_min])
                        ind = np.concatenate((ind, ind_min))
                    else:
                        ind_min = pruning_mask[np.argmin(agreement)]
                        if len(prune_indx) > 0:
                            ind_min = np.array([ind_min])
                            ind = np.concatenate((prune_indx, ind_min))
                        else:
                            ind.append(ind_min)
                        ind = np.array(ind)

                    ind = ind.astype(int)
                    bias_mask[ind] = 0.
                    for m in self.modalities:
                        weight_mask[m][ind, :] = 0.
                        fc_mu[m][:, self.lowD_dim[m] + ind] = 0.
                        fc_sigma[m][:, self.lowD_dim[m] + ind] = 0.
                        lowD_mask[m][:, ind] = 0.
                        
                    stop_prune = False
                else:
                    print('No more pruning!')
                    stop_prune = True

            # continue the training with pruning
            if not stop_prune:
                print("Continue training with pruning ...")
                print(f"Pruned categories: {ind}")
                bias = bias_mask.detach().cpu().numpy()
                pruning_mask = np.where(bias != 0.)[0]
                pruning_train_df = pd.DataFrame(np.zeros((n_epoch_p, len(df_keys))), columns=df_keys)
                pruning_validation_df = pd.DataFrame(np.zeros((n_epoch_p, len(df_keys))), columns=df_keys)
                # prune the model based on the consensus
                for m in self.modalities:
                    for arm in range(self.n_arm[m]):
                        if ws > 1:
                            dist.barrier()
                            # Average the weights across all processes
                            dist.all_reduce(self.model.qc[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.qc[m][arm].bias, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.mu[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.sigma[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.fc_lowD[m][arm].weight, op=dist.ReduceOp.AVG)
                        
                            with FSDP.summon_full_params(self.model, writeback=True):
                                prune.custom_from_mask(self.model.qc[m][arm], 'weight', mask=weight_mask[m].to(rank))
                                prune.custom_from_mask(self.model.qc[m][arm], 'bias', mask=bias_mask.to(rank))
                                prune.custom_from_mask(self.model.mu[m][arm], 'weight', mask=fc_mu[m].to(rank))
                                prune.custom_from_mask(self.model.sigma[m][arm], 'weight', mask=fc_sigma[m].to(rank))
                                prune.custom_from_mask(self.model.fc_lowD[m][arm], 'weight', mask=lowD_mask[m].to(rank))
                                
                            dist.barrier()  
                        else:
                            prune.custom_from_mask(self.model.qc[m][arm], 'weight', mask=weight_mask[m].to(rank))
                            prune.custom_from_mask(self.model.qc[m][arm], 'bias', mask=bias_mask.to(rank))
                            prune.custom_from_mask(self.model.mu[m][arm], 'weight', mask=fc_mu[m].to(rank))
                            prune.custom_from_mask(self.model.sigma[m][arm], 'weight', mask=fc_sigma[m].to(rank))
                            prune.custom_from_mask(self.model.fc_lowD[m][arm], 'weight', mask=lowD_mask[m].to(rank))
                          
                
                # FSDP.rewrap(self.model) # not sure about this line!
                for epoch in range(n_epoch_p):
                    # training
                    mod_loss_val = {key: 0 for key in self.all_modalities} 
                    jointloss_val = {key: 0 for key in self.all_modalities} 
                    dist_qc = {key: 0 for key in self.all_modalities} 
                    log_dist_qc = {key: 0 for key in self.all_modalities} 
                    entr = {key: 0 for key in self.all_modalities} 
                    loss_rec = {key: 0 for key in self.modalities} 
                    loss_clf = {key: 0 for key in self.modalities}
                    for m in self.modalities:
                        loss_rec[m] = np.zeros(self.n_arm[m])
                        loss_clf[m] = np.zeros(self.n_arm[m])
                    var_min = 0
                    loss_val = 0
                    self.model.train()
                    t0 = time.time()
                    for batch_indx, data_block in enumerate(train_loader):
                        d_idx = data_block[-1]
                        label = data_block[-2]
                        label_onehot = np.full((label.shape[0], self.n_class), -1)
                        valid_indices = label != -1
                        label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
                        label_onehot = torch.FloatTensor(label_onehot).to(rank)
                        if self.aug:
                            label_onehot = label_onehot.repeat(2, 1)
                        
                        train_data = dict.fromkeys(self.modalities)
                        mod_mask = dict.fromkeys(self.all_modalities)
                        for im, m in enumerate(self.modalities):
                            if m == 'M':
                                data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                            else:
                                data = data_block[im]
                            mod_mask[m] = data_block[im + self.n_modality]
                            train_data[m] = []
                            for arm in range(self.n_arm[m]):
                                data = data.to(rank)
                                if self.aug:
                                    if arm == 0:
                                        mod_mask[m] = torch.concat((mod_mask[m], mod_mask[m]), 0)
                                    if m == 'T':
                                        _, gen_data = self.netA(data, True, noise_std)
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                    elif m == 'M': 
                                        noise = torch.distributions.Exponential(1/0.5).sample(data.shape).to(rank)  
                                        gen_data = data.clone()
                                        tmp_mask = gen_data > 0
                                        gen_data[tmp_mask] += noise[tmp_mask]
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                    else:
                                        gen_data = data + torch.randn(data.shape).to(rank) * noise_std
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                        
                                else:
                                    train_data[m].append(data)
                            
                        for im, (m1, m2) in enumerate(self.mod_pairs):
                            mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                            
                        if self.ref_prior:
                            prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                        else:
                            prior_c=None

                        self.optimizer.zero_grad()
                        self.cl_optimizer.zero_grad()
                        recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                x=train_data, 
                                                                                temp=self.temp, 
                                                                                hard=hard, 
                                                                                variational=variational, 
                                                                                pruning_mask=pruning_mask,
                                                                                )
                        label_pred = dict.fromkeys(self.modalities)
                        for m in self.modalities:
                            label_pred[m] = [None] * self.n_arm[m]
                            for arm in range(self.n_arm[m]):
                                label_pred[m][arm] = self.classifier(qc[m][arm])
                                
                        loss, loss_dict, l_rec, l_classifer, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                            recon_x=recon_batch, 
                                                                                                                            x=train_data, 
                                                                                                                            y_pred=label_pred,
                                                                                                                            y=label_onehot,
                                                                                                                            mu=mu, 
                                                                                                                            log_sigma=log_var, 
                                                                                                                            qc=qc, 
                                                                                                                            c=c, 
                                                                                                                            lam=lam,
                                                                                                                            beta=beta,
                                                                                                                            prior_c=prior_c, 
                                                                                                                            mask=mod_mask,
                                                                                                                            ref_mod=ref_mod,
                                                                                                                            )
                        loss.backward()
                        self.optimizer.step()
                        self.cl_optimizer.step()
                        var_min += min_var.data.item()
                        loss_val += loss.data.item()
                        
                        for m in self.modalities:
                            for arm in range(self.n_arm[m]):
                                loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                                loss_clf[m][arm] += l_classifer[m][arm].data.item()
                        
                        for m in self.all_modalities:
                            mod_loss_val[m] += safe_item(loss_dict[m])
                            jointloss_val[m] += safe_item(loss_joint[m])
                            log_dist_qc[m] += safe_item(c_distance[m])
                            dist_qc[m] += safe_item(c_rep_dist[m])
                            entr[m] += safe_item(neg_entropy[m])

                    pruning_train_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                    pruning_train_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                    for m in self.modalities:
                        pruning_train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        for arm in range(self.n_arm[m]):
                            pruning_train_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                            pruning_train_df.loc[epoch, f'clf_loss_arm{m}_{arm}'] = loss_clf[m][arm] / (batch_indx + 1)
                        
                    for m in self.cross_mod_pairs:
                        pruning_train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                
                    
                    print('Training with pruning')
                    print('====> Pruning Epoch:{}, Elapsed Time:{:.2f}'.format(epoch, time.time() - t0))
                    print('====> Total Loss: {:.4f}, minVar: {:.4f}'.format(pruning_train_df['total_loss'][epoch], pruning_train_df['minVar'][epoch]))
                    for m in self.modalities:
                        val_1 = pruning_train_df[f'recon_error_arm{m}_0'][epoch]
                        val_2 = pruning_train_df[f'distance_{m}'][epoch]
                        val_3 = pruning_train_df[f'log_distance_{m}'][epoch]
                        val_4 = pruning_train_df[f'entropy_{m}'][epoch]
                        print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                    for m in self.cross_mod_pairs:
                        val_1 = pruning_train_df[f'distance_{m}'][epoch]
                        val_2 = pruning_train_df[f'log_distance_{m}'][epoch]
                        val_3 = pruning_train_df[f'entropy_{m}'][epoch]
                        print('====> {}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3))

                    if wandb_run:
                        wandb_run.log(
                                    {
                                    "train/total-loss": pruning_train_df['total_loss'][epoch],
                                    "train/min-var": pruning_train_df['minVar'][epoch],
                                    **dict(map(lambda m: (f"train/rec-loss-{m}", pruning_train_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"train/clf-acc-{m}", pruning_train_df[f'clf_loss_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"train/distance-{m}", pruning_train_df[f'distance_{m}'][epoch]), self.all_modalities)),
                                    **dict(map(lambda m: (f"train/log_distance-{m}", pruning_train_df[f'log_distance_{m}'][epoch]), self.all_modalities)),
                                    "train/time": time.time() - t0,
                                    }
                                    )
                    # validation step
                    self.model.eval()
                    self.classifier.eval()
                    with torch.no_grad():
                        mod_loss_val = {key: 0 for key in self.all_modalities} 
                        jointloss_val = {key: 0 for key in self.all_modalities} 
                        dist_qc = {key: 0 for key in self.all_modalities} 
                        log_dist_qc = {key: 0 for key in self.all_modalities} 
                        entr = {key: 0 for key in self.all_modalities} 
                        loss_rec = dict.fromkeys(self.modalities)
                        loss_clf = dict.fromkeys(self.modalities)
                        for m in self.modalities:
                            loss_rec[m] = np.zeros(self.n_arm[m])
                            loss_clf[m] = np.zeros(self.n_arm[m])
                            
                        var_min = 0
                        loss_val = 0
                        for batch_indx, data_block in enumerate(test_loader):
                            d_idx = data_block[-1]
                            label = data_block[-2]
                            label_onehot = np.full((label.shape[0], self.n_class), -1)
                            valid_indices = label != -1
                            label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
                            label_onehot = torch.FloatTensor(label_onehot).to(rank)
                        
                            val_data = dict.fromkeys(self.modalities)
                            mod_mask = dict.fromkeys(self.all_modalities)
                            for im, m in enumerate(self.modalities):
                                if m == 'M':
                                    data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                                else:
                                    data = data_block[im]
                                mod_mask[m] = data_block[im + self.n_modality]
                                val_data[m] = []
                                for arm in range(self.n_arm[m]):
                                    val_data[m].append(data.to(rank))
                            for im, (m1, m2) in enumerate(self.mod_pairs):
                                mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                                
                            if self.ref_prior:
                                prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                            else:
                                prior_c=None
                                
                            recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                    x=val_data, 
                                                                                    temp=self.temp, 
                                                                                    hard=hard, 
                                                                                    variational=variational, 
                                                                                    eval=True, 
                                                                                    pruning_mask=pruning_mask,
                                                                                    )
                            label_pred = dict.fromkeys(self.modalities)
                            for m in self.modalities:
                                label_pred[m] = [None] * self.n_arm[m]
                                for arm in range(self.n_arm[m]):
                                    label_pred[m][arm] = self.classifier(qc[m][arm])
                                    
                            loss, loss_dict, l_rec, l_classifier, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                                        recon_x=recon_batch, 
                                                                                                                                        x=val_data, 
                                                                                                                                        y_pred=label_pred,
                                                                                                                                        y=label_onehot,
                                                                                                                                        mu=mu, 
                                                                                                                                        log_sigma=log_var, 
                                                                                                                                        qc=qc, 
                                                                                                                                        c=c, 
                                                                                                                                        lam=lam,
                                                                                                                                        beta=beta,
                                                                                                                                        prior_c=prior_c, 
                                                                                                                                        mask=mod_mask,
                                                                                                                                        ref_mod=ref_mod,
                                                                                                                                        )
                            var_min += min_var.data.item()
                            loss_val += loss.data.item()
                            
                            for m in self.modalities:
                                for arm in range(self.n_arm[m]):
                                    loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                                    loss_clf[m][arm] += l_classifier[m][arm].data.item()
                            
                            for m in self.all_modalities:
                                mod_loss_val[m] += safe_item(loss_dict[m])
                                jointloss_val[m] += safe_item(loss_joint[m])
                                log_dist_qc[m] += safe_item(c_distance[m])
                                dist_qc[m] += safe_item(c_rep_dist[m])
                                entr[m] += safe_item(neg_entropy[m])
                                

                        pruning_validation_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                        pruning_validation_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                        for m in self.modalities:
                            pruning_validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                            for arm in range(self.n_arm[m]):
                                pruning_validation_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                                pruning_validation_df.loc[epoch, f'clf_loss_arm{m}_{arm}'] = loss_clf[m][arm] / (batch_indx + 1)
                            
                        for m in self.cross_mod_pairs:
                            pruning_validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        
                        print('Validation')
                        for m in self.modalities:
                            val_1 = pruning_validation_df[f'recon_error_arm{m}_0'][epoch]
                            val_2 = pruning_validation_df[f'distance_{m}'][epoch]
                            val_3 = pruning_validation_df[f'log_distance_{m}'][epoch]
                            val_4 = pruning_validation_df[f'entropy_{m}'][epoch]
                            print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                        
                        if wandb_run:
                            wandb_run.log(
                                        {
                                        "validate/total-loss": pruning_validation_df['total_loss'][epoch],
                                        "validate/min-var": pruning_validation_df['minVar'][epoch],
                                        **dict(map(lambda m: (f"validate/rec-loss-{m}", pruning_validation_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                        **dict(map(lambda m: (f"validate/clf-acc-{m}", pruning_validation_df[f'clf_loss_arm{m}_0'][epoch]), self.modalities)),
                                        **dict(map(lambda m: (f"validate/distance-{m}", pruning_validation_df[f'distance_{m}'][epoch]), self.modalities)),
                                        **dict(map(lambda m: (f"validate/log_distance-{m}", pruning_validation_df[f'log_distance_{m}'][epoch]), self.modalities)),
                                        "validate/time": time.time() - t0,
                                        }
                                        )

                for m in self.modalities:
                    for arm in range(self.n_arm[m]):
                        if ws > 1:
                            dist.barrier()
                            # Average the weights across all processes
                            dist.all_reduce(self.model.qc[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.qc[m][arm].bias, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.mu[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.sigma[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.fc_lowD[m][arm].weight, op=dist.ReduceOp.AVG)
                           
                            with FSDP.summon_full_params(self.model, writeback=True):
                                prune.remove(self.model.qc[m][arm], 'weight')
                                prune.remove(self.model.qc[m][arm], 'bias')
                                prune.remove(self.model.mu[m][arm], 'weight')
                                prune.remove(self.model.sigma[m][arm], 'weight')
                                prune.remove(self.model.fc_lowD[m][arm], 'weight')
                            
                            dist.barrier()  
                        else:
                            prune.remove(self.model.qc[m][arm], 'weight')
                            prune.remove(self.model.qc[m][arm], 'bias')
                            prune.remove(self.model.mu[m][arm], 'weight')
                            prune.remove(self.model.sigma[m][arm], 'weight')
                            prune.remove(self.model.fc_lowD[m][arm], 'weight')
                           
                pr += 1
                # save the model and the learning curve
                trained_model = self.folder + f'/model/cpl_mixVAE_model_after_pruning_{pr}_{self.current_time}.pth'
                classifier_model = self.folder + f'/model/classifier_after_pruning_{pr}_{self.current_time}.pth'
                if ws > 1:
                    dist.barrier()
                    with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                        cpu_state = self.model.state_dict()
                    if rank == 0:
                        torch.save(cpu_state, trained_model)
                        torch.save(self.classifier.state_dict(), classifier_model)
                    dist.barrier()
                else:
                    if self.compress:
                        trained_model += '.gz'
                        save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                        save_model_with_gzip(self.classifier.state_dict(), classifier_model, compression_level=3)
                    else:
                        torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                        torch.save({'model_state_dict': self.classifier.state_dict()}, classifier_model, pickle_protocol=4)
                            
                pruning_train_df.to_csv(self.folder + f'/pruning_train_df_{pr}_{self.current_time}.txt', sep='\t', index=False)
                pruning_validation_df.to_csv(self.folder + f'/pruning_validation_df_{pr}_{self.current_time}.txt', sep='\t', index=False)
                    
            
        print('Training is done!')
        return trained_model
        

    @torch.no_grad()
    def eval_model(self, data_loader, hard=False, temp=1., c_p=None, ref_prior=False, ref_mod=''):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            data_loader: input data loader
            c_p: the prior categorical variable, only if ref_prior is True.
            c_onehot: the one-hot representation of the prior categorical variable, only if ref_prior is True.

        return
            d_dict: the output dictionary.
        """
        
        bias = self.model.qc[self.modalities[0]][0].bias.detach().cpu().numpy()
        pruning_mask = np.where(bias != 0.)[0]
        prune_indx = np.where(bias == 0.)[0]
        data_len = len(data_loader.dataset)
        self.ref_prior = ref_prior
        
        recon_loss = dict.fromkeys(self.modalities)
        loss_clf = dict.fromkeys(self.modalities)
        state_sample = dict.fromkeys(self.modalities)
        state_mu = dict.fromkeys(self.modalities)
        state_var = dict.fromkeys(self.modalities)
        c_prob = dict.fromkeys(self.modalities)
        c_sample = dict.fromkeys(self.modalities)
        predicted_label = dict.fromkeys(self.modalities)
        prob_cat = dict.fromkeys(self.modalities)
        for m in self.modalities:
            recon_loss[m] = [[] for arm in range(self.n_arm[m])]
            loss_clf[m] = [[] for arm in range(self.n_arm[m])]
            state_sample[m] = [[] for arm in range(self.n_arm[m])]
            state_mu[m] = [[] for arm in range(self.n_arm[m])]
            state_var[m] = [[] for arm in range(self.n_arm[m])]
            c_prob[m] = [[] for arm in range(self.n_arm[m])]
            c_sample[m] = [[] for arm in range(self.n_arm[m])]
            predicted_label[m] = [[] for arm in range(self.n_arm[m])]
            prob_cat[m] = [[] for arm in range(self.n_arm[m])]
        
    
        dist_qc = dict()
        dist_c = dict()
        lam = dict()
        for m in self.all_modalities:
            dist_qc[m] = []
            dist_c[m] = []
            lam[m] = 1

        samp_id = []
        n_modality = len(self.modalities)
        # evaluation of the model
        self.model.eval() 
        self.classifier.eval()
        
        for batch_indx, data_block in enumerate(data_loader):
            d_idx = data_block[-1]
            label = data_block[-2]
            label_onehot = np.full((label.shape[0], self.n_class), -1)
            valid_indices = label != -1
            label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
            label_onehot = torch.FloatTensor(label_onehot)
            
            train_data = dict.fromkeys(self.modalities)
            mod_mask = dict.fromkeys(self.all_modalities)
            for im, m in enumerate(self.modalities):
                if m == 'M':
                    data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                else:
                    data = data_block[im]
                mod_mask[m] = data_block[im + n_modality]
                train_data[m] = []
                for arm in range(self.n_arm[m]):
                    train_data[m].append(data)
            for im, (m1, m2) in enumerate(self.mod_pairs):
                mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                
            if ref_prior:
                prior_c = torch.FloatTensor(c_p[d_idx, :])
            else:
                prior_c = None

            recon_batch, _, qc, state, c, mu, log_var, _ = self.model(
                                                                        x=train_data, 
                                                                        temp=temp, 
                                                                        hard=hard, 
                                                                        variational=self.variational, 
                                                                        pruning_mask=pruning_mask,
                                                                        )     

            label_pred = dict.fromkeys(self.modalities)
            for m in self.modalities:
                label_pred[m] = [None] * self.n_arm[m]
                for arm in range(self.n_arm[m]):
                    label_pred[m][arm] = self.classifier(qc[m][arm])
                    
            _, _, l_rec, l_classifier, _, _, c_distance, c_rep_dist, _, _ = self.loss(
                                                                                recon_x=recon_batch, 
                                                                                x=train_data, 
                                                                                y_pred=label_pred,
                                                                                y=label_onehot,
                                                                                mu=mu, 
                                                                                log_sigma=log_var, 
                                                                                qc=qc, 
                                                                                c=c, 
                                                                                lam=lam,
                                                                                beta=1,
                                                                                prior_c=prior_c, 
                                                                                mask=mod_mask,
                                                                                ref_mod=ref_mod,
                                                                                )

            for m in self.modalities:
                for arm in range(self.n_arm[m]):
                    recon_loss[m][arm].append(l_rec[m][arm].data.item() / self.input_dim[m])
                    loss_clf[m][arm].append(l_classifier[m][arm].data.item())
                    state_sample[m][arm].append(state[m][arm].cpu().detach().numpy())
                    state_mu[m][arm].append(mu[m][arm].cpu().detach().numpy())
                    state_var[m][arm].append(log_var[m][arm].cpu().detach().numpy())
                    c_encoder = qc[m][arm].cpu().data.view(qc[m][arm].size()[0], self.n_categories).detach().numpy()
                    c_prob[m][arm].append(c_encoder)
                    c_samp = c[m][arm].cpu().data.view(c[m][arm].size()[0], self.n_categories).detach().numpy()
                    c_sample[m][arm].append(c_samp)
                    predicted_label[m][arm].append(np.argmax(c_encoder, axis=1) + 1)
                    prob_cat[m][arm].append(np.max(c_encoder, axis=1))
            
            for m in self.all_modalities:
                dist_qc[m].append(safe_item(c_rep_dist[m]))
                dist_c[m].append(safe_item(c_distance[m]))
            
            samp_id.append(d_idx.numpy().astype(int))

                
        samp_id = np.concatenate(samp_id, axis=0)
        total_dist_c = dict()
        total_dist_qc = dict()
        mean_total_loss_rec = dict()
        mean_loss_clf = dict()
        for m in self.modalities:
            mean_total_loss_rec[m] = np.zeros(self.n_arm[m])
            mean_loss_clf[m] = np.zeros(self.n_arm[m])
            for arm in range(self.n_arm[m]):
                state_sample[m][arm] = np.concatenate(state_sample[m][arm], axis=0)
                state_mu[m][arm] = np.concatenate(state_mu[m][arm], axis=0)
                state_var[m][arm] = np.concatenate(state_var[m][arm], axis=0)
                c_prob[m][arm] = np.concatenate(c_prob[m][arm], axis=0)
                c_sample[m][arm] = np.concatenate(c_sample[m][arm], axis=0)
                predicted_label[m][arm] = np.concatenate(predicted_label[m][arm], axis=0)
                prob_cat[m][arm] = np.concatenate(prob_cat[m][arm], axis=0)
                mean_total_loss_rec[m][arm] = np.mean(np.array(recon_loss[m][arm]))
                mean_loss_clf[m][arm] = np.mean(np.array(loss_clf[m][arm]))
            
        for m in self.all_modalities:
            total_dist_c[m] = np.mean(np.array(dist_c[m]))
            total_dist_qc[m] = np.mean(np.array(dist_qc[m]))
                
        # save the output in a dictionary
        d_dict = dict()
        d_dict['state_sample'] = state_sample
        d_dict['state_mu'] = state_mu
        d_dict['state_var'] = state_var
        d_dict['prob_cat'] = prob_cat
        d_dict['total_loss_rec'] = mean_total_loss_rec
        d_dict['loss_clf'] = mean_loss_clf
        d_dict['total_dist_z'] = total_dist_c
        d_dict['total_dist_qz'] = total_dist_qc
        d_dict['predicted_label'] = predicted_label
        d_dict['data_indx'] = samp_id
        d_dict['z_prob'] = c_prob
        d_dict['z_sample'] = c_sample
        d_dict['prune_indx'] = prune_indx

        return d_dict
    
    
    @torch.no_grad()
    def eval_clf(self, data_loader, hard=False, temp=1., c_p=None, ref_prior=False, ref_mod=''):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            data_loader: input data loader
            c_p: the prior categorical variable, only if ref_prior is True.
            c_onehot: the one-hot representation of the prior categorical variable, only if ref_prior is True.

        return
            d_dict: the output dictionary.
        """
        
        bias = self.model.qc[self.modalities[0]][0].bias.detach().cpu().numpy()
        pruning_mask = np.where(bias != 0.)[0]
        prune_indx = np.where(bias == 0.)[0]
        data_len = len(data_loader.dataset)
        self.ref_prior = ref_prior
        
        acc = dict.fromkeys(self.modalities)
        pred_label = dict.fromkeys(self.modalities)
        true_label = dict.fromkeys(self.modalities)
        for m in self.modalities:
            acc[m] = [[] for arm in range(self.n_arm[m])]
            pred_label[m] = [[] for arm in range(self.n_arm[m])]
            true_label[m] = [[] for arm in range(self.n_arm[m])]
    
        dist_qc = dict()
        dist_c = dict()
        lam = dict()
        for m in self.all_modalities:
            dist_qc[m] = []
            dist_c[m] = []
            lam[m] = 1

        samp_id = []
        n_modality = len(self.modalities)
        # evaluation of the model
        self.model.eval() 
        self.classifier.eval()
        
        for batch_indx, data_block in enumerate(data_loader):
            d_idx = data_block[-1]
            label = data_block[-2]
            label_onehot = np.full((label.shape[0], self.n_class), -1)
            valid_indices = label != -1
            label_onehot[valid_indices] = self.onehot_encoder.fit_transform(label[valid_indices].reshape(-1, 1))
            label_onehot = torch.FloatTensor(label_onehot)
            
            train_data = dict.fromkeys(self.modalities)
            mod_mask = dict.fromkeys(self.all_modalities)
            for im, m in enumerate(self.modalities):
                if m == 'M':
                    data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                else:
                    data = data_block[im]
                mod_mask[m] = data_block[im + n_modality]
                train_data[m] = []
                for arm in range(self.n_arm[m]):
                    train_data[m].append(data)
            for im, (m1, m2) in enumerate(self.mod_pairs):
                mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                

            recon_batch, _, qc, state, c, mu, log_var, _ = self.model(
                                                                        x=train_data, 
                                                                        temp=temp, 
                                                                        hard=hard, 
                                                                        variational=self.variational, 
                                                                        pruning_mask=pruning_mask,
                                                                        )     
            samp_id.append(d_idx.numpy().astype(int))
            label_pred = dict.fromkeys(self.modalities)
            for m in self.modalities:
                label_pred[m] = [None] * self.n_arm[m]
                for arm in range(self.n_arm[m]):
                    label_pred[m][arm] = self.classifier(qc[m][arm])
                    acc_ = (label_pred[m][arm].argmax(dim=1) == label_onehot.argmax(dim=1)).sum() / label_onehot.shape[0]
                    pred_label[m][arm].append(label_pred[m][arm].argmax(dim=1).cpu().detach().numpy())
                    true_label[m][arm].append(label_onehot.argmax(dim=1).cpu().detach().numpy())
                    acc[m][arm].append(acc_)
 
        samp_id = np.concatenate(samp_id, axis=0)
        for m in self.modalities:
            for arm in range(self.n_arm[m]):
                acc[m][arm] = np.mean(acc[m][arm], axis=0)
                pred_label[m][arm] = np.concatenate(pred_label[m][arm], axis=0)
                true_label[m][arm] = np.concatenate(true_label[m][arm], axis=0)

        return acc, pred_label, true_label, samp_id
    

    def save_file(self, fname, **kwargs):
        """
        Save data as a .p file using pickle.

        input args
            fname: the path of the pre-trained network.
            kwarg: keyword arguments for input variables e.g., x=[], y=[], etc.
        """

        f = open(fname + '.p', "wb")
        data = {}
        for k, v in kwargs.items():
            data[k] = v
        pickle.dump(data, f, protocol=4)
        f.close()
        

    def load_file(self, fname):
        """
        load data .p file using pickle. Make sure to use the same version of
        pcikle used for saving the file

        input args
            fname: the path of the pre-trained network.

        return
            data: a dictionary including the save dataset
        """

        data = pickle.load(open(fname + '.p', "rb"))
        return data
