import warp as wp
import jax
import jax.numpy as jnp
import numpy as np
from typing import Tuple

# import experimental feature
from warp.jax_experimental import jax_kernel

# wp.config.print_launches = True
# wp.config.mode = "debug"

import os, sys
BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)
    


def get_hungarian_algorithm_by_dim(dim):
    N = dim
    @wp.func
    def reduce_sumf(x: wp.vec(N, wp.float32))->wp.float32:
        for i in range(1,N):
            x[0] = x[0] + x[i]
        return x[0]

    @wp.func
    def vec_equalf(vec: wp.vec(N, wp.float32), val: wp.float32):
        for i in range(N):
            if vec[i] == val:
                vec[i] = 1.0
            else:
                vec[i] = 0.0
        return vec

    @wp.func
    def k_in_vec(vec:wp.vec(N, wp.int32), k:wp.int32, vec_max:wp.int32)->wp.bool:
        for i in range(vec_max):
            if vec[i] == k:
                return True
        return False

    @wp.func
    def collect_not_in_vec(vec:wp.vec(N, wp.int32), vec_max:wp.int32)->Tuple[wp.vec(N, wp.int32), wp.int32]:
        res = vec*0
        collected_cnt = wp.int32(0)
        for k in range(N):
            in_check = k_in_vec(vec, k, vec_max)
            if in_check:
                continue
            res[collected_cnt] = k
            collected_cnt+=1
        return res, collected_cnt
            
    @wp.func
    def greedy_matching_with_zero(zero_mask_matrix: wp.mat((N,N), wp.float32))->Tuple[wp.vec(N, wp.int32), wp.vec(N, wp.int32), wp.int32]:
        # mask elements with zero
        # for j in range(N):
        #     zero_mask_matrix[j] = vec_equalf(zero_mask_matrix[j], 0.0)
        
        zero_mask_matrix_masked = zero_mask_matrix

        # wp.print(zero_mask_matrix_masked)
            
        # perform greedy matching
        num_matched_pair = wp.int32(0)
        matched_pair_i = wp.vec(0, length=N, dtype=wp.int32)
        matched_pair_j = wp.vec(0, length=N, dtype=wp.int32)
        for tmp_ in range(N):
            # find min zero idx
            selected_row_idx = wp.int32(0)
            min_zero_val = wp.float32(1e6)
            for tmp_k in range(N):
                nzero_in_row = reduce_sumf(zero_mask_matrix_masked[tmp_k])
                if nzero_in_row!=0.0 and nzero_in_row < min_zero_val:
                    min_zero_val = nzero_in_row
                    selected_row_idx = tmp_k
            if min_zero_val == wp.float32(1e6):
                break
            # wp.print(selected_row_idx)
            # wp.print(min_zero_val)

            # select one column zero idx
            selected_col_idx = int(0)
            for z in range(N):
                if zero_mask_matrix_masked[selected_row_idx, z]==1.0:
                    selected_col_idx = z
                    break

            # mark selected row and column
            for l in range(N):
                zero_mask_matrix_masked[l, selected_col_idx] = 0.0
                zero_mask_matrix_masked[selected_row_idx, l] = 0.0

            matched_pair_i[num_matched_pair] = selected_row_idx
            matched_pair_j[num_matched_pair] = selected_col_idx
            num_matched_pair += 1

        return matched_pair_i, matched_pair_j, num_matched_pair


    @wp.func
    def min_line_cover(matched_pair_i: wp.vec(N, wp.int32), matched_pair_j: wp.vec(N, wp.int32), num_matched_pair: wp.int32, zero_mask_matrix: wp.mat((N,N), wp.float32)):
        # initial variables
        non_matched_i_vec, non_matched_i_cnt = collect_not_in_vec(matched_pair_i, num_matched_pair)
        # wp.print(non_matched_i_vec)
        # wp.print(non_matched_i_cnt)

        marked_cols = wp.vec(0, length=N, dtype=wp.int32)
        marked_col_cnt = wp.int32(0)

        check_switch = wp.bool(True)
        # while check_switch:
        for _ in range(N):
            check_switch = wp.bool(False)

            # collect duplicated rows
            for ki in range(non_matched_i_cnt):
                row_array = zero_mask_matrix[non_matched_i_vec[ki]]
                for t in range(N):
                    if row_array[t] == 1.0 and not k_in_vec(marked_cols, t, marked_col_cnt):
                        #step 2-2-3
                        marked_cols[marked_col_cnt] = t
                        marked_col_cnt += 1
                        check_switch = wp.bool(True)

            # remove rows according to marked cols
            for h in range(num_matched_pair):
                if not k_in_vec(non_matched_i_vec, matched_pair_i[h], non_matched_i_cnt) and k_in_vec(marked_cols, matched_pair_j[h], marked_col_cnt):
                    non_matched_i_vec[non_matched_i_cnt] = matched_pair_i[h]
                    non_matched_i_cnt += 1
                    check_switch = wp.bool(True)
            
            if not check_switch:
                break
        
        updated_matched_i_vec, updated_matched_i_cnt = collect_not_in_vec(non_matched_i_vec, non_matched_i_cnt)

        return updated_matched_i_vec, updated_matched_i_cnt, marked_cols, marked_col_cnt


    @wp.func
    def adjust_matrix(updated_matched_i_vec: wp.vec(N, wp.int32), updated_matched_i_cnt: wp.int32, marked_cols: wp.vec(N, wp.int32), marked_col_cnt: wp.int32, cost_matrix: wp.mat((N,N), wp.float32)):
        # adjust matrix
        # find min value out of non-marked rows and non-marked columns
        min_val = wp.float32(1e6)
        for row in range(N):
            if not k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt):
                for c in range(N):
                    if not k_in_vec(marked_cols, c, marked_col_cnt):
                        # idx for non-marked rows and non-marked columns
                        cur_cost = cost_matrix[row][c]
                        if cur_cost < min_val:
                            min_val = cur_cost
        # wp.print(min_val)
        # wp.print(cost_matrix)
        
        # subtract min value from non-marked elements
        # new_cost_matrix = wp.mat(0.0, shape=(N,N), dtype=wp.float32)
        for row in range(N):
            if not k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt):
                for c in range(N):
                    if not k_in_vec(marked_cols, c, marked_col_cnt):
                        # idx for non-marked rows and non-marked columns
                        cost_matrix[row,c] = cost_matrix[row,c] - min_val
            else:
                # # add min value to elements in intersection
                for c in range(N):
                    if k_in_vec(marked_cols, c, marked_col_cnt):
                        cost_matrix[row,c] = cost_matrix[row,c] + min_val
                        
        return cost_matrix

    @wp.kernel
    def try_matching(
        # inputs
        zero_grounded_cost_matrix: wp.array(dtype=wp.mat((N,N), wp.float32)),
        # outputs
        matched_pair_i_out: wp.array(dtype=wp.vec(N, dtype=wp.int32)),
        matched_pair_j_out: wp.array(dtype=wp.vec(N, dtype=wp.int32)),
    ):
        '''
        trials with warp

        # output should not be array(bool)

        # assignments
        local_vec = wp.vec2(0.0,0.0)
        local_vec = wp.vector(0.0, 0.0, length=2, dtype=wp.float32)
        local_vec = wp.vec(0.0, 0.0, length=2, dtype=wp.float32)
        '''

        i = wp.tid()
        cost_matrix = zero_grounded_cost_matrix[i]
        # matched_pair_tr = matched_pair[i]

        # max_iter = 2
        # for _ in range(max_iter):
        while True:
            # mask elements with zero
            zero_mask_matrix = cost_matrix
            for j in range(N):
                zero_mask_matrix[j] = vec_equalf(zero_mask_matrix[j], 0.0)
            
            # zero_mask_matrix_masked = zero_mask_matrix

            # # wp.print(zero_mask_matrix_masked)
                
            # # perform greedy matching
            # num_matched_pair = wp.int32(0)
            # matched_pair_i = wp.vec(0, length=N, dtype=wp.int32)
            # matched_pair_j = wp.vec(0, length=N, dtype=wp.int32)
            # for _ in range(N):
            #     # find min zero idx
            #     selected_row_idx = 0
            #     min_zero_val = wp.float32(1e6)
            #     for k in range(N):
            #         nzero_in_row = reduce_sumf(zero_mask_matrix_masked[k])
            #         if nzero_in_row!=0.0 and nzero_in_row < min_zero_val:
            #             min_zero_val = nzero_in_row
            #             selected_row_idx = k
            #     if min_zero_val == wp.float32(1e6):
            #         break
            #     # wp.print(selected_row_idx)
            #     # wp.print(min_zero_val)

            #     # select one column zero idx
            #     selected_col_idx = int(0)
            #     for k in range(N):
            #         if zero_mask_matrix_masked[selected_row_idx, k]==1.0:
            #             selected_col_idx = k
            #             break

            #     # mark selected row and column
            #     for l in range(N):
            #         zero_mask_matrix_masked[l, selected_col_idx] = 0.0
            #         zero_mask_matrix_masked[selected_row_idx, l] = 0.0

            #     matched_pair_i[num_matched_pair] = selected_row_idx
            #     matched_pair_j[num_matched_pair] = selected_col_idx
            #     num_matched_pair += 1
            
            matched_pair_i, matched_pair_j, num_matched_pair = greedy_matching_with_zero(zero_mask_matrix)

            if num_matched_pair == N:
                # wp.print(matched_pair_i)
                # wp.print(matched_pair_j)
                for k in range(N):
                    matched_pair_i_out[i][k] = matched_pair_i[k]
                    matched_pair_j_out[i][k] = matched_pair_j[k]
                break
            
            # wp.print(zero_mask_matrix)
            # wp.print(matched_pair_i)
            # wp.print(matched_pair_j)
            # wp.print(num_matched_pair)
            

            # if not done, select minimum cover lines
            
            # # initial variables
            # non_matched_i_vec, non_matched_i_cnt = collect_not_in_vec(matched_pair_i, num_matched_pair)
            # # wp.print(non_matched_i_vec)
            # # wp.print(non_matched_i_cnt)

            # marked_cols = wp.vec(0, length=N, dtype=wp.int32)
            # marked_col_cnt = wp.int32(0)

            # check_switch = wp.bool(True)
            # # while check_switch:
            # for _ in range(N):
            #     check_switch = wp.bool(False)

            #     # collect duplicated rows
            #     for ki in range(non_matched_i_cnt):
            #         row_array = zero_mask_matrix[non_matched_i_vec[ki]]
            #         for t in range(N):
            #             if row_array[t] == 1.0 and not k_in_vec(marked_cols, t, marked_col_cnt):
            #                 #step 2-2-3
            #                 marked_cols[marked_col_cnt] = t
            #                 marked_col_cnt += 1
            #                 check_switch = wp.bool(True)

            #     # remove rows according to marked cols
            #     for h in range(num_matched_pair):
            #         if not k_in_vec(non_matched_i_vec, matched_pair_i[h], non_matched_i_cnt) and k_in_vec(marked_cols, matched_pair_j[h], marked_col_cnt):
            #             non_matched_i_vec[non_matched_i_cnt] = matched_pair_i[h]
            #             non_matched_i_cnt += 1
            #             check_switch = wp.bool(True)
                
            #     if not check_switch:
            #         break
            
            # updated_matched_i_vec, updated_matched_i_cnt = collect_not_in_vec(non_matched_i_vec, non_matched_i_cnt)


            updated_matched_i_vec, updated_matched_i_cnt, marked_cols, marked_col_cnt = min_line_cover(matched_pair_i, matched_pair_j, num_matched_pair, zero_mask_matrix)

            # wp.print(marked_cols)
            # wp.print(marked_col_cnt)
            # wp.print(updated_matched_i_vec)
            # wp.print(updated_matched_i_cnt)

            # adjust matrix
            # find min value out of non-marked rows and non-marked columns
            # min_val = wp.float32(1e6)
            # for row in range(N):
            #     if not k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt):
            #         for c in range(N):
            #             if not k_in_vec(marked_cols, c, marked_col_cnt):
            #                 # idx for non-marked rows and non-marked columns
            #                 cur_cost = cost_matrix[row][c]
            #                 if cur_cost < min_val:
            #                     min_val = cur_cost
            # # wp.print(min_val)
            # # wp.print(cost_matrix)
            
            # # subtract min value from non-marked elements
            # # new_cost_matrix = wp.mat(0.0, shape=(N,N), dtype=wp.float32)
            # for row in range(N):
            #     if not k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt):
            #         for c in range(N):
            #             if not k_in_vec(marked_cols, c, marked_col_cnt):
            #                 # idx for non-marked rows and non-marked columns
            #                 cost_matrix[row,c] = cost_matrix[row,c] - min_val
            #     else:
            #         # # add min value to elements in intersection
            #         for c in range(N):
            #             if k_in_vec(marked_cols, c, marked_col_cnt):
            #                 cost_matrix[row,c] = cost_matrix[row,c] + min_val


            cost_matrix = adjust_matrix(updated_matched_i_vec, updated_matched_i_cnt, marked_cols, marked_col_cnt, cost_matrix)

            # wp.print(cost_matrix)

            # # add min value to elements in intersection
            # for row in range(N):
            #     if k_in_vec(updated_matched_i_vec, row, updated_matched_i_cnt):
            #         for c in range(N):
            #             if k_in_vec(marked_cols, c, marked_col_cnt):
            #                 cost_matrix[row,c] = cost_matrix[row,c] + min_val
            
            # # wp.print(cost_matrix)



    return try_matching


