import os
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Optimizer
import torch.distributions as dist
from .utils import Encoder,MIEstimator,ExponentialScheduler

import torch.optim as optimizer_module


from .numpydataset import NumpyDataset
from cca_zoo.deepmodels import (
    DCCAE,
)
from cca_zoo.deepmodels import architectures
from cca_zoo.data.deep import get_dataloaders
import pytorch_lightning as pl
import numpy as np
import pdb

from abc import ABC, abstractmethod
from itertools import chain, combinations
import os

import torch
import torch.nn as nn
from torch.autograd import Variable

from .divergence_measures.mm_div import calc_alphaJSD_modalities
from .divergence_measures.mm_div import calc_group_divergence_moe
from .divergence_measures.mm_div import poe, alpha_poe
from .divergence_measures.kl_div import calc_kl_divergence
from scipy.stats import truncnorm
from tqdm import tqdm
class MyClass:
    pass


class EncoderImg(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """
    def __init__(self, flags):
        super(EncoderImg, self).__init__()

        self.flags = flags
        # self.shared_encoder = nn.Sequential(                          # input shape (3, 28, 28)
        #     nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),     # -> (32, 14, 14)
        #     nn.ReLU(),
        #     nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),    # -> (64, 7, 7)
        #     nn.ReLU(),
        #     nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),   # -> (128, 4, 4)
        #     nn.ReLU(),
        #     Flatten(),                                                # -> (2048)
        #     nn.Linear(2048, flags.style_dim + flags.class_dim),       # -> (ndim_private + ndim_shared)
        #     nn.ReLU(),
        # )
        self.shared_encoder = architectures.Encoder(latent_dims=flags.style_dim + flags.class_dim, feature_size=2352,layer_sizes=(512,),activation=nn.ReLU())
        self.act = nn.ReLU()
        # content branch
        self.class_mu = nn.Linear(flags.style_dim + flags.class_dim, flags.class_dim)
        self.class_logvar = nn.Linear(flags.style_dim + flags.class_dim, flags.class_dim)
        # optional style branch
        if flags.factorized_representation:
            self.style_mu = nn.Linear(flags.style_dim + flags.class_dim, flags.style_dim)
            self.style_logvar = nn.Linear(flags.style_dim + flags.class_dim, flags.style_dim)

    def forward(self, x):
        #x = x.reshape(x.shape[0],-1)
        #pdb.set_trace()
        h = self.act(self.shared_encoder(x))
        if self.flags.factorized_representation:
            return self.style_mu(h), self.style_logvar(h), self.class_mu(h), \
                   self.class_logvar(h)
        else:
            return None, None, self.class_mu(h), self.class_logvar(h)


class DecoderImg(nn.Module):
    """
    Adopted from:
    https://www.cs.toronto.edu/~lczhang/360/lec/w05/autoencoder.html
    """
    def __init__(self, flags):
        super(DecoderImg, self).__init__()
        self.flags = flags
        # self.decoder = nn.Sequential(
        #     nn.Linear(flags.style_dim + flags.class_dim, 2048),                                # -> (2048)
        #     nn.ReLU(),
        #     Unflatten((128, 4, 4)),                                                            # -> (128, 4, 4)
        #     nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1),                   # -> (64, 7, 7)
        #     nn.ReLU(),
        #     nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # -> (32, 14, 14)
        #     nn.ReLU(),
        #     nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1),   # -> (3, 28, 28)
        # )
        self.decoder = architectures.Decoder(latent_dims=flags.style_dim + flags.class_dim, feature_size=2352,layer_sizes=(512,),activation=nn.ReLU())

    def forward(self, style_latent_space, class_latent_space):
        if self.flags.factorized_representation:
            z = torch.cat((style_latent_space, class_latent_space), dim=1)
        else:
            z = class_latent_space
        x_hat = self.decoder(z)#.reshape(z.shape[0],3,28,28)
        # x_hat = torch.sigmoid(x_hat)
        return x_hat, torch.tensor(0.75).to(z.device)  # NOTE: consider learning scale param, too


def reweight_weights(w):
    w = w / w.sum();
    return w;

def mixture_component_selection(flags, mus, logvars, w_modalities=None):
    #if not defined, take pre-defined weights
    num_components = mus.shape[0];
    num_samples = mus.shape[1];
    if w_modalities is None:
        w_modalities = torch.Tensor(flags.alpha_modalities).to(flags.device);
    idx_start = [];
    idx_end = []
    for k in range(0, num_components):
        if k == 0:
            i_start = 0;
        else:
            i_start = int(idx_end[k-1]);
        if k == w_modalities.shape[0]-1:
            i_end = num_samples;
        else:
            i_end = i_start + int(torch.floor(num_samples*w_modalities[k]));
        idx_start.append(i_start);
        idx_end.append(i_end);
    idx_end[-1] = num_samples;
    mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]);
    logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]);
    return [mu_sel, logvar_sel];


def calc_log_probs(exp, result, batch):
    mods = exp.modalities;
    log_probs = dict()
    weighted_log_prob = 0.0;
    #pdb.set_trace()
    for m, m_key in enumerate(mods.keys()):
        mod = mods[m_key]
       # pdb.set_trace()
        log_probs[mod.name] = -mod.calc_log_prob(result['rec'][mod.name],
                                                 batch[mod.name],
                                                 exp.flags.batch_size);
        weighted_log_prob += exp.rec_weights[mod.name]*log_probs[mod.name];
    return log_probs, weighted_log_prob;


def calc_klds(exp, result):
    latents = result['latents']['subsets'];
    klds = dict();
    for m, key in enumerate(latents.keys()):
        mu, logvar = latents[key];
        klds[key] = calc_kl_divergence(mu, logvar,
                                       norm_value=exp.flags.batch_size)
    return klds;


def calc_klds_style(exp, result):
    latents = result['latents']['modalities'];
    klds = dict();
    for m, key in enumerate(latents.keys()):
        if key.endswith('style'):
            mu, logvar = latents[key];
            klds[key] = calc_kl_divergence(mu, logvar,
                                           norm_value=exp.flags.batch_size)
    return klds;


def calc_style_kld(exp, klds):
    mods = exp.modalities;
    style_weights = exp.style_weights;
    weighted_klds = 0.0;
    for m, m_key in enumerate(mods.keys()):
        weighted_klds += style_weights[m_key]*klds[m_key+'_style'];
    return weighted_klds;


def calc_klds_cvib(exp, result):
    latents = result['latents']['modalities'];
    joint_mu, joint_logvar = result['latents']['joint']
    klds = dict();
    kld_losses = 0.0
    for m, key in enumerate(latents.keys()):
        if 'style' not in key:
            mu, logvar = latents[key];
            klds[key] = calc_kl_divergence(joint_mu, joint_logvar, mu, logvar,
                                           norm_value=exp.flags.batch_size)
            kld_losses += klds[key]
    # import pdb; pdb.set_trace()
    return klds, kld_losses

class Modality():
    def __init__(self, name, enc, dec, class_dim, style_dim, lhood_name='normal'):
        super(Modality, self).__init__()
        self.name = name;
        self.encoder = enc;
        self.decoder = dec;
        self.class_dim = class_dim;
        self.style_dim = style_dim;
        self.likelihood_name = lhood_name;
        self.likelihood = self.get_likelihood(lhood_name);


    def get_likelihood(self, name):
        if name == 'laplace':
            pz = dist.Laplace;
        elif name == 'bernoulli':
            pz = dist.Bernoulli;
        elif name == 'normal':
            pz = dist.Normal;
        elif name == 'categorical':
            pz = dist.OneHotCategorical;
        else:
            print('likelihood not implemented')
            pz = None;
        return pz;
    def calc_log_prob(self, out_dist, target, norm_value):
        log_prob = out_dist.log_prob(target).sum();
        mean_val_logprob = log_prob/norm_value;
        return mean_val_logprob;


class BaseMMVae(nn.Module):
    def __init__(self,modnum=2,batchsize=200):
        super().__init__()
        self.num_modalities = modnum
        #import argparse

        
        flags = MyClass()
        use_cuda = torch.cuda.is_available()
        flags.device = torch.device('cuda' if use_cuda else 'cpu')
        flags.batch_size = batchsize
        #pdb.set_trace()
        flags.style_dim = 0
        flags.class_dim = 200
        #flags.likelihood = 0.7
        flags.modality_ivw = True
        flags.num_mods = modnum  # set number of modalities dynamically
        
        flags.div_weight_uniform_content = 1 / (flags.num_mods + 1)
        flags.alpha_modalities = [flags.div_weight_uniform_content]
        
        flags.div_weight = 1 / (flags.num_mods + 1)
        flags.alpha_modalities.extend([flags.div_weight for _ in range(flags.num_mods)])
        flags.factorized_representation= False
        self.flags = flags
        #pdb.set_trace()
        self.modalities = self.set_modalities()
        #pdb.set_trace()
        self.subsets = self.set_subsets()
        #pdb.set_trace()
        self.rec_weights = self.set_rec_weights()
       
        self.set_fusion_functions();
        #pdb.set_trace()
        #pdb.set_trace()
        self.encoders = nn.ModuleDict()
        self.decoders = nn.ModuleDict()
        self.lhoods = dict()
        #pdb.set_trace()
        for m, m_key in enumerate(sorted(self.modalities.keys())):
            #pdb.set_trace()
            self.encoders[m_key] = self.modalities[m_key].encoder;
            self.decoders[m_key] = self.modalities[m_key].decoder;
            self.lhoods[m_key] = self.modalities[m_key].likelihood;
        # pdb.set_trace()
        # self.encoders = encoders;
        # self.decoders = decoders;
        # pdb.set_trace()
        # self.lhoods = lhoods;
        #pdb.set_trace()

    def set_rec_weights(self):
        rec_weights = dict()
        for k, m_key in enumerate(self.modalities.keys()):
            mod = self.modalities[m_key]
            #numel_mod = mod.data_size.numel()
            rec_weights[mod.name] = 1.0
        return rec_weights

    def set_modalities(self):
        mods = [Modality("m%d" % m,EncoderImg(self.flags).cuda(),
                       DecoderImg(self.flags).cuda(), self.flags.class_dim,
                       self.flags.style_dim, 'laplace') for m in range(self.num_modalities)]
      
        mods_dict = {m.name: m for m in mods}
        return mods_dict   # 为每种view 设置独立的encoder 和 decoder


    def set_subsets(self):
        num_mods = len(list(self.modalities.keys()));

        """
        powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3)
        (1,2,3)
        """
        xs = list(self.modalities)
        # note we return an iterator rather than a list
        subsets_list = chain.from_iterable(combinations(xs, n) for n in range(len(xs)+1))
        subsets = dict();
        for k, mod_names in enumerate(subsets_list):
            mods = [];
            for l, mod_name in enumerate(sorted(mod_names)):
                mods.append(self.modalities[mod_name])
                key = '_'.join(sorted(mod_names));
                subsets[key] = mods;
        return subsets;
    
    

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        return eps.mul(std).add_(mu)

    def truncated_z_sample(self, logvar, truncation=0.5, seed=None):
        state = None# if seed is None else np.random.RandomState(seed)
        values = truncnorm.rvs(-1*self.flags.trunc_range, self.flags.trunc_range, size=logvar.shape, random_state=state)
        values = truncation * values.astype('float32')

        return torch.from_numpy(values).cuda()


    def set_fusion_functions(self):
        weights = reweight_weights(torch.Tensor(self.flags.alpha_modalities));
        self.weights = weights.to(self.flags.device);
        
        self.modality_fusion = self.ivw_fusion;
        self.fusion_condition = self.fusion_condition_poe;
        self.calc_joint_divergence = self.divergence_static_prior;


    def divergence_static_prior(self, mus, logvars, weights=None):
        if weights is None:
            weights=self.weights;
        weights = weights.clone();
        weights = reweight_weights(weights);
        div_measures = calc_group_divergence_moe(self.flags,
                                                 mus,
                                                 logvars,
                                                 weights,
                                                 normalization=self.flags.batch_size);
        divs = dict();
        divs['joint_divergence'] = div_measures[0];
        divs['individual_divs'] = div_measures[1];
        divs['dyn_prior'] = None;
        return divs;


    def divergence_dynamic_prior(self, mus, logvars, weights=None):
        if weights is None:
            weights = self.weights;
        div_measures = calc_alphaJSD_modalities(self.flags,
                                                mus,
                                                logvars,
                                                weights,
                                                normalization=self.flags.batch_size);
        divs = dict();
        divs['joint_divergence'] = div_measures[0];
        divs['individual_divs'] = div_measures[1];
        divs['dyn_prior'] = div_measures[2];
        return divs;


    def moe_fusion(self, mus, logvars, weights=None):
        if weights is None:
            weights = self.weights;
        weights = reweight_weights(weights);
        #mus = torch.cat(mus, dim=0);
        #logvars = torch.cat(logvars, dim=0);
        mu_moe, logvar_moe = mixture_component_selection(self.flags,
                                                               mus,
                                                               logvars,
                                                               weights);
        return [mu_moe, logvar_moe];


    def poe_fusion(self, mus, logvars, weights=None):
        # self.flags.modality_poe or
        if ( (self.flags.modality_poe and mus.shape[0]>1) or mus.shape[0] == len(self.modalities.keys()) ):
            num_samples = mus[0].shape[0];
            mus = torch.cat((mus, torch.zeros(1, num_samples,
                             self.flags.class_dim).to(self.flags.device)),
                             dim=0);
            logvars = torch.cat((logvars, torch.zeros(1, num_samples,
                                 self.flags.class_dim).to(self.flags.device)),
                                 dim=0);
        #mus = torch.cat(mus, dim=0);
        #logvars = torch.cat(logvars, dim=0);
        mu_poe, logvar_poe = poe(mus, logvars);
        return [mu_poe, logvar_poe];


    def ivw_fusion(self, mus, logvars, weights=None):
        mu_poe, logvar_poe = poe(mus, logvars)
        return [mu_poe, logvar_poe]


    def fusion_condition_moe(self, subset, input_batch=None):
        if len(subset) == 1:
            return True;
        else:
            return False;


    def fusion_condition_poe(self, subset, input_batch=None):
        if len(subset) == len(input_batch.keys()):
            return True;
        else:
            return False;


    def fusion_condition_joint(self, subset, input_batch=None):
        return True;


    def forward(self, input_batch):
        latents = self.inference(input_batch);
        results = dict();
        results['latents'] = latents;
        results['group_distr'] = latents['joint'];
        class_embeddings = self.reparameterize(latents['joint'][0],
                                                latents['joint'][1]);   # 产生特征
        div = self.calc_joint_divergence(latents['mus'],
                                         latents['logvars'],
                                         latents['weights']);
        for k, key in enumerate(div.keys()):
            results[key] = div[key];

        results_rec = dict();
        enc_mods = latents['modalities'];
        for m, m_key in enumerate(self.modalities.keys()):
            if m_key in input_batch.keys():
                m_s_mu, m_s_logvar = enc_mods[m_key + '_style'];
                if self.flags.factorized_representation:
                    m_s_embeddings = self.reparameterize(mu=m_s_mu, logvar=m_s_logvar);
                else:
                    m_s_embeddings = None;
                m_rec = self.lhoods[m_key](*self.decoders[m_key](m_s_embeddings, class_embeddings));
                results_rec[m_key] = m_rec;
        results['rec'] = results_rec;
        return results;

    def encode(self, input_batch):
        latents = dict();
        for m, m_key in enumerate(self.modalities.keys()):
            if m_key in input_batch.keys():
                i_m = input_batch[m_key];
                #pdb.set_trace()
                l = self.encoders[m_key](i_m)
                latents[m_key + '_style'] = l[:2] # style： mu var
                latents[m_key] = l[2:]  # class： mu var
            else:
                latents[m_key + '_style'] = [None, None];
                latents[m_key] = [None, None];
        return latents;


    def inference(self, input_batch, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size;
        latents = dict();
        enc_mods = self.encode(input_batch);
       # pdb.set_trace()
        latents['modalities'] = enc_mods;
        mus = torch.Tensor().to(self.flags.device);
        logvars = torch.Tensor().to(self.flags.device);
        distr_subsets = dict();
        #pdb.set_trace()
        for k, s_key in enumerate(self.subsets.keys()):
            if s_key != '':
                mods = self.subsets[s_key];  # 这个子集里的mod
                mus_subset = torch.Tensor().to(self.flags.device);
                logvars_subset = torch.Tensor().to(self.flags.device);
                mods_avail = True
                try:
                    for m, mod in enumerate(mods):
                        if mod.name in input_batch.keys():
                            mus_subset = torch.cat((mus_subset,
                                                enc_mods[mod.name][0].unsqueeze(0)),
                                               dim=0);
                            logvars_subset = torch.cat((logvars_subset,
                                                    enc_mods[mod.name][1].unsqueeze(0)),
                                                   dim=0);  # subset_modnum * batch * dim
                        else:
                            mods_avail = False;
                except:
                    pdb.set_trace()
                if mods_avail:  # 输入数据包含这个子集
                    weights_subset = ((1/float(len(mus_subset)))*
                                      torch.ones(len(mus_subset)).to(self.flags.device));
                    #pdb.set_trace()
                    try:
                        s_mu, s_logvar = self.modality_fusion(mus_subset,
                                                          logvars_subset,
                                                          weights_subset);   # 子集内 view fusion
                        distr_subsets[s_key] = [s_mu, s_logvar];
                    
                    # if self.flags.modality_jsd and mus_subset.shape[0] > 1:
                    #     weights_subset = ((1 / float(len(mus_subset)+1)) *
                    #                       torch.ones(len(mus_subset)+1).to(self.flags.device));
                    #     mu_zero = torch.zeros([1] + list( mus_subset.shape[1:] )).to(self.flags.device)
                    #     logvar_zero = torch.zeros([1] + list( mus_subset.shape[1:] )).to(self.flags.device)
                    #     mus_subset = torch.cat((mu_zero, mus_subset), dim=0)
                    #     logvars_subset = torch.cat((logvar_zero, logvars_subset), dim=0)

                    #     mu_prime, logvar_prime = alpha_poe(weights_subset,
                    #                                        mus_subset,
                    #                                        logvars_subset)
                    #     distr_subsets[s_key] = [mu_prime, logvar_prime]
                        if self.fusion_condition(mods, input_batch):  # 默认setting下，必须要子集和输入batch的view一模一样，才行
                            mus = torch.cat((mus, s_mu.unsqueeze(0)), dim=0);   # 包含子集数 * batchnum *dim
                            logvars = torch.cat((logvars, s_logvar.unsqueeze(0)),
                                            dim=0);
                    except:
                        pdb.set_trace()
        # if self.flags.modality_jsd:
        #     mus = torch.cat((mus,
        #                      torch.zeros(1, num_samples, self.flags.class_dim).to(self.flags.device)),
        #                     dim=0);
        #     logvars = torch.cat((logvars,
        #                          torch.zeros(1, num_samples, self.flags.class_dim).to(self.flags.device)),
        #                         dim=0);
        #weights = (1/float(len(mus)))*torch.ones(len(mus)).to(self.flags.device);
        #print(1)
        try:
            weights = (1/float(mus.shape[0]))*torch.ones(mus.shape[0]).to(self.flags.device);
            joint_mu, joint_logvar = self.moe_fusion(mus, logvars, weights);  # fusion
        except Exception as e:
            print(e)
            pdb.set_trace()
        #mus = torch.cat(mus, dim=0);
        #logvars = torch.cat(logvars, dim=0);
        latents['mus'] = mus; # 包含子集数 * batchnum *dim
        latents['logvars'] = logvars; # 包含子集数 * batchnum *dim
        latents['weights'] = weights; # 包含子集数 * batchnum *dim
        latents['joint'] = [joint_mu, joint_logvar]; # batchnum *dim 混合包含子集
        latents['subsets'] = distr_subsets; 
        return latents;


    def generate(self, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size;

        mu = torch.zeros(num_samples,
                         self.flags.class_dim).to(self.flags.device);
        logvar = torch.zeros(num_samples,
                             self.flags.class_dim).to(self.flags.device);
        z_class = self.reparameterize(mu, logvar); #self.truncated_z_sample(logvar, truncation=self.flags.trunc_rate) #
        z_styles = self.get_random_styles(num_samples);
        random_latents = {'content': z_class, 'style': z_styles};
        random_samples = self.generate_from_latents(random_latents);
        return random_samples;


    def generate_sufficient_statistics_from_latents(self, latents):
        suff_stats = dict();
        content = latents['content']
        for m, m_key in enumerate(self.modalities.keys()):
            s = latents['style'][m_key];
            cg = self.lhoods[m_key](*self.decoders[m_key](s, content));
            suff_stats[m_key] = cg;
        return suff_stats;


    def generate_from_latents(self, latents):
        suff_stats = self.generate_sufficient_statistics_from_latents(latents);
        cond_gen = dict();
        for m, m_key in enumerate(latents['style'].keys()):
            cond_gen_m = suff_stats[m_key].mean;
            cond_gen[m_key] = cond_gen_m;
        return cond_gen;


    def cond_generation(self, latent_distributions, num_samples=None):
        if num_samples is None:
            num_samples = self.flags.batch_size;

        style_latents = self.get_random_styles(num_samples);
        cond_gen_samples = dict();
        for k, key in enumerate(latent_distributions.keys()):
            [mu, logvar] = latent_distributions[key];
            content_rep = self.reparameterize(mu=mu, logvar=logvar);
            latents = {'content': content_rep, 'style': style_latents}
            cond_gen_samples[key] = self.generate_from_latents(latents);
        return cond_gen_samples;


    def get_random_style_dists(self, num_samples):
        styles = dict();
        for k, m_key in enumerate(self.modalities.keys()):
            mod = self.modalities[m_key];
            s_mu = torch.zeros(num_samples,
                               mod.style_dim).to(self.flags.device)
            s_logvar = torch.zeros(num_samples,
                                   mod.style_dim).to(self.flags.device);
            styles[m_key] = [s_mu, s_logvar];
        return styles;


    def get_random_styles(self, num_samples):
        styles = dict();
        for k, m_key in enumerate(self.modalities.keys()):
            if self.flags.factorized_representation:
                mod = self.modalities[m_key];
                z_style = torch.randn(num_samples, mod.style_dim);
                z_style = z_style.to(self.flags.device);
            else:
                z_style = None;
            styles[m_key] = z_style;
        return styles;


    def save_networks(self):
        for k, m_key in enumerate(self.modalities.keys()):
            torch.save(self.encoders[m_key].state_dict(),
                       os.path.join(self.flags.dir_checkpoints, 'enc_' +
                                    self.modalities[m_key].name))
            torch.save(self.decoders[m_key].state_dict(),
                       os.path.join(self.flags.dir_checkpoints, 'dec_' +
                                    self.modalities[m_key].name))


def MVTCAE_fit_transform(train,test,dim=100,latent_dim=100,epochs=50):
    torch.set_default_dtype(torch.float64)
    #X = multi_view[0]
    #Y = multi_view[1]
    #pdb.set_trace()
    #N = multi_view[0].shape[0]
    #Y = multi_view[1]
    
    #multi_view_tensor = [torch.Tensor(view).cuda() for view in multi_view]
   
    in_dims = [view.shape[1] for view in train]
    dataset = NumpyDataset(train, labels=None)
    #LATENT_DIMS = dim
    #EPOCHS = epochs
    # beta_style=0.0
    # beta_content = 1.0
    # rec_weight = 1.0;
    mm_vae = BaseMMVae(modnum=len(train),batchsize=2000).cuda()
    
    #optimzer = torch.optim.Adam(mm_vae.parameters(), lr=1e-2)
    optimzer = torch.optim.Adam(mm_vae.parameters(),
                               lr=0.001,
                               betas=(0.9,0.999))
    dataloader = get_dataloaders(dataset=dataset,batch_size=2000)
    for i in tqdm(range(epochs)):
        
        for batch_idx, batch in enumerate(dataloader):

            views = {'m{}'.format(i):view.cuda() for i,view in enumerate(batch["views"])}
            #views.append(1)
            #pdb.set_trace()
            results = mm_vae(views)
            #pdb.set_trace()
            log_probs, weighted_log_prob = calc_log_probs(mm_vae, results, views);
            #pdb.set_trace()
            group_divergence = results['joint_divergence'];

            klds = calc_klds(mm_vae, results);

    
            kld_style = 0.0;

            n_views = mm_vae.num_modalities
            tc_ratio = 5.0/6.0
        

            klds_cvib_dict, klds_cvib = calc_klds_cvib(mm_vae, results)
            rec_weight = (n_views - tc_ratio) / n_views  # 0.58
            cvib_weight = tc_ratio / n_views  # 0.41
            vib_weight = 1 - tc_ratio  # 0.16
            beta = 2.5
            #pdb.set_trace()
            kld_weighted = cvib_weight * klds_cvib + vib_weight * group_divergence
            total_loss = rec_weight * weighted_log_prob + beta * kld_weighted
            print(rec_weight,beta)
            print(i,total_loss.item(),weighted_log_prob.item(), kld_weighted.item())
            #print(total_loss.item())
            #pdb.set_trace()
            optimzer.zero_grad()
            total_loss.backward()
            optimzer.step()
           # pdb.set_trace()
            #trainer.train_step(views)
            #print(trainer.loss_items['loss/I_z1_z2'][-1])
            
    #pdb.set_trace()
    #trainer = trainer.to('cpu')
    with torch.no_grad():
        mm_vae.eval()
        res = []
 
        for i in range(0,len(test[0]),2000):
    #         #pdb.set_trace()
            #batch_views = [torch.Tensor(view)[i:i+500,:].cuda() for view in multi_view]
            views = {'m{}'.format(K):torch.Tensor(view)[i:i+2000,:].cuda() for K,view in enumerate(test)}
    #         #pdb.set_trace()
            view_project = mm_vae.inference(views,num_samples=2000)
            #pdb.set_trace()

            #embedding_here =  [view.cpu().detach().numpy() for view in view_project]
    #         #pdb.set_trace()
            #for j in range(len(multi_view)):
            if len(test)==2:

                data =view_project['subsets']['m0_m1'][0].cpu().data.numpy()
            elif len(test)==3:
                data =view_project['subsets']['m0_m1_m2'][0].cpu().data.numpy()
            elif len(test)==4:
                data =view_project['subsets']['m0_m1_m2_m3'][0].cpu().data.numpy()
            elif len(test)==5:
                data =view_project['subsets']['m0_m1_m2_m3_m4'][0].cpu().data.numpy()
            res.append(data)
            #pdb.set_trace()
    
    res = np.concatenate(res,axis=0)
    #pdb.set_trace()

    # #pdb.set_trace()

    # #trainer = trainer.to('cpu')
    return res
    