from random import Random
import math

#import sys
#from Test.core import *
if '_Array' not in dir():
    from Compiler.types import *
    from Compiler.types import _secret
    from Compiler.library import *
    from Compiler.program import Program
    _Array = Array

SORT_BITS = []
insecure_random = Random(0)

def predefined_comparator(x, y):
    """ Assumes SORT_BITS is populated with the required sorting network bits """
    if predefined_comparator.sort_bits_iter is None:
        predefined_comparator.sort_bits_iter = iter(SORT_BITS)
    return next(predefined_comparator.sort_bits_iter)
predefined_comparator.sort_bits_iter = None

def list_comparator(x, y):
    """ Uses the first element in the list for comparison """
    return x[0] < y[0]

def normal_comparator(x, y):
    return x < y

def bitwise_list_comparator(x, y):
    """ Uses the first element in the list for comparison """
    return (1 - x[0]) * y[0]

def bitwise_comparator(x, y):
    b = (1 - x) * y
    return b

def cond_swap_bit(x,y, b):
    """ swap if b == 1 """
    if x is None:
        return y, None
    elif y is None:
        return x, None
    if isinstance(x, list):
        t = [(xi - yi) * b for xi,yi in zip(x, y)]
        return [xi - ti for xi,ti in zip(x, t)], \
            [yi + ti for yi,ti in zip(y, t)]
    else:
        t = (x - y) * b
        return x - t, y + t

def cond_swap(x,y, comp):
    if x is None:
        return y, None
    elif y is None:
        return x, None
    b = comp(x, y)
    return cond_swap_bit(x, y, 1 - b)

def odd_even_merge(a, comp):
    if len(a) & (len(a)-1) != 0:
        raise Exception('Length must be a power of 2')
    if len(a) == 1:
        return
    if len(a) == 2:
        a[0], a[1] = cond_swap(a[0], a[1], comp)
    else:
        even = a[::2]
        odd = a[1::2]
        odd_even_merge(even, comp)
        odd_even_merge(odd, comp)
        a[0] = even[0]
        for i in range(1, len(a) // 2):
            a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp)
        a[-1] = odd[-1]

