import pickle
import numpy as np
import time
import itertools
import pandas as pd
import torch
import pdb
import torch.nn.utils.prune as prune
from torch.nn import functional as F
import torch.distributed as dist
import matplotlib.pyplot as plt
from torch.distributed.fsdp import (
                                    FullyShardedDataParallel as FSDP,
                                    MixedPrecision,
                                    BackwardPrefetch,
                                    ShardingStrategy,
                                    FullStateDictConfig,
                                    StateDictType,
                                    )

from .models import multimodal_mixVAE as mixVAE_model
from ..utils.deepnet_tools import zinb_error
from ..utils.deepnet_tools import safe_item
from ..utils.deepnet_tools import jaccard_distance
from ..utils.deepnet_tools import save_model_with_gzip, load_model_with_gzip
from ..utils.deepnet_tools import gumbel_softmax


class multi_mod_cpl_mixVAE:

    def __init__(self, saving_folder='', augmenter=[], device=None, eps=1e-8, save_flag=True):
        """
        Initialized the cpl_mixVAE class.

        input args:
            saving_folder: a string that indicates the folder to save the model(s) and file(s).
            augmentor: the pre-trained augmenter.
            device: computing device, either 'cpu' or 'cuda'.
            eps: a small constant value to fix computation overflow.
            save_flag: a boolean variable, if True, the model is saved.
        """

        self.eps = eps
        self.save = save_flag
        self.folder = saving_folder
        self.aug = [True if augmenter else False][0]
        self.device = device
        
        if self.aug:
            self.netA = augmenter



    def init_model(self, 
                   networks,
                   n_categories, 
                   input_dim, 
                   lowD_dim,
                   state_dim, 
                   x_drop=None, 
                   n_arm=None, 
                   noise_model='Gaussian',
                   tau=0.005, 
                   trained_model='', 
                   n_pr=0,
                   ):
        """
        Initialized the deep mixture model and its optimizer.

        input args:
            n_categories: number of categories of the latent variables.
            state_dim: dimension of the state variable.
            input_dim: input dimension (size of the input layer).
            lowD_dim: dimension of the latent representation.
            x_drop: dropout probability at the first (input) layer.
            n_arm: int value that indicates number of arms.
            tau: temperature of the softmax layers, usually equals to 1/n_categories (0 < tau <= 1).
            trained_model: a pre-trained model, in case you want to initialized the network with a pre-trained network.
            n_pr: number of pruned categories, only if you want to initialize the network with a pre-trained network.
            momentum: a hyperparameter for batch normalization that updates its running statistics.
            noise_model: the noise model of the data, either 'Gaussian' or 'ZINB'.
        """
        self.lowD_dim = lowD_dim
        self.n_categories = n_categories
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.n_arm = n_arm
        self.modalities = list(networks.keys())
        self.mod_pairs = list(itertools.combinations(self.modalities, 2))
        self.cross_mod_pairs = [''.join(pair) for pair in self.mod_pairs]
        self.all_modalities = list(self.modalities) + self.cross_mod_pairs
        
        if x_drop is None:
            x_drop = dict.fromkeys(self.modalities, 0.)
        if n_arm is None:
            n_arm = dict.fromkeys(self.modalities, 1)
            
        
        self.model = mixVAE_model(
                                modalities=self.modalities, 
                                networks=networks,
                                input_dim=input_dim, 
                                n_categories=n_categories, 
                                x_drop=x_drop, 
                                n_arm=n_arm, 
                                tau=tau, 
                                eps=self.eps,
                                noise_model=noise_model,
                                )

        self.dist = False

        if len(trained_model) > 0:
            print('Load the pre-trained model')
            # if you wish to load another model for evaluation
            self.load_model(trained_model)
            self.init = False
            self.n_pr = n_pr
        else:
            self.init = True
            self.n_pr = 0


    def load_model(self, trained_model):
        if self.compress:
            self.model = load_model_with_gzip(model=self.model, filepath=trained_model)
        else:
            loaded_file = torch.load(trained_model, map_location='cpu', weights_only=False)
            try:
                self.model.load_state_dict(loaded_file['model_state_dict'])
                # self.optimizer = torch.optim.Adam(self.model.parameters())
                # self.optimizer.load_state_dict(loaded_file['optimizer_state_dict'])
            except:
                self.model.load_state_dict(loaded_file)
    
    
    def loss(self, recon_x, x, mu, log_sigma, qc, c, lam, beta, prior_c, mask, ref_mod):
        """
        loss function of the cpl-mixVAE network including.

       input args
            recon_x: a distionary including the reconstructed data for each modality.
            x: a distionary includes original input data.
            mu: distionary of mean of the Gaussian distribution for the sate variable.
            log_sigma: log of variance of the Gaussian distribution for the sate variable.
            qc: probability of categories for all modalities.
            c: samples fom all distrubtions for all modalities.
            lam: coupling factor in the cpl-mixVAE model.
            beta: regularizer for the KL divergence term.

        return
            total_loss: total loss value.
            l_rec: reconstruction loss for each modality.
            loss_joint: coupling loss.
            neg_joint_entropy: negative joint entropy of the categorical variable.
            qc_distance: distance between a pair of categorical distributions, i.e. qc_a & qc_b.
            c_distance: Euclidean distance between a pair of categorical variables, i.e. c_a & c_b.
            KLD: dictionary of KL divergences for the state variables across all modalities.
            var_a.min(): minimum variance of the modalities.

        """
        l_rec = dict.fromkeys(self.modalities)
        loss_indep = dict.fromkeys(self.modalities)
        KLD_cont = dict.fromkeys(self.modalities)
        neg_joint_entropy = dict.fromkeys(self.all_modalities)
        z_distance_rep = dict.fromkeys(self.all_modalities)
        z_distance = dict.fromkeys(self.all_modalities)
        loss_joint = dict.fromkeys(self.all_modalities)
        loss = dict.fromkeys(self.all_modalities)

        loss_ind = 0.
        for m in self.modalities:
            l_rec[m] = [None] * self.n_arm[m]
            loss_indep[m] = [None] * self.n_arm[m]
            KLD_cont[m] = [None] * self.n_arm[m] 
            var_qc_inv, log_qc = [None] * 2, [None] * 2
            neg_joint_entropy[m], z_distance_rep[m], z_distance[m] = [], [], []
            
            for arm_a in range(self.n_arm[m]):
                recon_x_ = recon_x[m][arm_a][mask[m]]
                x_ = x[m][arm_a][mask[m]]
                
                x_bin = 0. * x_
                x_bin[x_ > self.eps] = 1.
                rec_bin = 0. * recon_x_
                rec_bin[recon_x_ > self.eps] = 1.
                if self.model.noise_model == 'Gaussian':
                    l_rec[m][arm_a] = F.mse_loss(recon_x_, x_, reduction='sum') / (x_.size(0))
                    l_rec[m][arm_a] += F.binary_cross_entropy(rec_bin, x_bin)

                if self.variational:
                    log_sigma_ = log_sigma[m][arm_a][mask[m]]    
                    mu_ = mu[m][arm_a][mask[m]]
                    KLD_cont[m][arm_a] = (-0.5 * torch.mean(1 + log_sigma_ - mu_.pow(2) - log_sigma_.exp(), dim=0)).sum()
                    loss_indep[m][arm_a] = l_rec[m][arm_a] + beta * KLD_cont[m][arm_a]
                else:
                    loss_indep[m][arm_a] = l_rec[m][arm_a]
                    KLD_cont[m][arm_a] = [0.]

                qc_a = qc[m][arm_a][mask[m]]
                log_qc[0] = torch.log(qc_a + self.eps)
                var_qc0 = qc_a.var(0)

                var_qc_inv[0] = (1 / (var_qc0 + self.eps)).repeat(qc_a.size(0), 1).sqrt()
                
                for arm_b in range(arm_a + 1, self.n_arm[m]):
                    qc_b = qc[m][arm_b][mask[m]]
                    log_qc[1] = torch.log(qc_b + self.eps)
                    tmp_entropy = (torch.sum(qc_a * log_qc[0], dim=-1)).mean() + (torch.sum(qc_b * log_qc[1], dim=-1)).mean()
                    neg_joint_entropy[m].append(tmp_entropy)
                    # var = qc[arm_b].var(0)
                    var_qc1 = qc_b.var(0)
                    var_qc_inv[1] = (1 / (var_qc1 + self.eps)).repeat(qc_b.size(0), 1).sqrt()

                    # distance between z_1 and z_2 i.e., ||z_1 - z_2||^2
                    # Euclidean distance
                    z_distance_rep[m].append((torch.norm((c[m][arm_a][mask[m]] - c[m][arm_b][mask[m]]), p=2, dim=1).pow(2)).mean())
                    z_distance[m].append((torch.norm((log_qc[0] * var_qc_inv[0]) - (log_qc[1] * var_qc_inv[1]), p=2, dim=1).pow(2)).mean())
                
                
                # if the prior categorical variable is provided
                if self.ref_prior:
                    n_comb = max(self.n_arm[m] * (self.n_arm[m] + 1) / 2, 1)
                    scaler = self.n_arm[m]
                    # distance between z_1 and z_2 i.e., ||z_1 - z_2||^2
                    # Euclidean distance
                    z_distance_rep[m].append((torch.norm((c[m][arm_a][mask[m]] - prior_c[mask[m]]), p=2, dim=1).pow(2)).mean())
                    tmp_entropy = (torch.sum(qc[m][arm_a][mask[m]] * log_qc[0], dim=-1)).mean()
                    neg_joint_entropy[m].append(tmp_entropy)
                    qc_bin = gumbel_softmax(qc[m][arm_a][mask[m]], 1, self.n_categories, 1, hard=True, gumble_noise=False)
                    # z_distance[m].append(self.lam_pc * F.binary_cross_entropy(qc_bin, prior_c[mask[m]]))
                    z_distance[m].append(z_distance_rep[m][-1])
                else:
                    n_comb = max(self.n_arm[m] * (self.n_arm[m] - 1) / 2, 1)
                    scaler = max((self.n_arm[m] - 1), 1)
            
            loss_joint[m] = lam[m] * sum(z_distance[m]) + sum(neg_joint_entropy[m]) +  ((self.n_categories / 2) * (np.log(2 * np.pi)) - 0.5 * np.log(2 * lam[m])) 
            loss[m] = sum(loss_indep[m]) + loss_joint[m]
            
        for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
            mn = self.cross_mod_pairs[im]
            pair_mask = mask[self.cross_mod_pairs[im]]
            neg_joint_entropy[mn] = 0.
            z_distance[mn] = 0.
            z_distance_rep[mn] = 0.
            iter = 0.
            for arm_a in range(self.n_arm[mod_1]):
                qc_1 = qc[mod_1][arm_a][pair_mask]
                log_qc_1 = torch.log(qc_1 + self.eps)
                c_1 = c[mod_1][arm_a][pair_mask]
                var_qc = qc_1.var(0)
                var_qc_inv_1 = (1 / (var_qc + self.eps)).repeat(qc_1.size(0), 1).sqrt()
                for arm_b in range(self.n_arm[mod_2]):
                    qc_2 = qc[mod_2][arm_b][pair_mask]
                    log_qc_2 = torch.log(qc_2 + self.eps)
                    c_2 = c[mod_2][arm_b][pair_mask]
                    var_qc = qc_2.var(0)
                    var_qc_inv_2 = (1 / (var_qc + self.eps)).repeat(qc_2.size(0), 1).sqrt()

                    if ref_mod == mod_1:
                        qc_prior = qc_1.clone().detach()
                        neg_joint_entropy[mn] += (torch.sum(qc_2 * log_qc_2, dim=-1)).mean()
                        qc_bin = gumbel_softmax(qc_2, 1, self.n_categories, 1, hard=True, gumble_noise=False)
                        z_distance[mn] += F.binary_cross_entropy(qc_bin, qc_prior)
                    
                    elif ref_mod == mod_2:
                        qc_prior = qc_2.clone().detach()
                        neg_joint_entropy[mn] += (torch.sum(qc_1 * log_qc_1, dim=-1)).mean()
                        qc_bin = gumbel_softmax(qc_1, 1, self.n_categories, 1, hard=True, gumble_noise=False)
                        z_distance[mn] += F.binary_cross_entropy(qc_bin, qc_prior)
                        
                    else:
                        neg_joint_entropy[mn] += (torch.sum(qc_1 * log_qc_1, dim=-1)).mean() + (torch.sum(qc_2 * log_qc_2, dim=-1)).mean()
                        # distance between c_rna and c_atac i.e., ||c_rna - c_atac||^2
                        z_distance[mn] += (torch.norm((log_qc_1 * var_qc_inv_1) - (log_qc_2 * var_qc_inv_2), p=2, dim=1).pow(2)).mean()
                    
                    # Euclidean distance (only for monitoriing purpose)
                    z_distance_rep[mn] += (torch.norm((c_1 - c_2), p=2, dim=1).pow(2)).mean()
                    iter += 1
            
            neg_joint_entropy[mn] /= iter
            z_distance[mn] /= iter
            z_distance_rep[mn] /= iter
            loss_joint[mn] = lam[mn] * z_distance[mn] + neg_joint_entropy[mn] +  ((self.n_categories / 2) * (np.log(2 * np.pi)) - 0.5 * np.log(2 * lam[mn])) 
            loss[mn] = (pair_mask.shape[0] / pair_mask.sum()) *  loss_joint[mn]

        var_min = min((qc[self.modalities[0]][self.arm_paired][pair_mask].var(0).min(), qc[self.modalities[1]][self.arm_paired][pair_mask].var(0).min()))
        total_loss = sum(loss.values())
        return total_loss, loss, l_rec, loss_joint, neg_joint_entropy, z_distance, z_distance_rep, KLD_cont, var_min



    def train(self, train_loader, test_loader, n_epoch, n_epoch_p, lr=1e-3, min_con=.5, max_prun_it=0, ref_prior=False, temp=1., ws=1,
              hard=False, variational=True, c_p=None, lam=1., beta=1., arm_paired=0, world_size=1, wandb_run=None, compress=False, rank=None, 
              add_noise=False, noise_std=0.1, ref_mod=''):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            train_loader: train dataloader.
            test_loader: test dataloader.
            n_epoch: number of training epoch, without pruning.
            n_epoch_p: number of training epoch, with pruning.
            lr: the learning rate of the optimizer, here Adam.
            c_p: the prior categorical variable, only if ref_prior is True.
            c_onehot: the one-hot representation of the prior categorical variable, only if ref_prior is True.
            min_con: minimum value of consensus among pair of arms.
            max_prun_it: maximum number of pruning iterations.
            temp: temperature of sampling
            hard: a boolean variable, True uses one-hot method that is used in Gumbel-softmax, and False uses the Gumbel-softmax function.
            variational: a boolean variable for variational mode, False mode does not use sampling.

        return
            data_file_id: the output dictionary.
        """
        # define current_time
        self.current_time = time.strftime('%Y-%m-%d-%H-%M-%S')
        self.variational = variational
        self.temp = temp
        self.ref_prior = ref_prior
        self.lr = lr
        self.arm_paired = arm_paired
        self.n_modality = len(self.modalities)
        
        if rank is None:
            rank = self.device
        
        # make a preformance dataframes for training, validation, and testing 
        df_keys = ['total_loss', 'minVar']
        for m in self.modalities:
            df_keys.append(f'loss_joint_{m}')
            df_keys.append(f'entropy_{m}')
            df_keys.append(f'distance_{m}')
            df_keys.append(f'log_distance_{m}')
            for arm in range(self.n_arm[m]):
                df_keys.append(f'recon_error_arm{m}_{arm}')

        for m in self.cross_mod_pairs:
            df_keys.append(f'loss_joint_{m}')
            df_keys.append(f'entropy_{m}')
            df_keys.append(f'distance_{m}')
            df_keys.append(f'log_distance_{m}')

        train_df = pd.DataFrame(np.zeros((n_epoch, len(df_keys))), columns=df_keys)
        validation_df = pd.DataFrame(np.zeros((n_epoch, len(df_keys))), columns=df_keys)
        
        # initialize the model parameters for pruning
        bias_mask = torch.ones(self.n_categories)
        weight_mask = dict.fromkeys(self.modalities)
        fc_mu = dict.fromkeys(self.modalities)
        fc_sigma = dict.fromkeys(self.modalities)
        lowD_mask = dict.fromkeys(self.modalities)
        for m in self.modalities:
            weight_mask[m] = torch.ones((self.n_categories, self.lowD_dim[m]))
            fc_mu[m] = torch.ones((self.state_dim[m], self.n_categories + self.lowD_dim[m]))
            fc_sigma[m] = torch.ones((self.state_dim[m], self.n_categories + self.lowD_dim[m]))
            lowD_mask[m] = torch.ones((self.lowD_dim[m], self.state_dim[m] + self.n_categories))
        
        batch_size = train_loader.batch_size
        self.compress = compress
        self.model = self.model.to(rank)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        print(f"Model is on device: {next(self.model.parameters()).device}")
        
        if self.aug:
            self.netA.to(rank)
            
        # training the model without pruning
        if self.init:
            print("Start training ...")
            for epoch in range(n_epoch):
                mod_loss_val = {key: 0 for key in self.all_modalities} 
                jointloss_val = {key: 0 for key in self.all_modalities} 
                dist_qc = {key: 0 for key in self.all_modalities} 
                log_dist_qc = {key: 0 for key in self.all_modalities} 
                entr = {key: 0 for key in self.all_modalities} 
                loss_rec = {key: 0 for key in self.all_modalities} 
                for m in self.modalities:
                    loss_rec[m] = np.zeros(self.n_arm[m])
                var_min = 0
                loss_val = 0
                self.model.train()
                t0 = time.time()
                for batch_indx, data_block in enumerate(train_loader):
                    d_idx = data_block[-1]
                    train_data = dict.fromkeys(self.modalities)
                    mod_mask = dict.fromkeys(self.all_modalities)
                    for im, m in enumerate(self.modalities):
                        if m == 'M':
                            data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                        else:
                            data = data_block[im]
                        mod_mask[m] = data_block[im + self.n_modality]
                        train_data[m] = []
                        for arm in range(self.n_arm[m]):
                            data = data.to(rank)
                            if self.aug:
                                if arm == 0:
                                    mod_mask[m] = torch.concat((mod_mask[m], mod_mask[m]), 0)
                                if m == 'T':
                                    _, gen_data = self.netA(data, True, noise_std)
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                                elif m == 'M': 
                                    noise = torch.distributions.Exponential(1/0.5).sample(data.shape).to(rank)  
                                    gen_data = data.clone()
                                    tmp_mask = gen_data > 0
                                    gen_data[tmp_mask] += noise[tmp_mask]
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                                else:
                                    gen_data = data + torch.randn(data.shape).to(rank) * noise_std
                                    train_data[m].append(torch.concat((data, gen_data), 0))
                            else:
                                train_data[m].append(data)
                            
                    for im, (m1, m2) in enumerate(self.mod_pairs):
                        mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]

                    if self.ref_prior:
                        prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                    else:
                        prior_c = None

                    self.optimizer.zero_grad()
                    recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                            x=train_data, 
                                                                            temp=self.temp, 
                                                                            hard=hard, 
                                                                            variational=variational,
                                                                            )
                    loss, loss_dict, l_rec, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                    recon_x=recon_batch, 
                                                                                                                    x=train_data, 
                                                                                                                    mu=mu, 
                                                                                                                    log_sigma=log_var, 
                                                                                                                    qc=qc, 
                                                                                                                    c=c, 
                                                                                                                    lam=lam,
                                                                                                                    beta=beta,
                                                                                                                    prior_c=prior_c, 
                                                                                                                    mask=mod_mask,
                                                                                                                    ref_mod=ref_mod,
                                                                                                                    )
                                                                                                                  
                    loss.backward()
                    self.optimizer.step()
                    var_min += min_var.data.item()
                    loss_val += loss.data.item()
                    
                    for m in self.modalities:
                        for arm in range(self.n_arm[m]):
                            loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                    
                    for m in self.all_modalities:
                        mod_loss_val[m] += safe_item(loss_dict[m])
                        jointloss_val[m] += safe_item(loss_joint[m])
                        log_dist_qc[m] += safe_item(c_distance[m])
                        dist_qc[m] += safe_item(c_rep_dist[m])
                        entr[m] += safe_item(neg_entropy[m])
                
                # if self.dist:
                #     dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(loss_rec, op=dist.ReduceOp.SUM)
                #     dist.all_reduce(c_dist, op=dist.ReduceOp.SUM)
                
                train_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                train_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                for m in self.modalities:
                    train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                    for arm in range(self.n_arm[m]):
                        train_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                       
                for m in self.cross_mod_pairs:
                    train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                    train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
            
                print('Training')
                print('Epoch:{}, Elapsed Time:{:.2f}'.format(epoch, time.time() - t0))
                print('====> Total Loss: {:.4f}, minVar: {:.4f}'.format(train_df['total_loss'][epoch], train_df['minVar'][epoch]))
                for m in self.modalities:
                    val_1 = train_df[f'recon_error_arm{m}_0'][epoch]
                    val_2 = train_df[f'distance_{m}'][epoch]
                    val_3 = train_df[f'log_distance_{m}'][epoch]
                    val_4 = train_df[f'entropy_{m}'][epoch]
                    print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                for m in self.cross_mod_pairs:
                    val_1 = train_df[f'distance_{m}'][epoch]
                    val_2 = train_df[f'log_distance_{m}'][epoch]
                    val_3 = train_df[f'entropy_{m}'][epoch]
                    print('====> {}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3))
                    
                if wandb_run:
                    wandb_run.log(
                                {
                                "train/total-loss": train_df['total_loss'][epoch],
                                "train/min-var": train_df['minVar'][epoch],
                                **dict(map(lambda m: (f"train/rec-loss-{m}", train_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                **dict(map(lambda m: (f"train/distance-{m}", train_df[f'distance_{m}'][epoch]), self.all_modalities)),
                                **dict(map(lambda m: (f"train/log_distance-{m}", train_df[f'log_distance_{m}'][epoch]), self.all_modalities)),
                                "train/time": time.time() - t0,
                                }
                                )
                # validation step
                self.model.eval()
                with torch.no_grad():
                    mod_loss_val = {key: 0 for key in self.all_modalities} 
                    jointloss_val = {key: 0 for key in self.all_modalities} 
                    dist_qc = {key: 0 for key in self.all_modalities} 
                    log_dist_qc = {key: 0 for key in self.all_modalities} 
                    entr = {key: 0 for key in self.all_modalities} 
                    loss_rec = dict.fromkeys(self.modalities)
                    for m in self.modalities:
                        loss_rec[m] = np.zeros(self.n_arm[m])
                        
                    var_min = 0
                    loss_val = 0
                    for batch_indx, data_block in enumerate(test_loader):
                        d_idx = data_block[-1]
                        val_data = dict.fromkeys(self.modalities)
                        mod_mask = dict.fromkeys(self.all_modalities)
                        for im, m in enumerate(self.modalities):
                            if m == 'M':
                                data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                            else:
                                data = data_block[im]
                            mod_mask[m] = data_block[im + self.n_modality]
                            val_data[m] = []
                            for arm in range(self.n_arm[m]):
                                val_data[m].append(data.to(rank))
                        for im, (m1, m2) in enumerate(self.mod_pairs):
                            mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                        if self.ref_prior:
                            prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                        else:
                            prior_c = None
                            
                        recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                x=val_data, 
                                                                                temp=self.temp, 
                                                                                hard=hard, 
                                                                                variational=variational, 
                                                                                eval=True,
                                                                                )
                        loss, loss_dict, l_rec, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                        recon_x=recon_batch, 
                                                                                                                        x=val_data, 
                                                                                                                        mu=mu, 
                                                                                                                        log_sigma=log_var, 
                                                                                                                        qc=qc, 
                                                                                                                        c=c, 
                                                                                                                        lam=lam,
                                                                                                                        beta=beta,
                                                                                                                        prior_c=prior_c, 
                                                                                                                        mask=mod_mask,
                                                                                                                        ref_mod=ref_mod,
                                                                                                                        )
                        var_min += min_var.data.item()
                        loss_val += loss.data.item()
                        
                        for m in self.modalities:
                            for arm in range(self.n_arm[m]):
                                loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                        
                        for m in self.all_modalities:
                            mod_loss_val[m] += safe_item(loss_dict[m])
                            jointloss_val[m] += safe_item(loss_joint[m])
                            log_dist_qc[m] += safe_item(c_distance[m])
                            dist_qc[m] += safe_item(c_rep_dist[m])
                            entr[m] += safe_item(neg_entropy[m])

                    validation_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                    validation_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                    for m in self.modalities:
                        validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        for arm in range(self.n_arm[m]):
                            validation_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                        
                    for m in self.cross_mod_pairs:
                        validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        
                    print('Validation')
                    for m in self.modalities:
                        val_1 = validation_df[f'recon_error_arm{m}_0'][epoch]
                        val_2 = validation_df[f'distance_{m}'][epoch]
                        val_3 = validation_df[f'log_distance_{m}'][epoch]
                        val_4 = validation_df[f'entropy_{m}'][epoch]
                        print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                    
                    if wandb_run:
                        wandb_run.log(
                                    {
                                    "validate/total-loss": validation_df['total_loss'][epoch],
                                    "validate/min-var": validation_df['minVar'][epoch],
                                    **dict(map(lambda m: (f"validate/rec-loss-{m}", validation_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"validate/distance-{m}", validation_df[f'distance_{m}'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"validate/log_distance-{m}", validation_df[f'log_distance_{m}'][epoch]), self.modalities)),
                                    "validate/time": time.time() - t0,
                                    }
                                    )

                if self.save and n_epoch > 10 and epoch // 1000 == 0:
                    trained_model = self.folder + f'/model/cpl_mixVAE_model_{self.current_time}.pth'
                    if ws > 1:
                        dist.barrier()
                        save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                        with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                            cpu_state = self.model.state_dict()
                        if rank == 0:
                            torch.save(cpu_state, trained_model)
                        dist.barrier()
                    else:
                        if self.compress:
                            trained_model += '.gz'
                            save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                        else:
                            torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                            
                        prune_indx = []
                        # save train and validation dataframes as text file
                        train_df.to_csv(self.folder + f'/train_df_{self.current_time}.txt', sep='\t', index=False)
                        validation_df.to_csv(self.folder + f'/validation_df_{self.current_time}.txt', sep='\t', index=False)

            # save the model and the learning curve
            if self.save and n_epoch > 0:
                trained_model = self.folder + f'/model/cpl_mixVAE_model_before_pruning_{self.current_time}.pth'
                if ws > 1:
                    dist.barrier()
                    save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
                    with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                        cpu_state = self.model.state_dict()
                    if rank == 0:
                        torch.save(cpu_state, trained_model)
                    dist.barrier()
                else:
                    if self.compress:
                        trained_model += '.gz'
                        save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                    else:
                        torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                        
                    prune_indx = []
                    # save train and validation dataframes as text file
                    train_df.to_csv(self.folder + f'/train_df_{self.current_time}.txt', sep='\t', index=False)
                    validation_df.to_csv(self.folder + f'/validation_df_{self.current_time}.txt', sep='\t', index=False)
                
    
        if ws > 1:
            #Ensure all ranks reach this point
            dist.barrier()
        # training the model with pruning
        if n_epoch_p > 0:
            # initialized pruning parameters of the layer of the discrete variable
            bias = self.model.qc[self.modalities[0]][0].bias.detach().cpu().numpy()
            pruning_mask = np.where(bias != 0.)[0]
            prune_indx = np.where(bias == 0.)[0]
            stop_prune = False
        else:
            stop_prune = True

        pr = self.n_pr
        ind = []
        while not stop_prune:
            if ref_mod:
                predicted_label = [None] * self.n_arm[ref_mod]
            else:
                predicted_label = {key: [] for key in self.modalities}
            # Assessment over all dataset
            self.model.eval()
            with torch.no_grad():
                for i, data_block in enumerate(train_loader):
                    d_idx = data_block[-1]
                    data_dict = dict.fromkeys(self.modalities)
                    mod_mask = dict.fromkeys(self.all_modalities)
                    for im, m in enumerate(self.modalities):
                        if m == 'M':
                            data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                        else:
                            data = data_block[im]
                        mod_mask[m] = data_block[im + self.n_modality]
                        data_dict[m] = []
                        for arm in range(self.n_arm[m]):
                            data_dict[m].append(data.to(rank))
                            
                    for im, (m1, m2) in enumerate(self.mod_pairs):
                        mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                    
                    if self.ref_prior:
                        prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                    else:
                        prior_c = None

                    _, _, qc, _, _, _, _, _ = self.model(x=data_dict, temp=self.temp, hard=hard, variational=variational, eval=True, pruning_mask=pruning_mask)

                    if ref_mod:
                        for arm in range(self.n_arm[ref_mod]):
                            c_encoder = qc[ref_mod][arm].data.view(qc[ref_mod][arm].size()[0], self.n_categories)
                            predicted_label[arm].append(c_encoder[mod_mask[ref_mod]].argmax(dim=1).detach().cpu().numpy())
                    else:
                        for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
                            c_encoder = qc[mod_1][self.arm_paired].data.view(qc[mod_1][self.arm_paired].size()[0], self.n_categories)
                            predicted_label[mod_1].append(c_encoder[mod_mask[self.cross_mod_pairs[im]]].argmax(dim=1).detach().cpu().numpy())
                            c_encoder = qc[mod_2][self.arm_paired].data.view(qc[mod_2][self.arm_paired].size()[0], self.n_categories)
                            predicted_label[mod_2].append(c_encoder[mod_mask[self.cross_mod_pairs[im]]].argmax(dim=1).detach().cpu().numpy())

            
            if ws > 1:
                #Ensure all ranks reach this point
                dist.barrier()
            
            if ws == 1 or (ws > 1 and rank == 0):
                c_agreement = []
                if ref_mod:
                    for arm_a in range(self.n_arm[ref_mod]):
                        pred_a = predicted_label[arm_a]
                        for arm_b in range(arm_a + 1, self.n_arm):
                            pred_b = predicted_label[arm_b]
                            armA_vs_armB = np.zeros((self.n_categories, self.n_categories))

                            for samp in range(pred_a.shape[0]):
                                armA_vs_armB[pred_a[samp].astype(int), pred_b[samp].astype(int)] += 1

                            num_samp_arm = []
                            for ij in range(self.n_categories):
                                sum_row = armA_vs_armB[ij, :].sum()
                                sum_column = armA_vs_armB[:, ij].sum()
                                num_samp_arm.append(max(sum_row, sum_column))

                            armA_vs_armB = np.divide(armA_vs_armB, np.array(num_samp_arm), out=np.zeros_like(armA_vs_armB),
                                                    where=np.array(num_samp_arm) != 0)
                            c_agreement.append(np.diag(armA_vs_armB))
                            ind_sort = np.argsort(c_agreement[-1])

                            # plot the consensus matrix
                            plt.figure()
                            plt.imshow(armA_vs_armB[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
                            plt.colorbar()
                            plt.xlabel('arm_' + str(arm_a), fontsize=20)
                            plt.xticks(range(self.n_categories), range(self.n_categories))
                            plt.yticks(range(self.n_categories), range(self.n_categories))
                            plt.ylabel('arm_' + str(arm_b), fontsize=20)
                            plt.xticks([])
                            plt.yticks([])
                            plt.title('|c|=' + str(self.n_categories), fontsize=20)
                            plt.savefig(self.folder + f'/consensus_{pr}_{ref_mod}_arm_{arm_a}_arm_{arm_b}.png', dpi=600)
                            plt.close("all")

                else:
                    for im, (mod_1, mod_2) in enumerate(self.mod_pairs):
                        pred_a = np.concatenate(predicted_label[mod_1])
                        pred_b = np.concatenate(predicted_label[mod_2])
                        armA_vs_armB = np.zeros((self.n_categories, self.n_categories))

                        for samp in range(pred_a.shape[0]):
                            armA_vs_armB[pred_a[samp].astype(int), pred_b[samp].astype(int)] += 1

                        num_samp_arm = []
                        for ij in range(self.n_categories):
                            sum_row = armA_vs_armB[ij, :].sum()
                            sum_column = armA_vs_armB[:, ij].sum()
                            num_samp_arm.append(max(sum_row, sum_column))

                        armA_vs_armB = np.divide(armA_vs_armB, np.array(num_samp_arm), out=np.zeros_like(armA_vs_armB),
                                                where=np.array(num_samp_arm) != 0)
                        c_agreement.append(np.diag(armA_vs_armB))
                        ind_sort = np.argsort(c_agreement[-1])

                        # plot the consensus matrix
                        plt.figure()
                        plt.imshow(armA_vs_armB[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
                        plt.colorbar()
                        plt.xlabel(f'arm_{mod_1}', fontsize=20)
                        plt.xticks(range(self.n_categories), range(self.n_categories))
                        plt.yticks(range(self.n_categories), range(self.n_categories))
                        plt.ylabel(f'arm_{mod_2}', fontsize=20)
                        plt.xticks([])
                        plt.yticks([])
                        plt.title('|c|=' + str(self.n_categories), fontsize=20)
                        plt.savefig(self.folder + f'/consensus_{pr}_{mod_1}_vs_{mod_2}.png', dpi=600)
                        plt.close("all")

                # consensus among arms for each pair of arms 
                c_agreement = np.mean(c_agreement, axis=0)
                agreement = c_agreement[pruning_mask]
                if (np.min(agreement) <= min_con) and pr < max_prun_it:
                    if pr > 0:
                        ind_min = pruning_mask[np.argmin(agreement)]
                        ind_min = np.array([ind_min])
                        ind = np.concatenate((ind, ind_min))
                    else:
                        ind_min = pruning_mask[np.argmin(agreement)]
                        if len(prune_indx) > 0:
                            ind_min = np.array([ind_min])
                            ind = np.concatenate((prune_indx, ind_min))
                        else:
                            ind.append(ind_min)
                        ind = np.array(ind)

                    ind = ind.astype(int)
                    bias_mask[ind] = 0.
                    for m in self.modalities:
                        weight_mask[m][ind, :] = 0.
                        fc_mu[m][:, self.lowD_dim[m] + ind] = 0.
                        fc_sigma[m][:, self.lowD_dim[m] + ind] = 0.
                        lowD_mask[m][:, ind] = 0.
                        
                    stop_prune = False
                else:
                    print('No more pruning!')
                    stop_prune = True

            # continue the training with pruning
            if not stop_prune:
                print("Continue training with pruning ...")
                print(f"Pruned categories: {ind}")
                bias = bias_mask.detach().cpu().numpy()
                pruning_mask = np.where(bias != 0.)[0]
                pruning_train_df = pd.DataFrame(np.zeros((n_epoch_p, len(df_keys))), columns=df_keys)
                pruning_validation_df = pd.DataFrame(np.zeros((n_epoch_p, len(df_keys))), columns=df_keys)
                # prune the model based on the consensus
                for m in self.modalities:
                    for arm in range(self.n_arm[m]):
                        if ws > 1:
                            dist.barrier()
                            # Average the weights across all processes
                            dist.all_reduce(self.model.qc[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.qc[m][arm].bias, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.mu[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.sigma[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.fc_lowD[m][arm].weight, op=dist.ReduceOp.AVG)
                        
                            with FSDP.summon_full_params(self.model, writeback=True):
                                prune.custom_from_mask(self.model.qc[m][arm], 'weight', mask=weight_mask[m].to(rank))
                                prune.custom_from_mask(self.model.qc[m][arm], 'bias', mask=bias_mask.to(rank))
                                prune.custom_from_mask(self.model.mu[m][arm], 'weight', mask=fc_mu[m].to(rank))
                                prune.custom_from_mask(self.model.sigma[m][arm], 'weight', mask=fc_sigma[m].to(rank))
                                prune.custom_from_mask(self.model.fc_lowD[m][arm], 'weight', mask=lowD_mask[m].to(rank))
                                
                            dist.barrier()  
                        else:
                            prune.custom_from_mask(self.model.qc[m][arm], 'weight', mask=weight_mask[m].to(rank))
                            prune.custom_from_mask(self.model.qc[m][arm], 'bias', mask=bias_mask.to(rank))
                            prune.custom_from_mask(self.model.mu[m][arm], 'weight', mask=fc_mu[m].to(rank))
                            prune.custom_from_mask(self.model.sigma[m][arm], 'weight', mask=fc_sigma[m].to(rank))
                            prune.custom_from_mask(self.model.fc_lowD[m][arm], 'weight', mask=lowD_mask[m].to(rank))
                          
                
                # FSDP.rewrap(self.model) # not sure about this line!
                for epoch in range(n_epoch_p):
                    # training
                    mod_loss_val = {key: 0 for key in self.all_modalities} 
                    jointloss_val = {key: 0 for key in self.all_modalities} 
                    dist_qc = {key: 0 for key in self.all_modalities} 
                    log_dist_qc = {key: 0 for key in self.all_modalities} 
                    entr = {key: 0 for key in self.all_modalities} 
                    loss_rec = {key: 0 for key in self.all_modalities} 
                    for m in self.modalities:
                        loss_rec[m] = np.zeros(self.n_arm[m])
                    var_min = 0
                    loss_val = 0
                    self.model.train()
                    t0 = time.time()
                    for batch_indx, data_block in enumerate(train_loader):
                        d_idx = data_block[-1]
                        train_data = dict.fromkeys(self.modalities)
                        mod_mask = dict.fromkeys(self.all_modalities)
                        for im, m in enumerate(self.modalities):
                            if m == 'M':
                                data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                            else:
                                data = data_block[im]
                            mod_mask[m] = data_block[im + self.n_modality]
                            train_data[m] = []
                            for arm in range(self.n_arm[m]):
                                data = data.to(rank)
                                if self.aug:
                                    if arm == 0:
                                        mod_mask[m] = torch.concat((mod_mask[m], mod_mask[m]), 0)
                                    if m == 'T':
                                        _, gen_data = self.netA(data, True, noise_std)
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                    elif m == 'M': 
                                        noise = torch.distributions.Exponential(1/0.5).sample(data.shape).to(rank)  
                                        gen_data = data.clone()
                                        tmp_mask = gen_data > 0
                                        gen_data[tmp_mask] += noise[tmp_mask]
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                    else:
                                        gen_data = data + torch.randn(data.shape).to(rank) * noise_std
                                        train_data[m].append(torch.concat((data, gen_data), 0))
                                        
                                else:
                                    train_data[m].append(data)
                            
                        for im, (m1, m2) in enumerate(self.mod_pairs):
                            mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                            
                        if self.ref_prior:
                            prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                        else:
                            prior_c=None

                        self.optimizer.zero_grad()
                        recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                x=train_data, 
                                                                                temp=self.temp, 
                                                                                hard=hard, 
                                                                                variational=variational, 
                                                                                pruning_mask=pruning_mask,
                                                                                )
                        loss, loss_dict, l_rec, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                        recon_x=recon_batch, 
                                                                                                                        x=train_data, 
                                                                                                                        mu=mu, 
                                                                                                                        log_sigma=log_var, 
                                                                                                                        qc=qc, 
                                                                                                                        c=c, 
                                                                                                                        lam=lam,
                                                                                                                        beta=beta,
                                                                                                                        prior_c=prior_c, 
                                                                                                                        mask=mod_mask,
                                                                                                                        ref_mod=ref_mod,
                                                                                                                        )
                        loss.backward()
                        self.optimizer.step()
                        var_min += min_var.data.item()
                        loss_val += loss.data.item()
                        
                        for m in self.modalities:
                            for arm in range(self.n_arm[m]):
                                loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                        
                        for m in self.all_modalities:
                            mod_loss_val[m] += safe_item(loss_dict[m])
                            jointloss_val[m] += safe_item(loss_joint[m])
                            log_dist_qc[m] += safe_item(c_distance[m])
                            dist_qc[m] += safe_item(c_rep_dist[m])
                            entr[m] += safe_item(neg_entropy[m])

                    pruning_train_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                    pruning_train_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                    for m in self.modalities:
                        pruning_train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        for arm in range(self.n_arm[m]):
                            pruning_train_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                        
                    for m in self.cross_mod_pairs:
                        pruning_train_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                        pruning_train_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                
                    
                    print('Training with pruning')
                    print('====> Pruning Epoch:{}, Elapsed Time:{:.2f}'.format(epoch, time.time() - t0))
                    print('====> Total Loss: {:.4f}, minVar: {:.4f}'.format(pruning_train_df['total_loss'][epoch], pruning_train_df['minVar'][epoch]))
                    for m in self.modalities:
                        val_1 = pruning_train_df[f'recon_error_arm{m}_0'][epoch]
                        val_2 = pruning_train_df[f'distance_{m}'][epoch]
                        val_3 = pruning_train_df[f'log_distance_{m}'][epoch]
                        val_4 = pruning_train_df[f'entropy_{m}'][epoch]
                        print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                    for m in self.cross_mod_pairs:
                        val_1 = pruning_train_df[f'distance_{m}'][epoch]
                        val_2 = pruning_train_df[f'log_distance_{m}'][epoch]
                        val_3 = pruning_train_df[f'entropy_{m}'][epoch]
                        print('====> {}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3))

                    if wandb_run:
                        wandb_run.log(
                                    {
                                    "train/total-loss": pruning_train_df['total_loss'][epoch],
                                    "train/min-var": pruning_train_df['minVar'][epoch],
                                    **dict(map(lambda m: (f"train/rec-loss-{m}", pruning_train_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                    **dict(map(lambda m: (f"train/distance-{m}", pruning_train_df[f'distance_{m}'][epoch]), self.all_modalities)),
                                    **dict(map(lambda m: (f"train/log_distance-{m}", pruning_train_df[f'log_distance_{m}'][epoch]), self.all_modalities)),
                                    "train/time": time.time() - t0,
                                    }
                                    )
                    # validation step
                    self.model.eval()
                    with torch.no_grad():
                        mod_loss_val = {key: 0 for key in self.all_modalities} 
                        jointloss_val = {key: 0 for key in self.all_modalities} 
                        dist_qc = {key: 0 for key in self.all_modalities} 
                        log_dist_qc = {key: 0 for key in self.all_modalities} 
                        entr = {key: 0 for key in self.all_modalities} 
                        loss_rec = dict.fromkeys(self.modalities)
                        for m in self.modalities:
                            loss_rec[m] = np.zeros(self.n_arm[m])
                            
                        var_min = 0
                        loss_val = 0
                        for batch_indx, data_block in enumerate(test_loader):
                            d_idx = data_block[-1]
                            val_data = dict.fromkeys(self.modalities)
                            mod_mask = dict.fromkeys(self.all_modalities)
                            for im, m in enumerate(self.modalities):
                                if m == 'M':
                                    data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                                else:
                                    data = data_block[im]
                                mod_mask[m] = data_block[im + self.n_modality]
                                val_data[m] = []
                                for arm in range(self.n_arm[m]):
                                    val_data[m].append(data.to(rank))
                            for im, (m1, m2) in enumerate(self.mod_pairs):
                                mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                                
                            if self.ref_prior:
                                prior_c = torch.FloatTensor(c_p[d_idx, :]).to(rank)
                            else:
                                prior_c=None
                                
                            recon_batch, _, qc, _, c, mu, log_var, _ = self.model(
                                                                                    x=val_data, 
                                                                                    temp=self.temp, 
                                                                                    hard=hard, 
                                                                                    variational=variational, 
                                                                                    eval=True, 
                                                                                    pruning_mask=pruning_mask,
                                                                                    )
                            loss, loss_dict, l_rec, loss_joint, neg_entropy, c_distance, c_rep_dist, _, min_var = self.loss(
                                                                                                                            recon_x=recon_batch, 
                                                                                                                            x=val_data, 
                                                                                                                            mu=mu, 
                                                                                                                            log_sigma=log_var, 
                                                                                                                            qc=qc, 
                                                                                                                            c=c, 
                                                                                                                            lam=lam,
                                                                                                                            beta=beta,
                                                                                                                            prior_c=prior_c, 
                                                                                                                            mask=mod_mask,
                                                                                                                            ref_mod=ref_mod,
                                                                                                                            )
                            var_min += min_var.data.item()
                            loss_val += loss.data.item()
                            
                            for m in self.modalities:
                                for arm in range(self.n_arm[m]):
                                    loss_rec[m][arm] += l_rec[m][arm].data.item() / self.input_dim[m]
                            
                            for m in self.all_modalities:
                                mod_loss_val[m] += safe_item(loss_dict[m])
                                jointloss_val[m] += safe_item(loss_joint[m])
                                log_dist_qc[m] += safe_item(c_distance[m])
                                dist_qc[m] += safe_item(c_rep_dist[m])
                                entr[m] += safe_item(neg_entropy[m])
                                

                        pruning_validation_df.loc[epoch, 'total_loss'] = loss_val / (batch_indx + 1)
                        pruning_validation_df.loc[epoch, 'minVar'] = var_min / (batch_indx + 1)
                        for m in self.modalities:
                            pruning_validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                            for arm in range(self.n_arm[m]):
                                pruning_validation_df.loc[epoch, f'recon_error_arm{m}_{arm}'] = loss_rec[m][arm] / (batch_indx + 1)
                            
                        for m in self.cross_mod_pairs:
                            pruning_validation_df.loc[epoch, f'loss_joint_{m}'] = jointloss_val[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'entropy_{m}'] = entr[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'distance_{m}'] = dist_qc[m] / (batch_indx + 1)
                            pruning_validation_df.loc[epoch, f'log_distance_{m}'] = log_dist_qc[m] / (batch_indx + 1)
                        
                        print('Validation')
                        for m in self.modalities:
                            val_1 = pruning_validation_df[f'recon_error_arm{m}_0'][epoch]
                            val_2 = pruning_validation_df[f'distance_{m}'][epoch]
                            val_3 = pruning_validation_df[f'log_distance_{m}'][epoch]
                            val_4 = pruning_validation_df[f'entropy_{m}'][epoch]
                            print('====> {}, Rec Error:{:.4f}, Dist:{:.4f}, Log Dist:{:.4f}, H:{:.4f}'.format(m, val_1, val_2, val_3, val_4))
                        
                        if wandb_run:
                            wandb_run.log(
                                        {
                                        "validate/total-loss": pruning_validation_df['total_loss'][epoch],
                                        "validate/min-var": pruning_validation_df['minVar'][epoch],
                                        **dict(map(lambda m: (f"validate/rec-loss-{m}", pruning_validation_df[f'recon_error_arm{m}_0'][epoch]), self.modalities)),
                                        **dict(map(lambda m: (f"validate/distance-{m}", pruning_validation_df[f'distance_{m}'][epoch]), self.modalities)),
                                        **dict(map(lambda m: (f"validate/log_distance-{m}", pruning_validation_df[f'log_distance_{m}'][epoch]), self.modalities)),
                                        "validate/time": time.time() - t0,
                                        }
                                        )

                for m in self.modalities:
                    for arm in range(self.n_arm[m]):
                        if ws > 1:
                            dist.barrier()
                            # Average the weights across all processes
                            dist.all_reduce(self.model.qc[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.qc[m][arm].bias, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.mu[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.sigma[m][arm].weight, op=dist.ReduceOp.AVG)
                            dist.all_reduce(self.model.fc_lowD[m][arm].weight, op=dist.ReduceOp.AVG)
                           
                            with FSDP.summon_full_params(self.model, writeback=True):
                                prune.remove(self.model.qc[m][arm], 'weight')
                                prune.remove(self.model.qc[m][arm], 'bias')
                                prune.remove(self.model.mu[m][arm], 'weight')
                                prune.remove(self.model.sigma[m][arm], 'weight')
                                prune.remove(self.model.fc_lowD[m][arm], 'weight')
                            
                            dist.barrier()  
                        else:
                            prune.remove(self.model.qc[m][arm], 'weight')
                            prune.remove(self.model.qc[m][arm], 'bias')
                            prune.remove(self.model.mu[m][arm], 'weight')
                            prune.remove(self.model.sigma[m][arm], 'weight')
                            prune.remove(self.model.fc_lowD[m][arm], 'weight')
                           
                pr += 1
                # save the model and the learning curve
                trained_model = self.folder + f'/model/cpl_mixVAE_model_after_pruning_{pr}_{self.current_time}.pth'
                if ws > 1:
                    dist.barrier()
                    with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, save_policy):
                        cpu_state = self.model.state_dict()
                    if rank == 0:
                        torch.save(cpu_state, trained_model)
                    dist.barrier()
                else:
                    if self.compress:
                        trained_model += '.gz'
                        save_model_with_gzip(self.model.state_dict(), trained_model, compression_level=3)
                    else:
                        torch.save({'model_state_dict': self.model.state_dict()}, trained_model, pickle_protocol=4)
                            
                pruning_train_df.to_csv(self.folder + f'/pruning_train_df_{pr}_{self.current_time}.txt', sep='\t', index=False)
                pruning_validation_df.to_csv(self.folder + f'/pruning_validation_df_{pr}_{self.current_time}.txt', sep='\t', index=False)
                    
            
        print('Training is done!')
        return trained_model
        

    @torch.no_grad()
    def eval_model(self, data_loader, hard=False, temp=1., c_p=None, ref_prior=False, ref_mod=''):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            data_loader: input data loader
            c_p: the prior categorical variable, only if ref_prior is True.
            c_onehot: the one-hot representation of the prior categorical variable, only if ref_prior is True.

        return
            d_dict: the output dictionary.
        """
        
        bias = self.model.qc[self.modalities[0]][0].bias.detach().cpu().numpy()
        pruning_mask = np.where(bias != 0.)[0]
        prune_indx = np.where(bias == 0.)[0]
        data_len = len(data_loader.dataset)
        self.ref_prior = ref_prior
        
        recon_loss = dict.fromkeys(self.modalities)
        state_sample = dict.fromkeys(self.modalities)
        state_mu = dict.fromkeys(self.modalities)
        state_var = dict.fromkeys(self.modalities)
        c_prob = dict.fromkeys(self.modalities)
        c_sample = dict.fromkeys(self.modalities)
        predicted_label = dict.fromkeys(self.modalities)
        prob_cat = dict.fromkeys(self.modalities)
        for m in self.modalities:
            recon_loss[m] = [[] for arm in range(self.n_arm[m])]
            state_sample[m] = [[] for arm in range(self.n_arm[m])]
            state_mu[m] = [[] for arm in range(self.n_arm[m])]
            state_var[m] = [[] for arm in range(self.n_arm[m])]
            c_prob[m] = [[] for arm in range(self.n_arm[m])]
            c_sample[m] = [[] for arm in range(self.n_arm[m])]
            predicted_label[m] = [[] for arm in range(self.n_arm[m])]
            prob_cat[m] = [[] for arm in range(self.n_arm[m])]
        
    
        dist_qc = dict()
        dist_c = dict()
        lam = dict()
        for m in self.all_modalities:
            dist_qc[m] = []
            dist_c[m] = []
            lam[m] = 1

        samp_id = []
        n_modality = len(self.modalities)
        # evaluation of the model
        self.model.eval() 
        for batch_indx, data_block in enumerate(data_loader):
            d_idx = data_block[-1]
            train_data = dict.fromkeys(self.modalities)
            mod_mask = dict.fromkeys(self.all_modalities)
            for im, m in enumerate(self.modalities):
                if m == 'M':
                    data = data_block[im].reshape(data_block[im].shape[0], -1, 4, 4)
                else:
                    data = data_block[im]
                mod_mask[m] = data_block[im + n_modality]
                train_data[m] = []
                for arm in range(self.n_arm[m]):
                    train_data[m].append(data)
            for im, (m1, m2) in enumerate(self.mod_pairs):
                mod_mask[self.cross_mod_pairs[im]] = mod_mask[m1] & mod_mask[m2]
                
            if ref_prior:
                prior_c = torch.FloatTensor(c_p[d_idx, :])
            else:
                prior_c = None

            recon_batch, _, qc, state, c, mu, log_var, _ = self.model(
                                                                        x=train_data, 
                                                                        temp=temp, 
                                                                        hard=hard, 
                                                                        variational=self.variational, 
                                                                        pruning_mask=pruning_mask,
                                                                        )
            _, _, l_rec, _, _, c_distance, c_rep_dist, _, _ = self.loss(
                                                                        recon_x=recon_batch, 
                                                                        x=train_data, 
                                                                        mu=mu, 
                                                                        log_sigma=log_var, 
                                                                        qc=qc, 
                                                                        c=c, 
                                                                        lam=lam,
                                                                        beta=1,
                                                                        prior_c=prior_c, 
                                                                        mask=mod_mask,
                                                                        ref_mod=ref_mod,
                                                                        )
            
            for m in self.modalities:
                for arm in range(self.n_arm[m]):
                    recon_loss[m][arm].append(l_rec[m][arm].data.item() / self.input_dim[m])
                    state_sample[m][arm].append(state[m][arm].cpu().detach().numpy())
                    state_mu[m][arm].append(mu[m][arm].cpu().detach().numpy())
                    state_var[m][arm].append(log_var[m][arm].cpu().detach().numpy())
                    c_encoder = qc[m][arm].cpu().data.view(qc[m][arm].size()[0], self.n_categories).detach().numpy()
                    c_prob[m][arm].append(c_encoder)
                    c_samp = c[m][arm].cpu().data.view(c[m][arm].size()[0], self.n_categories).detach().numpy()
                    c_sample[m][arm].append(c_samp)
                    predicted_label[m][arm].append(np.argmax(c_encoder, axis=1) + 1)
                    prob_cat[m][arm].append(np.max(c_encoder, axis=1))
            
            for m in self.all_modalities:
                dist_qc[m].append(safe_item(c_rep_dist[m]))
                dist_c[m].append(safe_item(c_distance[m]))
            
            samp_id.append(d_idx.numpy().astype(int))

                
        samp_id = np.concatenate(samp_id, axis=0)
        total_dist_c = dict()
        total_dist_qc = dict()
        mean_total_loss_rec = dict()
        for m in self.modalities:
            mean_total_loss_rec[m] = np.zeros(self.n_arm[m])
            for arm in range(self.n_arm[m]):
                state_sample[m][arm] = np.concatenate(state_sample[m][arm], axis=0)
                state_mu[m][arm] = np.concatenate(state_mu[m][arm], axis=0)
                state_var[m][arm] = np.concatenate(state_var[m][arm], axis=0)
                c_prob[m][arm] = np.concatenate(c_prob[m][arm], axis=0)
                c_sample[m][arm] = np.concatenate(c_sample[m][arm], axis=0)
                predicted_label[m][arm] = np.concatenate(predicted_label[m][arm], axis=0)
                prob_cat[m][arm] = np.concatenate(prob_cat[m][arm], axis=0)
                mean_total_loss_rec[m][arm] = np.mean(np.array(recon_loss[m][arm]))
            
        for m in self.all_modalities:
            total_dist_c[m] = np.mean(np.array(dist_c[m]))
            total_dist_qc[m] = np.mean(np.array(dist_qc[m]))
                
        # save the output in a dictionary
        d_dict = dict()
        d_dict['state_sample'] = state_sample
        d_dict['state_mu'] = state_mu
        d_dict['state_var'] = state_var
        d_dict['prob_cat'] = prob_cat
        d_dict['total_loss_rec'] = mean_total_loss_rec
        d_dict['total_dist_z'] = total_dist_c
        d_dict['total_dist_qz'] = total_dist_qc
        d_dict['predicted_label'] = predicted_label
        d_dict['data_indx'] = samp_id
        d_dict['z_prob'] = c_prob
        d_dict['z_sample'] = c_sample
        d_dict['prune_indx'] = prune_indx

        return d_dict
    


    def save_file(self, fname, **kwargs):
        """
        Save data as a .p file using pickle.

        input args
            fname: the path of the pre-trained network.
            kwarg: keyword arguments for input variables e.g., x=[], y=[], etc.
        """

        f = open(fname + '.p', "wb")
        data = {}
        for k, v in kwargs.items():
            data[k] = v
        pickle.dump(data, f, protocol=4)
        f.close()

    def load_file(self, fname):
        """
        load data .p file using pickle. Make sure to use the same version of
        pcikle used for saving the file

        input args
            fname: the path of the pre-trained network.

        return
            data: a dictionary including the save dataset
        """

        data = pickle.load(open(fname + '.p', "rb"))
        return data