@jax.jit
def hungarian_warp(cost_mat):

    cost_mat = cost_mat - jnp.min(cost_mat, axis=-1, keepdims=True)
    cost_mat = cost_mat - jnp.min(cost_mat, axis=-2, keepdims=True)
    cost_mat = jax.lax.stop_gradient(cost_mat)
    res = jax_kernel(get_hungarian_algorithm_by_dim(cost_mat.shape[-2]))(cost_mat)
    res = jax.lax.stop_gradient(res)

    res = jnp.stack(res, axis=-1)

    sort_idx = jnp.argsort(res[...,:1], axis=-2)
    res = jnp.take_along_axis(res, sort_idx, axis=-2)

    return res





#### hungarian numpy

def min_zero_row(zero_mat, mark_zero):

    '''
    The function can be splitted into two steps:
    #1 The function is used to find the row which containing the fewest 0.
    #2 Select the zero number on the row, and then marked the element corresponding row and column as False
    '''

    #Find the row
    min_row = [99999, -1]

    for row_num in range(zero_mat.shape[0]): 
        if np.sum(zero_mat[row_num] == True) > 0 and min_row[0] > np.sum(zero_mat[row_num] == True):
            min_row = [np.sum(zero_mat[row_num] == True), row_num]

    # min col
    min_col = [99999, -1]
    for col_num in range(zero_mat.shape[0]): 
        if np.sum(zero_mat[:,col_num] == True) > 0 and min_col[0] > np.sum(zero_mat[:,col_num] == True):
            min_col = [np.sum(zero_mat[:,col_num] == True), col_num]
    
    if min_row[0] <= min_col[0]:
        # Marked the specific row and column as False
        zero_index = np.where(zero_mat[min_row[1]] == True)[0][0]
        mark_zero.append((min_row[1], zero_index))
        zero_mat[min_row[1], :] = False
        zero_mat[:, zero_index] = False
    else:
        zero_index = np.where(zero_mat[:, min_col[1]] == True)[0][0]
        mark_zero.append((zero_index, min_col[1]))
        zero_mat[zero_index, :] = False
        zero_mat[:, min_col[1]] = False

