from error_sampler.ErrorSampler import ErrorSampler
from imputation_methods.regression_imputations.RegressorImputation import RegressorImputation
from models.regressors.MeanRegressor import MeanRegressor


# TODO: RegressorImputationWithErrorSampling relies on RegressorImputation having and calibrating the base model.
#  This should be fixed
from utils.utils import get_y, get_seed, set_seeds


class RegressorImputationWithErrorSampling(RegressorImputation):
    def __init__(self, mean_regressor: MeanRegressor, error_sampler: ErrorSampler):
        super().__init__(mean_regressor)
        self.bin_edges = None
        self.proxy_cal = None
        self.error_sampler = error_sampler
        self.fit_error_sampler = None

    @property
    def name(self):
        return f"{self.mean_regressor.name}_with_{self.error_sampler.name}"

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000,
            batch_size=64, n_wait=20,
            **kwargs):
        if len(x_train) == 0:
            print(f"warning: {self.name} received an empty training set")
            self.fit_error_sampler = lambda: ()
            return
        super(RegressorImputationWithErrorSampling, self).fit(x_train, y_train, z_train, deleted_train, x_val,
                                                              y_val, z_val, deleted_val, epochs=epochs,
                                                              batch_size=batch_size,
                                                              n_wait=n_wait, **kwargs)

        def fit_error_sampler():
            estimated = self.mean_regressor.predict_mean(x_train, z_train).squeeze()
            train_errors = y_train.squeeze() - estimated
            estimated = self.mean_regressor.predict_mean(x_val, z_val).squeeze()
            val_errors = y_val.squeeze() - estimated
            if len(train_errors.shape) == 1:
                train_errors = train_errors.unsqueeze(-1)
            if len(val_errors.shape) == 1:
                val_errors = val_errors.unsqueeze(-1)
            self.error_sampler.fit(x_train, z_train, train_errors.detach(), y_train, deleted_train, x_val, z_val,
                                   y_val, val_errors.detach(), deleted_val,
                                   epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)

        self.fit_error_sampler = fit_error_sampler

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal)
        if self.fit_error_sampler is None:
            raise Exception(f"error: {self.name} is being calibrated before being fit.")
        seed = get_seed()
        set_seeds(42)
        self.fit_error_sampler()
        estimated = self.mean_regressor.predict_mean(x_cal, z_cal).squeeze()
        errors = get_y(y_cal).squeeze() - estimated
        # print(f"{self.name} errors q9: ", errors.quantile(q=0.95).item())
        # plt.hist(errors.cpu().detach().numpy())
        # plt.title(f"{self.name} errors")
        # plt.show()
        self.error_sampler.calibrate(x_cal, z_cal, y_cal, errors, deleted_cal)
        set_seeds(seed)


    def predict(self, x, z, **kwargs):
        mean_prediction = self.mean_regressor.predict_mean(x, z).squeeze()
        seed = get_seed()
        set_seeds(42)
        sampled_errors = self.error_sampler.sample_error(x, z).squeeze()
        set_seeds(seed)
        # plt.hist(sampled_errors.cpu().detach().numpy())
        # plt.title(f"{self.name} sampled errors")
        # plt.show()
        # print(f"{self.name} sampled_errors q9: ", sampled_errors.quantile(q=0.95).item())
        return mean_prediction + sampled_errors
