import abc
from typing import List

import torch
from tensordict import TensorDict

from clustering.clustering_method import ClusteringMethod
from clustering.one_dimensional_clustering import OneDimensionalClustering
from imputation_methods.imputation_utils import construct_histogram
from models.LinearModel import LinearModel



class LinearClustering(ClusteringMethod):
    def __init__(self, dataset_name: str, saved_models_path: str, figures_dir: str, seed):
        super().__init__()
        self.one_dimensional_clustering = OneDimensionalClustering()
        self.linear_model = LinearModel(dataset_name, saved_models_path, figures_dir, seed)
        self.is_linear_model_fit = False

    def reduce_x_to_1_dim(self, x: torch.Tensor) -> torch.Tensor:
        if not self.is_linear_model_fit:
            print("warning: using linear model before fitting")
        return self.linear_model.predict(x)

    def fit(self, x: torch.Tensor, y: torch.Tensor, more_features: torch.Tensor=None,
            **kwargs):
        # if more_features is None or len(more_features.shape) >= 3 or isinstance(more_features, TensorDict):
        self.linear_model.fit(x, y)
        self.is_linear_model_fit = True
        reduced_x = self.reduce_x_to_1_dim(x)
        self.one_dimensional_clustering.fit(reduced_x, **kwargs)

    def predict_cluster(self, x):
        reduced_x = self.reduce_x_to_1_dim(x)
        return self.one_dimensional_clustering.predict_cluster(reduced_x)

    @property
    def name(self):
        return "linear_clustering"