# import numpy as np
# import torch
#
# from clustering.clustering_method import ClusteringMethod
# from error_sampler.ErrorSampler import ErrorSampler
#
#
# class WeightedErrorSampler(ErrorSampler):
#
#     def __init__(self, base_error_sampler: ErrorSampler):
#         super().__init__()
#         self.base_error_sampler = base_error_sampler
#
#     def fit(self, x_train, z_train, errors_train, deleted_train, x_val, z_val, errors_val, deleted_val, **kwargs):
#         self.base_error_sampler.fit(z_train, errors_train.detach(), deleted_train, z_val, errors_val.detach(), deleted_val, **kwargs)
#
#     def sample_error(self, x_test, z_test):
#         sample = self.base_error_sampler.sample_error(x_test, z_test)
#
#     @property
#     def name(self) -> str:
#         return f"weighted_{self.base_error_sampler}"
#
#