def mark_matrix(mat):

	'''
	Finding the returning possible solutions for LAP problem.
	'''

	#Transform the matrix to boolean matrix(0 = True, others = False)
	cur_mat = mat
	zero_bool_mat = (cur_mat == 0)
	zero_bool_mat_copy = zero_bool_mat.copy()

	#Recording possible answer positions by marked_zero
	marked_zero = []
	while (True in zero_bool_mat_copy):
		min_zero_row(zero_bool_mat_copy, marked_zero)
	
	#Recording the row and column positions seperately.
	marked_zero_row = []
	marked_zero_col = []
	for i in range(len(marked_zero)):
		marked_zero_row.append(marked_zero[i][0])
		marked_zero_col.append(marked_zero[i][1])

	#Step 2-2-1
	non_marked_row = list(set(range(cur_mat.shape[0])) - set(marked_zero_row))
	
	marked_cols = []
	check_switch = True
	while check_switch:
		check_switch = False
		for i in range(len(non_marked_row)):
			row_array = zero_bool_mat[non_marked_row[i], :]
			for j in range(row_array.shape[0]):
				#Step 2-2-2
				if row_array[j] == True and j not in marked_cols:
					#Step 2-2-3
					marked_cols.append(j)
					check_switch = True

		for row_num, col_num in marked_zero:
			#Step 2-2-4
			if row_num not in non_marked_row and col_num in marked_cols:
				#Step 2-2-5
				non_marked_row.append(row_num)
				check_switch = True
	#Step 2-2-6
	marked_rows = list(set(range(mat.shape[0])) - set(non_marked_row))

	return(marked_zero, marked_rows, marked_cols)

