from kmeans_l2 import Kmeans_l2
import numpy as np

def test_cost_to_centers():
    kmeans = Kmeans_l2(1, 1)
    cost = kmeans.get_cost_for_centers(np.array([[1,2,3], [4,5,6], [-2,-1,0]]), np.array([[1,2,3]]))
    assert(cost==54)
    
def test_cost_to_centers_multiple_centers():
    kmeans = Kmeans_l2(2, 1)
    cost = kmeans.get_cost_for_centers(np.array([[1,2,3], [4,5,6], [-2,-1,0]]), np.array([[1,2,3],[4,5,6]]))
    cost_original = kmeans.get_original_cost(np.array([[1,2,3], [4,5,6], [-2,-1,0]]))
    assert(cost==27)
    assert(cost > cost_original)

def test_numcenter_equal_numpoints():
    kmeans = Kmeans_l2(3, 1)
    cost_original = kmeans.get_original_cost(np.array([[1,2,3], [4,5,6], [-2,-1,0]]))
    assert(cost_original==0)