import abc
from typing import List

import torch

from clustering.clustering_method import ClusteringMethod
from imputation_methods.imputation_utils import construct_histogram


class OneDimensionalClustering(ClusteringMethod):
    def __init__(self):
        super().__init__()
        self.bin_edges = None

    @staticmethod
    def compute_cluster_from_bin_edges(bin_edges: torch.Tensor, x: torch.Tensor):
        x = x.squeeze()
        assert len(x.shape) == 1
        n_bins = len(bin_edges) - 1
        x = x.unsqueeze(-1).repeat(1, n_bins)
        cluster_indicator = (bin_edges[:-1] <= x) & (x <= bin_edges[1:])
        clusters = torch.argmax(cluster_indicator.float(), dim=-1)
        return clusters

    def fit(self, x: torch.Tensor, min_bin_size=100, **kwargs):
        x = x.squeeze()
        assert len(x.shape) == 1
        self.bin_edges = construct_histogram(x, min_bin_size=min_bin_size)
        if not torch.is_tensor(self.bin_edges):
            self.bin_edges = torch.Tensor(self.bin_edges).to(x.device)

    def predict_cluster(self, x):
        if self.bin_edges is None:
            print(f"warning: {self.name} was used before fit")
            return OneDimensionalClustering.compute_cluster_from_bin_edges(torch.Tensor([0,1]).to(x.device), x)
        return OneDimensionalClustering.compute_cluster_from_bin_edges(self.bin_edges, x)

    @property
    def name(self):
        return "1d_clustering"
