import numpy as np
import torch

from clustering.clustering_method import ClusteringMethod
from error_sampler.ErrorSampler import ErrorSampler


class ClusteringErrorSampler(ErrorSampler):

    def __init__(self, clustering_method: ClusteringMethod):
        super().__init__()
        self.clustering_method = clustering_method
        self.cal_errors = None
        self.cal_clusters = None

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        deleted = torch.cat([deleted_train, deleted_val], dim=0)
        x_train = torch.cat([x_train, x_val], dim=0)[~deleted]
        z_train = torch.cat([z_train, z_val], dim=0)[~deleted]
        y_train = torch.cat([y_train, y_val], dim=0)[~deleted]
        self.clustering_method.fit(z_train, more_features=x_train, y=y_train, **kwargs)

    def calibrate(self, x_cal, z_cal, y_cal, errors_cal, deleted_cal, **kwargs):
        self.cal_errors = errors_cal[~deleted_cal].squeeze()
        self.cal_clusters = self.clustering_method.predict_cluster(z_cal[~deleted_cal])

    def sample_error(self, x_test: torch.Tensor, z_test: torch.Tensor) -> torch.Tensor:
        device = x_test.device
        sampled_y2 = torch.zeros(len(x_test), device=device)
        test_clusters = self.clustering_method.predict_cluster(z_test)
        train_clusters = self.cal_clusters
        n_clusters = torch.max(train_clusters).int().item() + 1
        for c in range(n_clusters):
            test_bin_idx = (test_clusters == c).squeeze()
            cal_bin_idx = (train_clusters == c).squeeze()
            cal_cluster_size = cal_bin_idx.float().sum().int().item()
            test_cluster_size = test_bin_idx.float().sum().int().item()
            if cal_cluster_size == 0:
                assert cal_bin_idx.float().sum().item() == 0
                continue
            if cal_cluster_size == 1:
                sampled_y2[test_bin_idx] = self.cal_errors[cal_bin_idx].item()
                continue
            sampled_y2[test_bin_idx] = torch.Tensor(
                np.random.choice(self.cal_errors[cal_bin_idx].squeeze().detach().cpu().numpy(),
                                 size=test_cluster_size)).to(device)
        return sampled_y2

    @property
    def name(self) -> str:
        return f"{self.clustering_method.name}_error_sampler"
