import torch
import pickle
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from torch.nn import functional as F
import torch.nn.utils.prune as prune
import time
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt

from mmids.RNA_ATAC_networks import RNA_ATAC_mixVAE
    

class train_cplmixVAE:

    def __init__(self, saving_folder='', device=None, eps=1e-8, save_flag=True):

        self.eps = eps
        self.save = save_flag
        self.folder = saving_folder
        self.device = device
        if device is None:
            self.gpu = False
            print('using CPU ...')
        else:
            self.gpu = True
            torch.cuda.set_device(device)
            gpu_device = torch.device('cuda:' + str(device))
            print('using GPU ' + torch.cuda.get_device_name(torch.cuda.current_device()))


    def data_gen(self, dataset, train_size, seed):

        test_size = dataset.shape[0] - train_size
        train_cpm, test_cpm, train_ind, test_ind = train_test_split(dataset, np.arange(dataset.shape[0]), train_size=train_size, test_size=test_size, random_state=seed)
        train_cpm, val_cpm, train_ind, val_ind = train_test_split(train_cpm, train_ind, train_size=train_size - test_size, test_size=test_size, random_state=seed)

        return train_ind, val_ind, test_ind

    def getdata(self, dataset_1, dataset_2, batch_size=128, train_size=0.9, seed=0):

        self.batch_size = batch_size
        tt_size = int(train_size * dataset_1.shape[0])
        train_ind, val_ind, test_ind = self.data_gen(dataset_1, tt_size, seed)

        # train_ind, val_ind, test_ind = self.data_gen(dataset=dataset_T, train_size=int(train_size * dataset_T.shape[0]))

        train_set_torch_1 = torch.FloatTensor(dataset_1[train_ind, :])
        train_set_torch_2 = torch.FloatTensor(dataset_2[train_ind, :])
        train_ind_torch = torch.FloatTensor(train_ind)

        train_data = TensorDataset(train_set_torch_1, train_set_torch_2, train_ind_torch)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

        val_set_torch_1 = torch.FloatTensor(dataset_1[val_ind, :])
        val_set_torch_2 = torch.FloatTensor(dataset_2[val_ind, :])
        val_ind_torch = torch.FloatTensor(val_ind)
        validation_data = TensorDataset(val_set_torch_1, val_set_torch_2, val_ind_torch)
        validation_loader = DataLoader(validation_data, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True)

        test_set_torch_1 = torch.FloatTensor(dataset_1[test_ind, :])
        test_set_torch_2 = torch.FloatTensor(dataset_2[test_ind, :])
        test_ind_torch = torch.FloatTensor(test_ind)
        test_data = TensorDataset(test_set_torch_1, test_set_torch_2, test_ind_torch)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=False, pin_memory=True)

        data_set_troch_1 = torch.FloatTensor(dataset_1)
        data_set_troch_2 = torch.FloatTensor(dataset_2)
        all_ind_torch = torch.FloatTensor(range(dataset_1.shape[0]))
        all_data = TensorDataset(data_set_troch_1, data_set_troch_2, all_ind_torch)
        alldata_loader = DataLoader(all_data, batch_size=batch_size, shuffle=False, drop_last=False, pin_memory=True)

        return alldata_loader, train_loader, validation_loader, test_loader, train_ind, val_ind, test_ind

    def init_model(self, n_categories, state_dim, input_dim, fc_dim, lowD_dim, x_drop, lr, lam, beta, noise_std, n_layers, temp=1.,
                   tau=0.01, hard=False, state_det=False, trained_model='', n_modal=2):
        """
        Initialized the deep mixture model and its optimizer.

        input args:
            fc_dim: dimension of the hidden layer.
            lowD_dim: dimension of the latent representation.
            x_drop: dropout probability at the first (input) layer.
            lr: the learning rate of the optimizer, here Adam.
            lam: coupling factor in the cpl-mixVAE model.
            tau: temperature of the softmax layers, usually equals to 1/n_categories (0 < tau <= 1).
            beta: regularizer for the KL divergence term.
            hard: a boolean variable, True uses one-hot method that is used in Gumbel-softmax, and False uses the Gumbel-softmax function.
            state_det: a boolean variable, False uses sampling.
            trained_model: the path of a pre-trained model, in case you wish to initialized the network with a pre-trained network.
            momentum: a hyperparameter for batch normalization that updates its running statistics.
        """
        self.lowD_dim = lowD_dim
        self.n_categories = n_categories
        self.state_dim = state_dim
        self.input_dim = input_dim
        self.temp = temp
        self.n_modal = n_modal
        self.noise_std = noise_std
        self.model = RNA_ATAC_mixVAE(input_dim=self.input_dim, fc_dim=fc_dim, n_categories=self.n_categories, state_dim=self.state_dim,
                                lowD_dim=lowD_dim, x_drop=x_drop, lam=lam, tau=tau, beta=beta, hard=hard,
                                state_det=state_det, n_std=noise_std, device=self.device, eps=self.eps,
                                n_layer=n_layers, momentum=0.01, affine=False)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

        if self.gpu:
            self.model = self.model.cuda(self.device)

        if len(trained_model) > 0:
            print('Load the pre-trained model')
            # if you wish to load another model for evaluation
            loaded_file = torch.load(trained_model, map_location='cpu')
            self.model.load_state_dict(loaded_file['model_state_dict'])
            self.optimizer.load_state_dict(loaded_file['optimizer_state_dict'])

    def load_model(self, trained_model):
        loaded_file = torch.load(trained_model, map_location='cpu')
        self.model.load_state_dict(loaded_file['model_state_dict'])

        self.current_time = time.strftime('%Y-%m-%d-%H-%M-%S')


    def run(self, train_loader, validation_loader, test_loader, alldata_loader, mask, modalities, n_epoch=0, n_epoch_p=0, min_con=.5, max_pron_it=30, min_density=20):
        """
        run the training of the cpl-mixVAE with the pre-defined parameters/settings
        pcikle used for saving the file

        input args
            data_df: a data frame including 'cluster_id', 'cluster', and 'class_label'
            train_loader: train dataloader
            test_loader: test dataloader
            validation_set:
            n_epoch: number of training epoch, without pruning
            n_epoch: number of training epoch, with pruning
            min_con: minimum value of consensus among pairs of arms
            temp: temperature of sampling

        return
            data_file_id: the path of the output dictionary.
        """
        # define current_time
        self.current_time = time.strftime('%Y-%m-%d-%H-%M-%S')
        self.mask = mask

        # initialized saving arrays
        train_loss = np.zeros(n_epoch)
        validation_loss = dict()
        train_recon = dict()

        for mod in modalities:
            validation_loss[mod] = np.zeros(n_epoch)
            train_recon[mod] = np.zeros(n_epoch)

        train_loss_joint = np.zeros(n_epoch)
        train_entropy = np.zeros(n_epoch)
        train_distance = np.zeros(n_epoch)
        train_minVar = np.zeros(n_epoch)
        train_log_distance = np.zeros(n_epoch)

        bias_mask = torch.ones(self.n_categories)
        weight_mask_1 = torch.ones((self.n_categories, self.lowD_dim[modalities[0]]))
        fc_mu_1 = torch.ones((self.state_dim[modalities[0]], self.n_categories + self.lowD_dim[modalities[0]]))
        fc_sigma_1 = torch.ones((self.state_dim[modalities[0]], self.n_categories + self.lowD_dim[modalities[0]]))
        f6_mask_1 = torch.ones((self.lowD_dim[modalities[0]], self.state_dim[modalities[0]] + self.n_categories))

        weight_mask_2 = torch.ones((self.n_categories, self.lowD_dim[modalities[1]]))
        fc_mu_2 = torch.ones((self.state_dim[modalities[1]], self.n_categories + self.lowD_dim[modalities[1]]))
        fc_sigma_2 = torch.ones((self.state_dim[modalities[1]], self.n_categories + self.lowD_dim[modalities[1]]))
        f6_mask_2 = torch.ones((self.lowD_dim[modalities[1]], self.state_dim[modalities[1]] + self.n_categories))

        if self.gpu:
            bias_mask = bias_mask.cuda(self.device)
            weight_mask_1 = weight_mask_1.cuda(self.device)
            fc_mu_1 = fc_mu_1.cuda(self.device)
            fc_sigma_1 = fc_sigma_1.cuda(self.device)
            f6_mask_1 = f6_mask_1.cuda(self.device)

            weight_mask_2 = weight_mask_2.cuda(self.device)
            fc_mu_2 = fc_mu_2.cuda(self.device)
            fc_sigma_2 = fc_sigma_2.cuda(self.device)
            f6_mask_2 = f6_mask_2.cuda(self.device)


        print("Start training...")
        for epoch in range(n_epoch):
            train_loss_val = 0
            train_jointloss_val = 0
            train_dqc = 0
            log_dqc = 0
            entr = 0
            var_min = 0
            t0 = time.time()
            train_loss_rec = dict()
            val_loss_rec = dict()
            for mod in modalities:
                train_loss_rec[mod] = 0.
                val_loss_rec[mod] = 0.

            self.model.train()
            for batch_indx, (data_1, data_2, indx), in enumerate(train_loader):
               
                if self.gpu:
                    data_1 = data_1.cuda(self.device)
                    data_2 = data_2.cuda(self.device)

                train_data = dict()
                train_data[modalities[0]] = data_1
                train_data[modalities[1]] = data_2

                self.optimizer.zero_grad()
                recon_batch, x_low, qc, s, c, mu, log_var, log_qc = self.model(x=train_data, temp=self.temp)
                loss, loss_rec, loss_joint, entropy, dist_c, d_qc, KLD_cont, min_var_0 = self.model.loss(recon_batch, train_data, mu, log_var, qc, c)
                loss.backward()
                self.optimizer.step()
                train_loss_val += loss.data.item()
                train_jointloss_val += loss_joint
                train_dqc += d_qc
                log_dqc += dist_c
                entr += entropy
                var_min += min_var_0.data.item()

                for mod in modalities:
                    train_loss_rec[mod] += loss_rec[mod].data.item() / sum(train_data[mod].shape[1:])
               
            train_loss[epoch] = train_loss_val / (batch_indx + 1)
            train_loss_joint[epoch] = train_jointloss_val / (batch_indx + 1)
            train_distance[epoch] = train_dqc / (batch_indx + 1)
            train_entropy[epoch] = entr / (batch_indx + 1)
            train_log_distance[epoch] = log_dqc / (batch_indx + 1)
            train_minVar[epoch] = var_min / (batch_indx + 1)

            for mod in modalities:
                train_recon[mod][epoch] = train_loss_rec[mod] / (batch_indx + 1)
                

            print('====> Epoch:{}, Total Loss: {:.4f}, Loss_1: {'':.4f}, Loss_2: {:.4f}, Joint Loss: {:.4f}, '
                  'Entropy: {:.4f}, d_logqz: {:.4f}, d_qz: {:.4f}, var_min: {:.4f}, Elapsed Time:{:.2f}'.format(
                epoch, train_loss[epoch], train_recon[modalities[0]][epoch], train_recon[modalities[1]][epoch], train_loss_joint[epoch],
                train_entropy[epoch], train_log_distance[epoch], train_distance[epoch], train_minVar[epoch], time.time() - t0))

            # validation
            self.model.eval()
            with torch.no_grad():
                for batch_indx, (val_data_1, val_data_2, indx), in enumerate(validation_loader):

                    if self.gpu:
                        val_data_1 = val_data_1.cuda(self.device)
                        val_data_2 = val_data_2.cuda(self.device)
                    
                    val_data = dict()
                    val_data[modalities[0]] = val_data_1
                    val_data[modalities[1]] = val_data_2

                    recon_batch, x_low, qc, s, c, mu, log_var, _ = self.model(x=val_data, temp=self.temp, eval=True)
                    loss, loss_arms, loss_joint, _, _, _, _, _ = self.model.loss(recon_batch, val_data, mu, log_var, qc, c)

                    for mod in modalities:
                        val_loss_rec[mod] += loss_arms[mod].data.item() / sum(val_data[mod].shape[1:])

            for mod in modalities:
                validation_loss[mod][epoch] = val_loss_rec[mod] / (batch_indx + 1)  

            print('====> Validation Loss_1: {:.4f}, Loss_2: {:.4f}'.format(validation_loss[modalities[0]][epoch], validation_loss[modalities[1]][epoch]))
            torch.cuda.empty_cache()

        if self.save and n_epoch > 0:
            trained_model = self.folder + '/model/cpl_mixVAE_model_before_pruning_' + self.current_time + '.pth'
            torch.save({'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()}, trained_model)

            bias = self.model.qc_T.bias.detach().cpu().numpy()
            pruning_mask = range(len(bias))
            prune_indx = []
            # plot the learning curve of the network
            fig, ax = plt.subplots()
            ax.plot(range(n_epoch), train_loss)
            ax.set_xlabel('# epoch', fontsize=16)
            ax.set_ylabel('loss value', fontsize=16)
            ax.set_title('Learning curve of the cpl-mixVAE for K=' + str(self.n_categories))
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            ax.figure.savefig(self.folder + '/model/learning_curve_before_pruning_K_' + str(self.n_categories) + '_' + self.current_time + '.png')
            plt.close("all")

        if n_epoch_p > 0:
            # initialized pruning parameters of the layer of the discrete variable
            bias = self.model.qc_1.bias.detach().cpu().numpy()
            pruning_mask = np.where(bias != 0.)[0]
            prune_indx = np.where(bias == 0.)[0]
            stop_prune = False
        else:
            stop_prune = True

        pr = 0
        ind = []
        while not stop_prune:
            predicted_label = dict()
            for mod in modalities:
                predicted_label[mod] = []

            # Assessment over all dataset
            data_indx = []
            self.model.eval()
            with torch.no_grad():
                for i, (data_1, data_2, indx) in enumerate(alldata_loader):
                    indx = indx.type(torch.IntTensor)
                    data_indx.append(indx)

                    if self.gpu:
                        data_1 = data_1.cuda(self.device)
                        data_2 = data_2.cuda(self.device)

                    a_data = dict()
                    a_data[modalities[0]] = data_1
                    a_data[modalities[1]] = data_2    

                    recon, x_low, z_category, state, z_smp, mu, log_sigma, _ = self.model(x=a_data, temp=self.temp)
                    loss, loss_arms, loss_joint, _, _, _, _, _ = self.model.loss(recon, a_data, mu, log_sigma, z_category, z_smp)

                    for mod in modalities:
                        z_encoder = z_category[mod].cpu().data.view(z_category[mod].size()[0], self.n_categories).detach().numpy()
                        label_predict = []
                        for d in range(len(indx)):
                            z_cat = np.squeeze(z_encoder[d, :])
                            label_predict.append(np.argmax(z_cat))
                        predicted_label[mod].append(np.array(label_predict))


            torch.cuda.empty_cache()
            c_agreement = []
            data_indx = np.concatenate(data_indx)


            pred_1 = np.concatenate(predicted_label[modalities[0]])
            pred_2 = np.concatenate(predicted_label[modalities[1]])
            m1_vs_m2 = np.zeros((self.n_categories, self.n_categories))

            for samp in range(pred_1.shape[0]):
                m1_vs_m2[pred_1[samp].astype(int) - 1, pred_2[samp].astype(int) - 1] += 1

            num_samp_arm = []
            for ij in range(self.n_categories):
                sum_row = m1_vs_m2[ij, :].sum()
                sum_column = m1_vs_m2[:, ij].sum()
                num_samp_arm.append(max(sum_row, sum_column))

            m1_vs_m2 = np.divide(m1_vs_m2, np.array(num_samp_arm), out=np.zeros_like(m1_vs_m2), where=np.array(num_samp_arm) != 0)
            c_agreement.append(np.diag(m1_vs_m2))
            ind_sort = np.argsort(c_agreement[-1])
            plt.figure()
            plt.imshow(m1_vs_m2[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
            plt.colorbar()
            plt.xlabel(f'arm {modalities[0]}', fontsize=20)
            plt.xticks(range(self.n_categories), range(self.n_categories))
            plt.yticks(range(self.n_categories), range(self.n_categories))
            plt.ylabel(f'arm {modalities[1]}', fontsize=20)
            plt.xticks([])
            plt.yticks([])
            plt.title('|c|=' + str(self.n_categories), fontsize=20)
            plt.savefig(self.folder + f'/consensus_{pr}_arm_{modalities[0]}_arm_{modalities[1]}.png', dpi=600)
            plt.close("all")

            c_agreement = np.mean(c_agreement, axis=0)
            agreement = c_agreement[pruning_mask]
            if (np.min(agreement) <= min_con) and pr < max_pron_it:
                if pr > 0:
                    ind_min = pruning_mask[np.argmin(agreement)]
                    ind_min = np.array([ind_min])
                    ind = np.concatenate((ind, ind_min))
                else:
                    ind_min = pruning_mask[np.argmin(agreement)]
                    if len(prune_indx) > 0:
                        ind_min = np.array([ind_min])
                        ind = np.concatenate((prune_indx, ind_min))
                    else:
                        ind.append(ind_min)
                    ind = np.array(ind)

                ind = ind.astype(int)
                print(ind)
                bias_mask[ind] = 0.
                weight_mask_1[ind, :] = 0.
                fc_mu_1[:, self.lowD_dim[modalities[0]] + ind] = 0.
                fc_sigma_1[:, self.lowD_dim[modalities[0]] + ind] = 0.
                f6_mask_1[:, ind] = 0.
                weight_mask_2[ind, :] = 0.
                fc_mu_2[:, self.lowD_dim[modalities[1]] + ind] = 0.
                fc_sigma_2[:, self.lowD_dim[modalities[1]] + ind] = 0.
                f6_mask_2[:, ind] = 0.
                stop_prune = False
            else:
                print('No more pruning!')
                stop_prune = True

            if not stop_prune:
                print("Training with pruning...")
                bias = bias_mask.detach().cpu().numpy()
                pruning_mask = np.where(bias != 0.)[0]
                train_loss = np.zeros(n_epoch_p)

                validation_loss = dict()
                train_recon = dict()
                for mod in modalities:
                    validation_loss[mod] = np.zeros(n_epoch_p)
                    train_recon[mod] = np.zeros(n_epoch_p)
         
                train_loss_joint = np.zeros(n_epoch_p)
                train_entropy = np.zeros(n_epoch_p)
                train_distance = np.zeros(n_epoch_p)
                train_minVar = np.zeros(n_epoch_p)
                train_log_distance = np.zeros(n_epoch_p)
      
                prune.custom_from_mask(self.model.qc_1, 'weight', mask=weight_mask_1)
                prune.custom_from_mask(self.model.qc_1, 'bias', mask=bias_mask)
                prune.custom_from_mask(self.model.mu_1, 'weight', mask=fc_mu_1)
                prune.custom_from_mask(self.model.sigma_1, 'weight', mask=fc_sigma_1)
                prune.custom_from_mask(self.model.layer_d_lowD_1, 'weight', mask=f6_mask_1)
    
                prune.custom_from_mask(self.model.qc_2, 'weight', mask=weight_mask_2)
                prune.custom_from_mask(self.model.qc_2, 'bias', mask=bias_mask)
                prune.custom_from_mask(self.model.mu_2, 'weight', mask=fc_mu_2)
                prune.custom_from_mask(self.model.sigma_2, 'weight', mask=fc_sigma_2)
                prune.custom_from_mask(self.model.layer_d_lowD_2, 'weight', mask=f6_mask_2)

                for epoch in range(n_epoch_p):
                    train_loss_val = 0
                    train_jointloss_val = 0
                    train_dqc = 0
                    log_dqc = 0
                    entr = 0
                    var_min = 0
                    t0 = time.time()

                    train_loss_rec = dict()
                    val_loss_rec = dict()
                    for mod in modalities:
                        train_loss_rec[mod] = 0.
                        val_loss_rec[mod] = 0.

                    self.model.train()
                    for batch_indx, (data_1, data_2, indx), in enumerate(train_loader):
                        
                        if self.gpu:
                            data_1 = data_1.cuda(self.device)
                            data_2 = data_2.cuda(self.device)
                         
                        train_data = dict()
                        train_data['T'] = data_1
                        train_data['E'] = data_2

                        self.optimizer.zero_grad()
                        recon_batch, x_low, qc, s, c, mu, log_var, log_qc = self.model(x=train_data, temp=self.temp, pruning_mask=pruning_mask)
                        loss, loss_rec, loss_joint, entropy, dist_c, d_qc, KLD_cont, min_var_0 = self.model.loss(recon_batch, train_data, mu, log_var, qc, c)

                        loss.backward()
                        self.optimizer.step()
                        train_loss_val += loss.data.item()
                        train_jointloss_val += loss_joint
                        train_dqc += d_qc
                        log_dqc += dist_c
                        entr += entropy
                        var_min += min_var_0.data.item()

                        for mod in modalities:
                            train_loss_rec[mod] += loss_rec[mod].data.item() / sum(train_data[mod].shape[1:])

                    train_loss[epoch] = train_loss_val / (batch_indx + 1)
                    train_loss_joint[epoch] = train_jointloss_val / (batch_indx + 1)
                    train_distance[epoch] = train_dqc / (batch_indx + 1)
                    train_entropy[epoch] = entr / (batch_indx + 1)
                    train_log_distance[epoch] = log_dqc / (batch_indx + 1)
                    train_minVar[epoch] = var_min / (batch_indx + 1)

                    for mod in modalities:
                        train_recon[mod][epoch] = train_loss_rec[mod] / (batch_indx + 1)

                    print('====> Epoch:{}, Total Loss: {:.4f}, Loss_1: {'':.4f}, Loss_2: {:.4f}, Joint Loss: {:.4f}, '
                          'Entropy: {:.4f}, d_logqz: {:.4f}, d_qz: {:.4f}, var_min: {:.4f}, Elapsed Time:{:.2f}'.format(
                        epoch, train_loss[epoch], train_recon[modalities[0]][epoch], train_recon[modalities[1]][epoch], train_loss_joint[epoch],
                        train_entropy[epoch], train_log_distance[epoch], train_distance[epoch], train_minVar[epoch], time.time() - t0))


                    # validation
                    self.model.eval()
                    with torch.no_grad():
                        for batch_indx, (val_data_1, val_data_2, indx), in enumerate(validation_loader):
                            
                            if self.gpu:
                                val_data_1 = val_data_1.cuda(self.device)
                                val_data_2 = val_data_2.cuda(self.device)
                            
                            val_data = dict()
                            val_data[modalities[0]] = val_data_1
                            val_data[modalities[1]] = val_data_2

                            recon_batch, x_low, qc, s, c, mu, log_var, _ = self.model(x=val_data, temp=self.temp, pruning_mask=pruning_mask, eval=True)
                            loss, loss_arms, loss_joint, _, _, _, _, _ = self.model.loss(recon_batch, val_data, mu, log_var, qc, c)

                            for mod in modalities:
                                val_loss_rec[mod] += loss_arms[mod].data.item() / sum(val_data[mod].shape[1:])

                    for mod in modalities:
                        validation_loss[mod][epoch] = val_loss_rec[mod] / (batch_indx + 1)

                    print('====> Validation Loss_1: {:.4f}, Loss_2: {:.4f}'.format(validation_loss[modalities[0]][epoch], validation_loss[modalities[1]][epoch]))
                    torch.cuda.empty_cache()
                    # print('current memory allocated: {}'.format(torch.cuda.memory_allocated() / 1024 ** 2))
                    # print('max memory allocated: {}'.format(torch.cuda.max_memory_allocated() / 1024 ** 2))
                    # print('cached memory: {}'.format(torch.cuda.memory_cached() / 1024 ** 2))

                if self.save and n_epoch_p > 0:
                    # plot the learning curve of the network
                    fig, ax = plt.subplots()
                    ax.plot(range(n_epoch_p), train_loss)
                    ax.set_xlabel('# epoch', fontsize=16)
                    ax.set_ylabel('loss value', fontsize=16)
                    ax.set_title('Learning curve of the cpl-mixVAE for K=' + str(self.n_categories))
                    ax.spines['right'].set_visible(False)
                    ax.spines['top'].set_visible(False)
                    ax.figure.savefig(self.folder + '/model/learning_curve_after_pruning_' + str(pr+1) + '_K_' + str(self.n_categories) + '_' + self.current_time + '.png')
                    plt.close("all")

                prune.remove(self.model.qc_1, 'weight')
                prune.remove(self.model.qc_1, 'bias')
                prune.remove(self.model.mu_1, 'weight')
                prune.remove(self.model.sigma_1, 'weight')
                prune.remove(self.model.layer_d_lowD_1, 'weight')
    
                prune.remove(self.model.qc_2, 'weight')
                prune.remove(self.model.qc_2, 'bias')
                prune.remove(self.model.mu_2, 'weight')
                prune.remove(self.model.sigma_2, 'weight')
                prune.remove(self.model.layer_d_lowD_2, 'weight')

                trained_model = self.folder + '/model/cpl_mixVAE_model_after_pruning_' + str(pr+1) + '_' + self.current_time + '.pth'
                torch.save({'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()}, trained_model)
                pr += 1

        # Evaluate the trained model
        bias = self.model.qc_1.bias.detach().cpu().numpy()
        pruning_mask = np.where(bias != 0.)[0]
        prune_indx = np.where(bias == 0.)[0]
        max_len = len(alldata_loader.dataset)

        state_sample = dict()
        state_mu = dict()
        state_var = dict()
        z_prob = dict()
        z_sample = dict()
        state_cat = dict()
        predicted_label = dict()
        x_low_all = dict()
        for mod in modalities:
            state_sample[mod] = []
            state_mu[mod] = []
            state_var[mod] = []
            z_prob[mod] = []
            z_sample[mod] = []
            state_cat[mod] = []
            predicted_label[mod] = []
            x_low_all[mod] = []
        
        total_loss_rec = dict()
        test_loss_rec = dict()
        for mod in modalities:
            total_loss_rec[mod] = []
            test_loss_rec[mod] = 0.

        total_loss_val = []
        total_dist_z = []
        total_dist_qz = []

        self.model.eval()
        with torch.no_grad():
            for batch_indx, (test_data_1, test_data_2, indx), in enumerate(test_loader):

                if self.gpu:
                    test_data_1 = test_data_1.cuda(self.device)
                    test_data_2 = test_data_2.cuda(self.device)

                test_data = dict()
                test_data[modalities[0]] = test_data_1
                test_data[modalities[1]] = test_data_2

                recon_batch, x_low, qc, s, c, mu, log_var, _ = self.model(x=test_data, temp=self.temp, eval=True, pruning_mask=pruning_mask)
                loss, loss_arms, loss_joint, _, _, _, _, _ = self.model.loss(recon_batch, test_data, mu, log_var, qc, c)

                for mod in modalities:
                    test_loss_rec[mod] += loss_arms[mod].data.item() / sum(test_data[mod].shape[1:])
            
            test_loss = dict()
            for mod in modalities:
                test_loss[mod] = np.mean(test_loss_rec[mod] / (batch_indx + 1))

            print('====> Test Loss_1: {:.4f}, Loss_2: {:.4f}'.format(test_loss[modalities[0]], test_loss[modalities[1]]))
            data_indx = []

            for i, (data_1, data_2, indx) in enumerate(alldata_loader):
                indx = indx.type(torch.IntTensor)
                data_indx.append(indx)
                if self.gpu:
                    data_1 = data_1.cuda(self.device)
                    data_2 = data_2.cuda(self.device)

                a_data = dict()
                a_data[modalities[0]] = data_1
                a_data[modalities[1]] = data_2

                recon, x_low, z_category, state, z_smp, mu, log_sigma, _ = self.model(x=a_data, temp=self.temp, eval=True, pruning_mask=pruning_mask)
                loss, loss_arms, loss_joint, _, dist_z, d_qz, _, _ = self.model.loss(recon, a_data, mu, log_sigma, z_category, z_smp)
                total_loss_val.append(loss.data.item())
                total_dist_z.append(dist_z.data.item())
                total_dist_qz.append(d_qz.data.item())

                for mod in modalities:
                    total_loss_rec[mod].append(loss_arms[mod].data.item())

                for mod in modalities:
                    state_sample[mod].append(state[mod].cpu().detach().numpy())
                    state_mu[mod].append(mu[mod].cpu().detach().numpy())
                    state_var[mod].append(log_sigma[mod].cpu().detach().numpy())
                    z_encoder = z_category[mod].cpu().data.view(z_category[mod].size()[0], self.n_categories).detach().numpy()
                    z_prob[mod].append(z_encoder)
                    z_samp = z_smp[mod].cpu().data.view(z_smp[mod].size()[0], self.n_categories).detach().numpy()
                    z_sample[mod].append(z_samp)
                    x_low_all[mod].append(x_low[mod].cpu().detach().numpy())

                    for n in range(z_encoder.shape[0]):
                        state_cat[mod].append(np.argmax(z_encoder[n, :]) + 1)

                    label_predict = []
                    for d in range(len(indx)):
                        z_cat = np.squeeze(z_encoder[d, :])
                        label_predict.append(np.argmax(z_cat) + 1)

                    predicted_label[mod].append(label_predict)

               
        data_indx = np.concatenate(data_indx)
        pred_1 = np.concatenate(predicted_label[modalities[0]])
        pred_2 = np.concatenate(predicted_label[modalities[1]])
        m1_vs_m2 = np.zeros((self.n_categories, self.n_categories))

        for samp in range(pred_1.shape[0]):
            m1_vs_m2[pred_1[samp].astype(int) - 1, pred_2[samp].astype(int) - 1] += 1

        num_samp_arm = []
        for ij in range(self.n_categories):
            sum_row = m1_vs_m2[ij, :].sum()
            sum_column = m1_vs_m2[:, ij].sum()
            num_samp_arm.append(max(sum_row, sum_column))

        m1_vs_m2 = np.divide(m1_vs_m2, np.array(num_samp_arm), out=np.zeros_like(m1_vs_m2), where=np.array(num_samp_arm) != 0)
        c_agreement = np.diag(m1_vs_m2)
        ind_sort = np.argsort(c_agreement)
        plt.figure()
        plt.imshow(m1_vs_m2[:, ind_sort[::-1]][ind_sort[::-1]], cmap='binary')
        plt.colorbar()
        plt.xlabel(f'arm {modalities[0]}', fontsize=20)
        plt.xticks(range(self.n_categories), range(self.n_categories))
        plt.yticks(range(self.n_categories), range(self.n_categories))
        plt.ylabel(f'arm {modalities[1]}', fontsize=20)
        plt.xticks([])
        plt.yticks([])
        plt.title('|c|=' + str(self.n_categories), fontsize=20)
        plt.savefig(self.folder + f'/consensus_all_training_arm_{modalities[0]}_arm_{modalities[1]}.png', dpi=600)

        # save data
        data_file_id = self.folder + '/model/data_' + self.current_time

        if self.save:
            self.save_file(data_file_id,
                           state_sample=state_sample,
                           state_mu=state_mu,
                           state_var=state_var,
                           train_loss=train_loss,
                           validation_loss=validation_loss,
                           total_dist_z=np.mean(np.array(total_dist_z)),
                           total_dist_qz=np.mean(np.array(total_dist_qz)),
                           mean_test_rec=test_loss,
                           predicted_label=predicted_label,
                           z_prob=z_prob,
                           z_sample=z_sample,
                           lowD_rep=x_low_all,
                           prune_indx=prune_indx)

        return data_file_id

     
    def save_file(self, fname, **kwargs):
        """
        Save data as a .p file using pickle.

        input args
            fname: the path of the pre-trained network.
            kwarg: keyword arguments for input variables e.g., x=[], y=[], etc.
        """

        f = open(fname + '.p', "wb")
        data = {}
        for k, v in kwargs.items():
            data[k] = v
        pickle.dump(data, f)
        f.close()

    def load_file(self, fname):
        """
        load data .p file using pickle. Make sure to use the same version of
        pcikle used for saving the file

        input args
            fname: the path of the pre-trained network.

        return
            data: a dictionary including the save dataset
        """

        data = pickle.load(open(fname + '.p', "rb"))
        return data