def adjust_matrix(mat, cover_rows, cover_cols):
	cur_mat = mat
	non_zero_element = []

	#Step 4-1
	for row in range(len(cur_mat)):
		if row not in cover_rows:
			for i in range(len(cur_mat[row])):
				if i not in cover_cols:
					non_zero_element.append(cur_mat[row][i])
	min_num = min(non_zero_element)

	#Step 4-2
	for row in range(len(cur_mat)):
		if row not in cover_rows:
			for i in range(len(cur_mat[row])):
				if i not in cover_cols:
					cur_mat[row, i] = cur_mat[row, i] - min_num
	#Step 4-3
	for row in range(len(cover_rows)):  
		for col in range(len(cover_cols)):
			cur_mat[cover_rows[row], cover_cols[col]] = cur_mat[cover_rows[row], cover_cols[col]] + min_num
	return cur_mat

def hungarian_algorithm_np(mat): 
    dim = mat.shape[0]
    cur_mat = mat

    #Step 1 - Every column and every row subtract its internal minimum
    for row_num in range(mat.shape[0]): 
        cur_mat[row_num] = cur_mat[row_num] - np.min(cur_mat[row_num])

    for col_num in range(mat.shape[1]): 
        cur_mat[:,col_num] = cur_mat[:,col_num] - np.min(cur_mat[:,col_num])
    zero_count = 0
    itr_no = 0
    while zero_count < dim:
        #Step 2 & 3
        ans_pos, marked_rows, marked_cols = mark_matrix(cur_mat)
        zero_count = len(marked_rows) + len(marked_cols)
        # if zero_count == dim:
        #      print(1)

        if zero_count < dim:
            cur_mat = adjust_matrix(cur_mat, marked_rows, marked_cols)
        itr_no += 1
            
    # total = 0
    # ans_mat = np.zeros((mat.shape[0], mat.shape[1]))
    # for i in range(len(ans_pos)):
    #     total += mat[ans_pos[i][0], ans_pos[i][1]]
    #     ans_mat[ans_pos[i][0], ans_pos[i][1]] = mat[ans_pos[i][0], ans_pos[i][1]]
    # return total, ans_mat
    return np.array(ans_pos)

