import timeit
import torch.nn.functional as F
import numpy as np
import unittest

from slot_attention.helpers.clustering_heuristic import batched_extract_query_mat, extract_query_mat, unbatched_extract_query_mat

class TestClustering(unittest.TestCase):
    
    # def test_cluster(self):
    #     # create a random np.array with 3 clusters
    #     var = 1
    #     length = 3
    #     cluster_0 = np.random.rand(5, 1) * length 
    #     cluster_points_0 = np.random.rand(5, 8) * var + cluster_0
        
    #     cluster_1 = np.random.rand(5, 1) * length
    #     cluster_points_1 = np.random.rand(5, 6) * var + cluster_1
        
    #     cluster_2 = np.random.rand(5, 1) * length
    #     print(f'cluster_2: {cluster_2}') 
    #     cluster_points_2 = np.random.rand(5, 4) * var + cluster_2
        
    #     cluster = np.concatenate((cluster_points_0, cluster_points_1, cluster_points_2), axis=1)
        
    #     # center and scale the cluster
    #     cluster = (cluster - cluster.mean(axis=1, keepdims=True)) / cluster.std(axis=1, keepdims=True)
    #     cluster = cluster.T
        
    #     # np.printoptions(precision=2, suppress=True)
    #     print(f'cluster:\n{cluster.round(0)}')
        
    #     dots = np.einsum('i d, j d -> i j', cluster, cluster)
    #     print(f'dots:\n{dots.round(0)}')
        
    #     centroid_mat = unbatched_extract_query_mat(cluster, 3)
    #     centroids = np.einsum('c n, n d -> c d', centroid_mat, cluster)

    #     print(f'centroids:\n{centroids.round(0)}')
        
    def test_cluster_with_batch(self):
        # create a random np.array with 3 clusters
        var = 1
        length = 3
        cluster_0 = np.random.rand(5, 1) * length 
        cluster_points_0 = np.random.rand(5, 8) * var + cluster_0
        
        cluster_1 = np.random.rand(5, 1) * length
        cluster_points_1 = np.random.rand(5, 6) * var + cluster_1
        
        cluster_2 = np.random.rand(5, 1) * length
        # print(f'cluster_2: {cluster_2}') 
        cluster_points_2 = np.random.rand(5, 4) * var + cluster_2
        
        cluster = np.concatenate((cluster_points_0, cluster_points_1, cluster_points_2), axis=1)
        
        # center and scale the cluster
        cluster = (cluster - cluster.mean(axis=1, keepdims=True)) / cluster.std(axis=1, keepdims=True)
        cluster = cluster.T
        
        # np.printoptions(precision=2, suppress=True)
        # print(f'cluster:\n{cluster.round(0)}')
        
        dots = np.einsum('i d, j d -> i j', cluster, cluster)
        # print(f'dots:\n{dots.round(0)}')
        
        # add dimension at index 0
        cluster = cluster[np.newaxis, :, :]
        cluster = np.repeat(cluster, 2, axis=0)
        centroid_mat = extract_query_mat(cluster, 3)
        centroids = np.einsum('b c n, b n d -> b c d', centroid_mat, cluster)
        
        # centroid_mat = unbatched_extract_query_mat(cluster, 3)
        # centroids = np.einsum('c n, n d -> c d', centroid_mat, cluster)

        print(f'centroids:\n{centroids.round(2)}')

if __name__ == '__main__':
    unittest.main()