import numpy as np
import torch

from error_sampler.ErrorSampler import ErrorSampler


class MarginalErrorSampler(ErrorSampler):
    def __init__(self):
        super().__init__()
        self.cal_errors = None

    def calibrate(self, x_cal, z_cal, y_cal, errors_cal, deleted_cal, **kwargs):
        self.cal_errors = errors_cal.squeeze()

    def sample_error(self, x_test: torch.Tensor, z_test: torch.Tensor) -> torch.Tensor:
        sampled_errors = torch.Tensor(
            np.random.choice(self.cal_errors.detach().cpu().numpy(), size=len(x_test))).to(x_test.device)
        return sampled_errors

    @property
    def name(self) -> str:
        return "marginal_error_sampler"
