import warnings
import numpy as np
from numpy.linalg import svd
from numba import jit
from sklearn.decomposition import TruncatedSVD
import cProfile
class K_Subspaces_algo:
    
    def __init__(self, k, dim, max_iter=100):
        self.k = k
        self.dim = dim
        self.centers = None
        self.training_data = None
        self.history = []
        self.max_iter = max_iter
    
    def get_original_solution(self, data):
        best_cost = None
        best_centers = None
        for i in range(10):
            self.fit(data) 
            score = self.score(data)
            if best_cost == None or best_cost>score:
                best_cost = score
                best_centers = self.centers    
        return self.centers, self.score(data)
    
    def get_subset_solution_original(self, dataset, subset_data):
        self.fit(subset_data)
        subset_cost = self.score(subset_data)
        original_cost = self.score(dataset)
        return subset_cost, original_cost, self.centers
        
    def check_convergence(self):
        lookback = 5
        if len(self.history) < lookback:
            return False
        for i in range(1, lookback+1):
            if self.history[-1]<0.97*self.history[-i]:
                return False
        print(f"early stopped with {self.history}")
        return True
    
    def fit(self, data):
        self.history = []
        self.training_data = data
        self.centers = self.get_initial_subspaces2(data)
        for i in range(self.max_iter):
            print("iteration ", i)
            clusters, cost = self.assign_points_batch(data)
            self.history.append(cost)
            if(self.check_convergence()):
                break
            for j in range(self.k):
                if clusters[j].shape[0] == 0:
                    continue
                self.centers[j] = self.find_best_subspace(clusters[j])
        print(self.history)

          
    def score(self, data):
        clusters, cost = self.assign_points_batch(data)
        return cost
    
    def find_best_subspace(self, points):
        if(points.shape[0]<self.dim):
            print("got into it")
            extra_points = self.training_data[np.random.choice(self.training_data.shape[0], 1, replace=False)]
            point = np.concatenate((points, extra_points))
            svd_out1 = TruncatedSVD(n_components=self.dim).fit(point).components_
            return svd_out1[0]
        svd_out1 = self.silent_svd(self.dim, points)
        if svd_out1.shape[0] < self.dim:
            print("failing here")
        return svd_out1
    
    def calculate_cost(self, point, center):
        # UUt = np.outer(center,center)
        UUt = center.T@center
        #Verify if used for more than one point.
        eye = np.eye(point.shape[0])
        diff = np.eye(point.shape[0]) - UUt
        proj = diff@point.T
        dist =  np.linalg.norm(proj, axis=0)**2
        cost = np.sum(dist)
        return cost
    
    def assign_points_batch(self, data):
        dists, clusters = self.get_dists_to_nearest_center(data)
        cost = np.sum(dists)
        clusters_list = [data[clusters==i] for i in range(self.k)]
        return clusters_list, cost

    def get_initial_subspaces(self, data):
        random_groups = []
        for i in range(self.k):
            random_groups.append(data[np.random.choice(data.shape[0], self.dim,replace=False)])
        subspaces1 = []
        for center_points in random_groups:
            if center_points.shape[0] == 1:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    svd_out = TruncatedSVD(n_components=self.dim).fit(center_points).components_
            else:
                svd_out = TruncatedSVD(n_components=self.dim).fit(center_points).components_
            subspaces1.append(svd_out)
        
        subspaces = np.array(subspaces1)        
        return subspaces
    
    def silent_svd(self, dim, data):
        if data.shape[0] == 1:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                svd_out = TruncatedSVD(n_components=dim).fit(data).components_
            return svd_out
        else:
            svd_out = TruncatedSVD(n_components=dim).fit(data).components_
            return svd_out
    
    def get_dists_to_nearest_center(self, data, centers=None):
        if centers is None:
            centers = self.centers
        dists_to_centers = np.zeros((data.shape[0], len(centers)))
        for i, center in enumerate(centers):
            # UUt = np.outer(center,center)
            center_fixed_dim = center.reshape((self.dim, data.shape[1]))
            UUt = np.matmul(center_fixed_dim.T,center_fixed_dim)
            diff = np.eye(data.shape[1]) - UUt
            proj = diff@data.T
            dist =  np.linalg.norm(proj, axis=0)**2
            dists_to_centers[:,i] = dist
        clusters = np.argmin(dists_to_centers, axis=1)
        min_dists = np.min(dists_to_centers, axis=1)
        return min_dists, clusters
        
    def get_initial_subspaces2(self, data):
        centers = []
        centerpoints = data[np.random.choice(data.shape[0], self.dim, replace=False)]
        centers.append(self.silent_svd(self.dim, centerpoints))
        for i in range(1, self.k):
            dists_to_nearest_center, _ = self.get_dists_to_nearest_center(data, centers)
            #sample with probability proportional to distance to nearest center
            centerpoints = data[np.random.choice(data.shape[0], self.dim, replace=False, p=dists_to_nearest_center/np.sum(dists_to_nearest_center))]
            centers.append(self.silent_svd(self.dim, centerpoints))
        return np.array(centers)
        
        
        
        