def odd_even_merge_sort(a, comp=bitwise_comparator):
    if len(a) == 1:
        return
    elif len(a) % 2 == 0:
        lower = a[:len(a)//2]
        upper = a[len(a)//2:]
        odd_even_merge_sort(lower, comp)
        odd_even_merge_sort(upper, comp)
        a[:] = lower + upper
        odd_even_merge(a, comp)
    else:
        raise CompilerError('Length of list must be power of two')

def merge(a, b, comp):
    """ General length merge (pads to power of 2) """
    while len(a) & (len(a)-1) != 0:
        a.append(None)
    while len(b) & (len(b)-1) != 0:
        b.append(None)
    if len(a) < len(b):
        a += [None] * (len(b) - len(a))
    elif len(b) < len(a):
        b += [None] * (len(b) - len(b))
    t = a + b
    odd_even_merge(t, comp)
    for i,v in enumerate(t[::]):
        if v is None:
            t.remove(None)
    return t

def sort(a, comp):
    """ Pads to power of 2, sorts, removes padding """
    length = len(a)
    while len(a) & (len(a)-1) != 0:
        a.append(None)
    odd_even_merge_sort(a, comp)
    del a[length:]

def recursive_merge(a, comp):
    """ Recursively merge a list of sorted lists (initially sorted by size) """
    if len(a) == 1:
        return
    # merge smallest two lists, place result in correct position, recurse
    t = merge(a[0], a[1], comp)
    del a[0]
    del a[0]
    added = False
    for i,c in enumerate(a):
        if len(c) >= len(t):
            a.insert(i, t)
            added = True
            break
    if not added:
        a.append(t)
    recursive_merge(a, comp)

def random_perm(n):
    """ Generate a random permutation of length n

    WARNING: randomness fixed at compile-time, this is NOT secure
    """
    if not Program.prog.options.insecure:
        raise CompilerError('no secure implementation of Waksman permution, '
                            'use --insecure to activate')
    a = list(range(n))
    for i in range(n-1, 0, -1):
        j = insecure_random.randint(0, i)
        t = a[i]
        a[i] = a[j]
        a[j] = t
    return a

def inverse(perm):
    inv = [None] * len(perm)
    for i, p in enumerate(perm):
        inv[p] = i
    return inv

def configure_waksman(perm):
    n = len(perm)
    if n == 2:
        return [(perm[0], perm[0])]
    I = [None] * (n//2)
    O = [None] * (n//2)
    p0 = [None] * (n//2)
    p1 = [None] * (n//2)
    inv_perm = [0] * n

    for i, p in enumerate(perm):
        inv_perm[p] = i

    while True:
        try:
            j = 2 * O.index(None)
        except ValueError:
            break
        #print 'j =', j
        O[j//2] = 0
        via = 0
        j0 = j
        while True:
            #print '    I[%d] = %d' % (inv_perm[j]/2, ((inv_perm[j] % 2) + via) % 2)

            i = inv_perm[j]
            #print '    p0[%d] = %d' % (inv_perm[j]/2, j/2)
            p0[i//2] = j//2

            I[i//2] = i % 2
            O[j//2] = j % 2
            #print '    O[%d] = %d' % (j/2, j % 2)
            if i % 2 == 1:
                i -= 1
            else:
                i += 1
            #i, via = set_swapper(I, j, via, inv_perm)

            #print '    O[%d] = %d' % (perm[i]/2, ((perm[i] % 2) + via ) % 2)
            j = perm[i]
            #O[j/2] = j % 2
            if j % 2 == 1:
                j -= 1
            else:
                j += 1
            #j, via = set_swapper(O, i, via, perm)
            #print '    p1[%d] = %d' % (i/2, perm[i]/2)
            p1[i//2] = perm[i]//2

            #print '    i = %d, j =  %d' %(i,j)
            if j == j0:
                break
        if None not in p0 and None not in p1:
            break

    assert sorted(p0) == list(range(n//2))
    assert sorted(p1) == list(range(n//2))
    p0_config = configure_waksman(p0)
    p1_config = configure_waksman(p1)
    return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]

def waksman(a, config, depth=0, start=0, reverse=False):
    """ config is a list of log_2(n) configuration lists for the sub-networks """
    n = len(a)
    if n == 2:
        a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start])
        return

    a0 = [0] * (n//2)
    a1 = [0] * (n//2)
    for i in range(n//2):
        if reverse:
            a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start])
        else:
            a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start])

    waksman(a0, config, depth+1, start, reverse)
    waksman(a1, config, depth+1, start + n//2, reverse)

    for i in range(n//2):
        if reverse:
            a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start])
        else:
            a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start])


WAKSMAN_FUNCTIONS = {}
def iter_waksman(a, config, reverse=False):
    """ Iterative Waksman algorithm, compilable for large inputs. Input
    must be an Array. """
    n = len(a)
    #if not isinstance(a, Array):
    #    raise CompilerError('Input must be an Array')

    depth = MemValue(0)
    nblocks = MemValue(1)
    size = MemValue(0)
    a2 = Array(n, a[0].reg_type)
    #config_array = Array(n, a[0].reg_type)
    #reverse = (int(reverse))

    def create_round_fn(n, reg_type, inwards):
        if (n, reg_type, inwards, reverse) in WAKSMAN_FUNCTIONS:
            return WAKSMAN_FUNCTIONS[(n, reg_type, inwards, reverse)]
        
        def do_round(size, config_address, a_address, a2_address):
            A = Array(n, reg_type, a_address)
            A2 = Array(n, reg_type, a2_address)
            C = Array(n, reg_type, config_address)
            outwards = 1 - inwards
            
            sizeval = size
            #for k in range(n//2):
            @for_range_parallel(200, n//2)
            def f(k):
                j = cint(k) % sizeval
                i = (cint(k) - j)//sizeval
                base = 2*i*sizeval

                in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards)
                out1, out2 = (base+j+j*outwards), (base+j+j*outwards+1*outwards+sizeval*inwards)
                
                if inwards:
                    if reverse:
                        c = C[base + j + sizeval]
                    else:
                        c = C[base + j]
                else:
                    if reverse:
                        c = C[base + j]
                    else:
                        c = C[base + j + sizeval]
                A2[out1], A2[out2] = cond_swap_bit(A[in1], A[in2], c)

        fn = function_block(do_round)
        WAKSMAN_FUNCTIONS[(n, reg_type, inwards, reverse)] = fn
        return fn
    
    do_round = lambda size, ca, aa, aa2, inwards: \
               create_round_fn(n, a[0].reg_type, inwards)(size, ca, aa, aa2)

    logn = int(math.log(n,2))

    # going into middle of network
    @for_range(logn)
    def f(i):
        size.write(n//(2*nblocks))
        conf_address = MemValue(config.address + depth.read()*n)
        do_round(size, conf_address, a.address, a2.address, 1)

        @for_range(n)
        def _(i):
            a[i] = a2[i]

        nblocks.write(nblocks*2)
        depth.write(depth+1)

    nblocks.write(nblocks//4)
    depth.write(depth-2)

    # and back out
    @for_range(logn-1)
    def f(i):
        size.write(n//(2*nblocks))
        conf_address = MemValue(config.address + depth.read()*n)
        do_round(size, conf_address, a.address, a2.address, 0)

        @for_range(n)
        def _(i):
            a[i] = a2[i]

        nblocks.write(nblocks//2)
        depth.write(depth-1)

    ## going into middle of network
    #while nblocks < n:
    #    #for i in range(n):
    #    #    config_array[i] = config[depth][i].read()
#
    #    size.write(n/(2*nblocks))
    #    conf_address = config.address + depth*n
    #    do_round_in(size, conf_address, a.address, a2.address)
#
    #    for i in range(n):
    #        a[i] = a2[i]
#
    #    nblocks *= 2
    #    depth += 1
    #
    #nblocks /= 4
    #depth -= 2
    ## and back out
    #while nblocks > 0:
    #    #for i in range(n):
    #    #    config_array[i] = config[depth][i].read()
#
    #    size.write(n/(2*nblocks))
    #    conf_address = config.address + depth*n
    #    do_round_out(size, conf_address, a.address, a2.address)
#
    #    for i in range(n):
    #        a[i] = a2[i]
#
    #    nblocks /= 2
    #    depth -= 1

def rec_shuffle(x, config=None, value_type=sgf2n, reverse=False):
    n = len(x)
    if n & (n-1) != 0:
        raise CompilerError('shuffle requires n a power of 2')
    if config is None:
        config = configure_waksman(random_perm(n))
        for i,c in enumerate(config):
            config[i] = [value_type.bit_type(b) for b in c]
    waksman(x, config, reverse=reverse)
    waksman(x, config, reverse=reverse)


def config_shuffle(n, value_type):
    """ Compute config for oblivious shuffling.
    
    Take mod 2 for active sec. """
    perm = random_perm(n)
    if n & (n-1) != 0:
        # pad permutation to power of 2
        m = 2**int(math.ceil(math.log(n, 2)))
        perm += list(range(n, m))
    config_bits = configure_waksman(perm)
    # 2-D array
    config = Array(len(config_bits) * len(perm), value_type.reg_type)
    if n > 1024:
        for x in config_bits:
            for y in x:
                get_program().public_input(y)
        @for_range(sum(len(x) for x in config_bits))
        def _(i):
            config[i] = public_input()
        return config
    for i,c in enumerate(config_bits):
        for j,b in enumerate(c):
            config[i * len(perm) + j] = b
    return config

def shuffle(x, config=None, value_type=sgf2n, reverse=False):
    """ Simulate secure shuffling with Waksman network for 2 players.
    WARNING: This is not a properly secure implementation but has roughly the right complexity.

    Returns the network switching config so it may be re-used later.  """
    n = len(x)
    m = 2**int(math.ceil(math.log(n, 2)))
    assert n == m, 'only working for powers of two'
    if config is None:
        config = config_shuffle(n, value_type)

    if isinstance(x, list):
        if isinstance(x[0], list):
            length = len(x[0])
            assert len(x) == length
            for i in range(length):
                xi = Array(m, value_type.reg_type)
                for j in range(n):
                    xi[j] = x[j][i]
                for j in range(n, m):
                    xi[j] = value_type(0)
                iter_waksman(xi, config, reverse=reverse)
                iter_waksman(xi, config, reverse=reverse)
                for j, y in enumerate(xi):
                    x[j][i] = y
        else:
            xa = Array(m, value_type.reg_type)
            for i in range(n):
                xa[i] = x[i]
            for i in range(n, m):
                xa[i] = value_type(0)
            iter_waksman(xa, config, reverse=reverse)
            iter_waksman(xa, config, reverse=reverse)
            x[:] = xa
    elif isinstance(x, Array):
        if len(x) != m and config is None:
            raise CompilerError('Non-power of 2 Array input not yet supported')
        iter_waksman(x, config, reverse=reverse)
        iter_waksman(x, config, reverse=reverse)
    else:
        raise CompilerError('Invalid type for shuffle:', type(x))

    return config

def shuffle_entries(x, entry_cls, config=None, value_type=sgf2n, reverse=False, perm_size=None):
    """ Shuffle a list of ORAM entries.

        Randomly permutes the first "perm_size" entries, leaving the rest (empty
        entry padding) in the same position. """
    n = len(x)
    l = len(x[0])
    if n & (n-1) != 0:
        raise CompilerError('Entries must be padded to power of two length.')
    if perm_size is None:
        perm_size = n

    xarrays = [Array(n, value_type.reg_type) for i in range(l)]
    for i in range(n):
        for j,value in enumerate(x[i]):
            if isinstance(value, MemValue):
                xarrays[j][i] = value.read()
            else:
                xarrays[j][i] = value

    if config is None:
        config = config_shuffle(perm_size, value_type)
    for xi in xarrays:
        shuffle(xi, config, value_type, reverse)
    for i in range(n):
        x[i] = entry_cls(xarrays[j][i] for j in range(l))
    return config


def sort_zeroes(bits, x, n_ones, value_type):
    """ Return Array of values in "x" where the corresponding bit in "bits" is
    a 0.

    The total number of zeroes in "bits" must be known.
    "bits" and "x" must be Arrays. """
    config = config_shuffle(len(x), value_type)
    shuffle(bits, config=config, value_type=value_type)
    shuffle(x, config=config, value_type=value_type)
    result = Array(n_ones, value_type.reg_type)

    sz = MemValue(0)
    last_x = MemValue(value_type(0))
    #for i,b in enumerate(bits):
        #if_then(b.reveal() == 0)
        #result[sz.read()] = x[i]
        #sz += 1
        #end_if()
    @for_range(len(bits))
    def f(i):
        found = (bits[i].reveal() == 0)
        szval = sz.read()
        result[szval] = last_x + (x[i] - last_x) * found
        sz.write(sz + found)
        last_x.write(result[szval])
    return result
