import torch
import pytorch_lightning as pl
from src.logger.utils import log_cond_modalities, log_modalities
MODEL_STR = "AE"

class AE(pl.LightningModule):

    def __init__(self, modality,test_loader = None, regularization = None , alpha = 0.0 , lr =0.001, decay =0.0 ,train_loader = None):

        super(AE, self).__init__()
        self.lr = lr
        self.decay = decay
        self.modality = modality
        self.latent_dim =  self.modality.latent_dim
        self.encoder = self.modality.enc
        self.decoder = self.modality.dec
        self.regularization = regularization
        self.test_loader = test_loader
        self.alpha = alpha
        self.train_loader = train_loader
        self.save_hyperparameters(ignore= ["modality","encoder","test_loader","decoder"])

    def training_step(self, x) :
        
        self.train()
      
        x = x[0][self.modality.name]
        batch_size = x.size(0)
        recon ,z  = self.forward(x)  
        
        regularization = 0.0
        if self.regularization != None:
            if self.regularization == "l1":
                regularization = torch.abs(z).sum() 
            elif self.regularization == "l2":
                regularization = torch.square(z).sum()
                 
        recon_loss = self.reconstruction_loss(x,recon)
        total_loss= recon_loss + self.alpha * regularization
        
        self.logger.experiment.add_scalar("loss/train", total_loss/batch_size, self.global_step)
        return{"loss":total_loss / batch_size, "recon_loss": recon_loss.detach() / batch_size, "regularization": regularization / batch_size} 




    def test_step(self, x, batch_idx):
        
        x = x[0][self.modality.name]
        batch_size = x.size(0)
        recon ,z  = self.forward(x)  
        
        regularization = 0.0
        if self.regularization != None:
            if self.regularization == "l1":
                regularization = torch.abs(z).sum() 
            elif self.regularization == "l2":
                regularization = torch.square(z).sum()
                 
        recon_loss = self.reconstruction_loss(x,recon)
        total_loss= recon_loss + self.alpha * regularization
        
        self.logger.experiment.add_scalar("loss/test", total_loss/batch_size, self.global_step)
        
        return{"loss":total_loss/batch_size} 


    def validation_step(self, x, batch_idx):
        x = x[0][self.modality.name]
        batch_size = x.size(0)
        recon ,z  = self.forward(x)  
        
        regularization = 0.0
        if self.regularization != None:
            if self.regularization == "l1":
                regularization = torch.abs(z).sum() 
            elif self.regularization == "l2":
                regularization = torch.square(z).sum()
                 
        recon_loss = self.reconstruction_loss(x,recon)
        total_loss= recon_loss + self.alpha * regularization
        self.logger.experiment.add_scalar("loss/test", total_loss/batch_size, self.global_step)
        return{"loss":total_loss/batch_size}



    def training_epoch_end(self, training_step_outputs):
        
        if self.current_epoch % 10 ==0:
            self.encoder.eval()
            self.decoder.eval()
            test_batch = next(iter(self.test_loader))[0][self.modality.name].to(self.device) 
            train_batch = next(iter(self.train_loader))[0][self.modality.name].to(self.device) 
            print("Doing reconstruction")
            with torch.no_grad():
                recon, z = self.forward(test_batch)
                recon_train, z_train = self.forward(train_batch)
            
            print("test : std  : " + str(z.std().detach()) + "  mean : " +str(z.mean().detach()) )
            print("train : std  : " + str(z_train.std().detach()) + "  mean : " +str(z_train.mean().detach()) )
            
            log_modalities(self.logger, {self.modality.name:test_batch}, [self.modality], self.current_epoch ,prefix="real_test/" ,nb_samples=8)
            log_modalities(self.logger, {self.modality.name:recon }, [self.modality], self.current_epoch ,prefix="recon_test/" ,nb_samples=8)
            
            log_modalities(self.logger, {self.modality.name:train_batch}, [self.modality], self.current_epoch ,prefix="real_train/" ,nb_samples=8)
            log_modalities(self.logger, {self.modality.name:recon_train }, [self.modality], self.current_epoch ,prefix="recon_train/" ,nb_samples=8)
      
          

    def encode(self, x):
        return self.encoder(x)


    def decode(self, z):
        return self.decoder(z)


    def forward(self, x):
        z = self.encode(x)
        return self.decode(z) ,z


    
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr= self.lr , betas=(0.9,0.999), weight_decay=self.decay
                                    # ,amsgrad=True 
                                    )
        return optimizer
    
    def reconstruction_loss(self, x,recon):
        return  - self.modality.calc_log_prob(x,recon)