###### hungarian numpy





import ctypes
import os
# turn offjax memory allocation
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax.extend as jex

# Load the shared library
hungarian_lib = ctypes.cdll.LoadLibrary("bindings/hungarian_ffi/libhungarian_kernel.so")

# Register the FFI target
jex.ffi.register_ffi_target(
    "hungarian_matching",
    jex.ffi.pycapsule(hungarian_lib.HungarianMatching),
    platform="CUDA"
)

def hungarian_jax_cuda(cost_matrices):
    # Ensure the input is a JAX array on the GPU
    cost_matrices = jnp.asarray(cost_matrices, dtype=jnp.float32)
    assert cost_matrices.ndim == 3
    B, N, N2 = cost_matrices.shape
    assert N == N2, "Cost matrices must be square"

    # Define the output types
    matched_pair_i_type = jax.ShapeDtypeStruct((B, N), jnp.int32)
    matched_pair_j_type = jax.ShapeDtypeStruct((B, N), jnp.int32)

    # Call the FFI function
    cost_matrices = jax.lax.stop_gradient(cost_matrices)
    matched_pair_i, matched_pair_j = jex.ffi.ffi_call(
        "hungarian_matching",
        (matched_pair_i_type, matched_pair_j_type),
        cost_matrices,
        vectorized=False,  # Batching is handled inside the kernel
    )

    matched_pair = jnp.stack([matched_pair_i, matched_pair_j], axis=-1)
    argsor_by_i = jnp.argsort(matched_pair[...,0:1], axis=-2)
    matched_pair = jnp.take_along_axis(matched_pair, argsor_by_i, axis=-2)

    return matched_pair



