import numpy as np 

def norm_l12(A): 
	return sum(np.linalg.norm(A, ord=2, axis=0))

def find_V_l12(U, A): 
	V = np.linalg.pinv(U) @ A
	return V 

def pick_first_column_greedy_l12(A): 
	num_cols = A.shape[-1] 
	best_col_idx = -1 
	min_cost = float("inf") 
	for i in range(num_cols): 
		current_column = np.expand_dims(A[:, i], axis=-1)
		v_row = find_V_l12(current_column, A) 
		current_cost = norm_l12(A - current_column @ v_row)
		if current_cost < min_cost: 
			min_cost = current_cost
			best_col_idx = i
	return best_col_idx 

def pick_new_column_greedy_l12(A, U_old, V_old): 
	num_rows, num_cols = A.shape 
	min_cost = norm_l12(A - U_old @ V_old)
	best_col_idx = -1 
	V_new = None

	zero_column = np.zeros((num_rows, 1))
	U_new = np.concatenate([U_old, zero_column], axis=-1) 
	for k in range(num_cols): 
		U_new[:, -1] = A[:, k]
		V_test = find_V_l12(U_new, A)
		new_cost = norm_l12(A - U_new @ V_test) 
		if new_cost < min_cost: 
			min_cost = new_cost
			best_col_idx = k
			V_new = V_test 

	U_new[:, -1] = A[:, best_col_idx] 
	return U_new, V_new, best_col_idx 

def greedy_approx_l12(A, num_cols): 
	indices = []
	first_column_idx = pick_first_column_greedy_l12(A)
	indices.append(first_column_idx) 
	U_greedy = np.expand_dims(A[:, first_column_idx], axis=-1)
	V_greedy = find_V_l12(U_greedy, A) 
	for k in range(num_cols-1): 
		U_greedy, V_greedy, idx = pick_new_column_greedy_l12(A, U_greedy, V_greedy)
		indices.append(idx)
	return U_greedy, V_greedy, indices 