from .abstract_imputation import Imputation

import torch
import torch.nn as nn


class NoDestructionImputation(Imputation):  
  def __init__(self, 
              nb_imputation_iwae = 1,
              nb_imputation_mc = 1,
              nb_imputation_iwae_test = None,
              nb_imputation_mc_test = None,
              reconstruction_reg = None,
              mask_reg = None,
              add_mask = False,
              post_process_regularization = None,
              **kwargs,
               ):
    super().__init__(nb_imputation_iwae = nb_imputation_iwae,
                    nb_imputation_mc = nb_imputation_mc,
                    nb_imputation_iwae_test = nb_imputation_iwae_test,
                    nb_imputation_mc_test = nb_imputation_mc_test,
                    reconstruction_reg=reconstruction_reg,
                    mask_reg=mask_reg,
                    add_mask= add_mask,
                    post_process_regularization=post_process_regularization,
                    **kwargs,
                  )

  def imputation_function(self, data, mask, index = None):
    return data



class MaskAsInput(Imputation):
  def __init__(self,
              nb_imputation_iwae = 1,
              nb_imputation_mc = 1,
              nb_imputation_iwae_test = None,
              nb_imputation_mc_test = None,
              reconstruction_reg = None,
              mask_reg = None,
              add_mask = False,
              post_process_regularization = None,
              **kwargs,
             ):
    super().__init__(nb_imputation_iwae = nb_imputation_iwae,
                    nb_imputation_mc = nb_imputation_mc,
                    nb_imputation_iwae_test = nb_imputation_iwae_test,
                    nb_imputation_mc_test = nb_imputation_mc_test,
                  reconstruction_reg=reconstruction_reg,
                  mask_reg=mask_reg,
                  add_mask= add_mask,
                  post_process_regularization=post_process_regularization,
                  )
    
  def imputation_function(self, data, mask, index = None):
    return mask
