from error_sampler.ErrorSampler import ErrorSampler
from models.CVAE_GAN import CVAE_GAN, CVAE_GAN_MODE


class CVAEErrorSampler(ErrorSampler):

    def __init__(self, dataset_name: str, y_dim, x_dim, z_dim, cvae_z_dim, device, seed:int, saved_models_path: str, mode: CVAE_GAN_MODE = CVAE_GAN_MODE.CVAE, dropout=0.1, lr=1e-3, wd=0.,
                 batch_norm=False,
                 kl_mult=0.01, figures_dir: str = None):
        super().__init__()
        self.cvae = CVAE_GAN(dataset_name, y_dim, z_dim, cvae_z_dim, device, seed, saved_models_path, mode, dropout=dropout,
                             lr=lr, wd=wd, batch_norm=batch_norm, kl_mult=kl_mult, figures_dir=figures_dir)

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        self.cvae.fit(z_train[~deleted_train], errors_train.detach()[~deleted_train],
                      z_val[~deleted_val], errors_val.detach()[~deleted_val], **kwargs)

    def sample_error(self, x_test, z_test):
        return self.cvae.sample_y(z_test)[0]

    @property
    def name(self) -> str:
        return f"{self.cvae.name}_error_sampler"
