
from causalml.inference.torch.cevae import CEVAE as CEAVE_Model
from causally.model.abstract_model import SKAbstractModel


class CEVAE(SKAbstractModel):
    def __init__(self, config,dataset):

        super(CEVAE, self).__init__(config,dataset)

        self.n_units = dataset.get_X_size()[0]

        self.outcome_dist = config['outcome_dist']
        self.latent_dim = config['latent_dim']
        self.hidden_dim = config['hidden_dim']
        self.num_epochs = config['num_epochs']
        self.learning_rate = config['learning_rate']
        self.learning_rate_decay = config['learning_rate_decay']
        self.num_layers = config['num_layers']
        self.batch_size = config['batch_size']

        self.model= CEAVE_Model(outcome_dist=self.outcome_dist,
              latent_dim=self.latent_dim,
              hidden_dim=self.hidden_dim,
              num_epochs=self.num_epochs,
              batch_size=self.batch_size,
              learning_rate=self.learning_rate,
              learning_rate_decay=self.learning_rate_decay,
              num_layers=self.num_layers)


    def calculate_loss(self, x,t,y,w):

        losses = self.model.fit(
            X=x,
            treatment=t,
            y=y,
        )


    def predict(self, x,t_0,t_1):

        y_tau = self.model.predict(X=x)

        return y_tau