if __name__ == '__main__':

    # N = 7
    # B = 1
    # cost_matrix = np.random.default_rng(3).integers(0, 100, size=(B, N, N)).astype(np.float32)
    # cost_matrix = cost_matrix - np.min(cost_matrix, axis=-1, keepdims=True)
    # cost_matrix = cost_matrix - np.min(cost_matrix, axis=-2, keepdims=True)
    # cost_matrix = wp.from_numpy(cost_matrix, dtype=wp.mat((N,N), wp.float32), device="cuda")
    # matched_pair_i = wp.zeros(shape=B, dtype=wp.vec(N, wp.int32), device="cuda")
    # matched_pair_j = wp.zeros(shape=B, dtype=wp.vec(N, wp.int32), device="cuda")
    # adjust_mask = wp.from_numpy(np.zeros(N*N, np.int32), dtype=wp.vec(N*N, wp.int32), device="cuda")
    # min_val_out = wp.zeros(shape=B, dtype=wp.float32, device="cuda")
    # wp.launch(get_hungarian_algorithm_by_dim(N), dim=(B), inputs=[cost_matrix, matched_pair_i, matched_pair_j])

    # print(matched_pair_i)
    # print(matched_pair_j)

    import util.bp_matching_util as skh
    import time
    N = 50
    B = 32
    # calculcate time per each methds
    hungarian_jax_cuda_jit = jax.jit(hungarian_jax_cuda)
    hungarian_jax_sp_jit = jax.jit(skh.bipartite_matching_sp)
    for seed in range(10000):
        # seed = 2160
        cost_matrix = np.random.default_rng(seed).integers(0, 100, size=(B, N, N)).astype(np.float32)
        # cost_matrix = np.random.default_rng(seed).random(size=(B, N, N)).astype(np.float32)

        time_start = time.time()
        x = hungarian_jax_cuda_jit(cost_matrix)
        x = jax.block_until_ready(x)
        time_end = time.time()
        # print(x)

        # test with scipy
        time_start_sp = time.time()
        x_sp = hungarian_jax_sp_jit(cost_matrix)
        x_sp = jax.block_until_ready(x_sp)
        # print(x_sp)
        time_end_sp = time.time()

        print(f"jax_cuda: {time_end-time_start:.6f}, jax_sp: {time_end_sp-time_start_sp:.6f}")

        def extact_cost(idx_pair):
            cost = 0
            for i, j in idx_pair[0]:
                cost += cost_matrix[0,i,j]
            return cost
        
        print(f"jax_cuda: {extact_cost(x)}, jax_sp: {extact_cost(x_sp)}")

        # x_np = hungarian_algorithm_np(cost_matrix[0].copy())

        # x_warp = hungarian_jax(cost_matrix)

        # assert extact_cost(x) == extact_cost(x_sp)

        # assert np.all(x == x_sp)
        