
import numpy as np
import torch
import torch.nn as nn
from src.abstract.mrl import MRL


MODEL_STR = "LateFusionVAE"


class LateFusionVAE(MRL):
    """ 
    
    Late fusion of the latent space 
    
    Z can be determininstic (  AE ) or probilisitic ( VAE ).
    
    
    """

    def __init__(self, latent_dim,
                 modalities_list, 
                 train_loader,
                 test_loader,
                 model_name=MODEL_STR, 
                 subsampling_strategy="fullset",
                 beta=1,
                 annealing_beta_gradualy=False,
                 nb_samples = 8, 
                 batch_size=256,
                 num_train_lr = 500,
                 eval_epoch= 5,
                 do_evaluation=True,
                 do_fd = True,
                 log_epoch = 5,
                 n_fd = 5000,
                 lr = 0.001,
                 probabilistic = False,
                 do_class = False
                 ):
        
        self.probabilistic = probabilistic
        
        super(LateFusionVAE, self).__init__(
                latent_dim =latent_dim, 
                modalities_list=   modalities_list, 
                test_loader=test_loader,
                train_loader=train_loader,
                model_name=model_name,
                subsampling_strategy=subsampling_strategy,
                beta=beta,
                batch_size=batch_size,
                annealing_beta_gradualy=annealing_beta_gradualy,
                nb_samples = nb_samples, 
                num_train_lr = num_train_lr,
                eval_epoch = eval_epoch,do_evaluation=do_evaluation,
                do_fd = do_fd,
                log_epoch = log_epoch,
                n_fd = n_fd,
                lr = lr,
                do_class = do_class
                )

        self.posterior = LateFusion()
     
     
     
    def encode(self, x):

        encodings = {}
        for idx, modality in enumerate(self.modalities_list):
            if modality.name in x.keys():
                mod_data = x[modality.name]
                latent = self.encoders[idx](mod_data)
                encodings[modality.name] = latent
        return encodings


    def decode(self, z):
        decodings = {}
        for idx, modality in enumerate(self.modalities_list):
            decodings[modality.name] = self.decoders[idx](z[modality.name])
        return decodings

    def forward(self,x):
        
        encodings = self.encode(x)
       
        if self.probabilistic:
            z_mods = {}
            for mod in encodings.keys():
                z_mods[mod] = self.reparam(mu =encodings[mod][0],log_var= encodings[mod][1] )
        else :
            z_mods = encodings   
                
    
        reconstruction = self.decode(z_mods)

        return reconstruction,encodings
    
    def compute_loss(self, x):
        
        """_summary_
            compute the elbo loss as defined in the paper.        
        Returns:
            loss: Elbo loss
        """
        # get the encoding of all modalities present in x

        # training x should be without missing modality.
        reconstruction,posterior = self.forward(x)
        reconstruction_error = self.compute_reconstruction_error(x, reconstruction , self.batch_size)
        
        if self.probabilistic:
            KLD = self.compute_KLD(posterior, self.batch_size)
            loss = self.elbo_objectif(reconstruction_error= reconstruction_error["total"],KLD= KLD["KLD_joint"], beta=self.beta)
            return {"loss" : loss, "KLDs":KLD["KLDs"], "Rec_loss": reconstruction_error["rec_loss"] }
        else :
            loss = reconstruction_error["total"]
            return {"loss" : loss, "Rec_loss": reconstruction_error["rec_loss"] }
    
    
    


    
    def elbo_objectif(self, reconstruction_error, KLD, beta):
        return reconstruction_error + beta * KLD
    

    def compute_KLD(self, posterior,batch_size):

        encodings = posterior
        
 
        num_mod= len(encodings)
        
        #weights = (1/float(num_mod))*torch.ones(num_mod).type_as(mu)
        weights = torch.ones(num_mod).to(self.device)
        klds = torch.zeros(num_mod).to(self.device)
        
        kl_joint = 0
        kld_mods ={}
        for idx, key in enumerate(encodings.keys()) :
            mu,logvar = encodings[key]
            
            kl_mod=  self.Kl_div_gaussian(mu,logvar) / batch_size 
            kld_mods[key]= kl_mod
            kl_joint += kl_mod
            klds[idx] = kl_mod
        
        return {"KLD_joint"  :  (self.beta*klds).sum(dim=0) , "KLDs": kld_mods   }
    


    def compute_reconstruction_error(self, x, reconstruction, batch_size):
        recons_log = {}
        
        logprobs = torch.zeros(len(x)).type_as(x[self.modalities_list[0].name])
        #weights = torch.zeros(len(x)).type_as(x[self.modalities_list[0].name])
        for  idx, mod in enumerate( self.modalities_list ):
            logprobs[idx] = ( - mod.calc_log_prob( x[mod.name], reconstruction[mod.name] ) / batch_size )
            #weights[idx] = float(mod.reconstruction_weight)  
            recons_log[mod.name] = logprobs[idx]
        return  { "total": (logprobs).sum(dim=0), "rec_loss": recons_log}
    
    def sample(self, N):
        with torch.no_grad():
            z_mods = {}
            for mod in self.modalities_list:
                z_mods[mod.name] =  torch.randn(N,  mod.latent_dim, device=self.device)
            decodings = self.decode(z_mods)
            return decodings


   
    def conditional_gen_all_subsets(self, x,N=None):
        
        results = {}
        modalities_str = np.array([mod.name for mod in self.modalities_list])
        subsets = { ','.join(modalities_str[s]) : s for s in self.subset_list}
            
        with torch.no_grad():
            encodings = self.encode(x)
            for idx, s_key in enumerate(subsets):
                sub_encodings = {
                    modalities_str[mod_i] : encodings[modalities_str[mod_i]]   for mod_i in subsets[s_key]
                }
                if self.probabilistic:
                    z_mods = {}
                    for mod in sub_encodings.keys():
                        z_mods[mod] = self.reparam(*sub_encodings[mod])
                else :
                    z_mods = sub_encodings   
                
                posterior = self.posterior(z_mods)
                reconstruction = self.decode(posterior)
                
                results[s_key] = reconstruction
               # results[s_key] = x
        return results 
    
    
    
    def conditional_gen_latent_subsets(self, x):
        
        results = {}
        modalities_str = np.array([mod.name for mod in self.modalities_list])
        subsets = { ','.join(modalities_str[s]) : s for s in self.subset_list}
            
        with torch.no_grad():
            encodings = self.encode(x)
            for idx, s_key in enumerate(subsets):
                sub_encodings = {
                    modalities_str[mod_i] : encodings[modalities_str[mod_i]]   for mod_i in subsets[s_key]
                }
                if self.probabilistic:
                    z_mods = {}
                    for mod in sub_encodings.keys():
                        z_mods[mod] = self.reparam(*sub_encodings[mod])
                else :
                    z_mods = sub_encodings   
                
                posterior = self.posterior(z_mods)
                #reconstruction = self.decode(posterior)
                
                results[s_key] = [concat_vect(posterior)]
               
        return results 
    

def concat_vect(encodings):
    z = torch.Tensor()
    for key in encodings.keys():
        z = z.to(encodings[key].device)
        z = torch.cat( [z, encodings[key]],dim = -1 )
    return z 
    
    

class LateFusion(nn.Module):
    """Return parameters for product of independent experts as implemented in:
    See https://github.com/thomassutter/MoPoE

   
    
    @param mu: M x D for M experts
    @param logvar: M x D for M experts
    """

    def forward(self, encodings ):
        ## concat
        return encodings
    
    
 