import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import squareform 

class FeatureSelector:
    """Selects features from a set of data according to their correlations"""

    def __init__(self, data: pd.DataFrame, columns: list[str]):
        self.data = data
        self.columns = columns 

    def corr_matrix(self):
        features = self.data[self.columns]
        return features.corr()
    
    def cluster(self, threshold):
        corr = self.corr_matrix()
        corr.fillna(0, inplace=True)
        dissimilarity = 1 - abs(corr)
        for i in range(1, len(corr)):
            dissimilarity.iloc[i, i] = 0 
        Z = linkage(squareform(dissimilarity.values), 'complete')
        labels = fcluster(Z, threshold, criterion='distance')
        clusters = {}
        for c, l in zip(self.columns, labels):
            if l in clusters: clusters[l].append(c)
            else: clusters[l] = [c]
        return list(clusters.values())
    
    def select_features(self, clusters):
        return [c[0] for c in clusters]