import numba
from numba.typed import List
import numpy as np


@numba.njit(cache=True, locals={'p1': numba.int64, 
                                'p2': numba.int64, 
                                'p3': numba.int64, 
                                'new_list': numba.int64[::1]})
def merge_lists(lst1, lst2):
    p1, p2, p3 = numba.int64(0), numba.int64(0), numba.int64(0)
    new_list = np.zeros(len(lst1) + len(lst2), dtype=np.int64)
    
    while p2 < len(lst2) and p1 < len(lst1):
        if lst2[p2] <= lst1[p1]:
            new_list[p3] = lst2[p2]
            p2 += 1
            
            if lst2[p2 - 1] == lst1[p1]:
                p1 += 1
                
        elif lst2[p2] > lst1[p1]:
            new_list[p3] = lst1[p1]
            p1 += 1
        p3 += 1
    
    if p2 == len(lst2) and p1 == len(lst1):
        return new_list[:p3]
    elif p1 == len(lst1):
        rest = lst2[p2:]
    elif p2 == len(lst2):
        rest = lst1[p1:]
        
    p3_ = p3 + len(rest)
    new_list[p3: p3_] = rest
    
    return new_list[:p3_]


@numba.njit(cache=True)
def ind2pair(ind, n):
    """
    A list: [(0, 1), (0, 2), (0, 3), ..., (0, n - 1), 
                     (1, 2), (1, 3), ..., (1, n - 1), 
                                                 ..., 
                                      (n - 2, n - 1)]
    
    inputs:
        index of the pair
        
    outputs:
        pair
    """
    ind += 1
    len_rows = np.arange(n - 1, 0, -1, dtype=np.int64)
    
    for i, n in enumerate(len_rows):
        if ind - n <= 0:
            break
        else:
            ind -= n
    
    return (i, ind + i)


@numba.njit(cache=True)
def pair2ind(i, j, n):
    """
    A list: [(0, 1), (0, 2), (0, 3), ..., (0, n - 1), 
                     (1, 2), (1, 3), ..., (1, n - 1), 
                                                 ..., 
                                      (n - 2, n - 1)]
    
    inputs:
        pair (i, j)
        
    outputs:
        index of the pair
    """
    return ((n - 1) + (n - i)) * i // 2 + (j - i) - 1


@numba.njit(cache=True, locals={'p1': numba.int64, 
                                'p2': numba.int64})
def sizeof_intersection(lst1, lst2):
    if not len(lst1) or not len(lst2):
        return numba.int64(-1)
    
    p1, p2 = numba.int64(0), numba.int64(0)
    count = numba.int64(0)
    
    while p1 < len(lst1) and p2 < len(lst2):
        if lst1[p1] == lst2[p2]:
            count += 1
            p1 += 1
            p2 += 1
        elif lst1[p1] < lst2[p2]:
            p1 += 1
        else:
            p2 += 1
    
    return count