import abc

import numpy as np
from sklearn.metrics import silhouette_score
import torch
from sklearn.cluster import KMeans
from tensordict import TensorDict

from clustering.clustering_method import ClusteringMethod


class Kmeans(ClusteringMethod):

    def __init__(self, n_clusters=8, use_x: bool = False):
        super().__init__()
        self.kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        self.n_clusters = n_clusters
        self.use_x = use_x
        self.was_fit = False

    def optimal_clusters(self, data, max_clusters=15):
        distortions = []
        silhouette_scores = []
        data_np = data.detach().cpu().numpy()

        for k in range(2, min(max_clusters, len(data_np)) + 1):
            kmeans = KMeans(n_clusters=k, random_state=42)
            labels = kmeans.fit_predict(data_np)
            distortions.append(kmeans.inertia_)
            silhouette_scores.append(silhouette_score(data_np, labels))

        best_k = np.argmax(silhouette_scores) + 2  # Adding 2 because range starts at 2
        return best_k

    def fit(self, x: torch.Tensor, more_features: torch.Tensor=None, **kwargs):
        if self.use_x:
            assert more_features is not None
            x = torch.cat([x, more_features], dim=-1)
        # if len(x) < self.n_clusters:
        #     self.n_clusters = max(len(x) // 2, 1)
        # else:
        #     self.n_clusters = self.optimal_clusters(x, max_clusters=15)

        self.kmeans = KMeans(n_clusters=8, random_state=42)
        self.kmeans.fit(x.detach().cpu().numpy())
        self.was_fit = True

    def calibrate(self, x: torch.Tensor, **kwargs):
        pass

    def predict_cluster(self, x):
        if not self.was_fit:
            print("warning: using kmeans_clustering_with_x model before fitting")
            self.fit(x)
        return torch.Tensor(self.kmeans.predict(x.detach().cpu().numpy())).to(x.device)

    @property
    def name(self):
        return f"kmeans_clustering"


