from Compiler.oram import *

from Compiler.program import Program

ORAM = OptimalORAM

prog = program.Program.prog
prog.set_bit_length(min(64, prog.bit_length))

class HeapEntry(object):
    fields = ['empty', 'prio', 'value']
    def __init__(self, int_type, *args):
        self.int_type = int_type
        if not len(args):
            raise CompilerError()
        if len(args) == 1:
            args = args[0]
        for field,arg in zip(self.fields, args):
            self.__dict__[field] = arg
    def data(self):
        return self.prio, self.value
    def __repr__(self):
        return '(' + ', '.join('%s=%s' % (field,self.__dict__[field]) \
                                   for field in self.fields) + ')'
    def __eq__(self, other):
        return self.__dict__ == other.__dict__
    def __gt__(self, other):
        return (1 - self.empty) * (1 - other.empty) * \
            (self.int_type(self.prio) > self.int_type(other.prio))
    def __iter__(self):
        for field in self.fields:
            yield self.__dict__[field]
    def __add__(self, other):
        return type(self)(self.int_type, (i + j for i,j in zip(self, other)))
    def __sub__(self, other):
        return type(self)(self.int_type, (i - j for i,j in zip(self, other)))
    def __xor__(self, other):
        return type(self)(self.int_type, (i ^ j for i,j in zip(self, other)))
    def __mul__(self, other):
        return type(self)(self.int_type, (other * i for i in self))
    __rxor__ = __xor__
    __rmul__ = __mul__
    def hard_conv_me(self, value_type):
        return type(self)(self.int_type, \
                              *(value_type.hard_conv(x) for x in self))
    def dump(self):
        print_ln('empty %s, prio %s, value %s', *(reveal(x) for x in self))

class HeapORAM(object):
    def __init__(self, size, oram_type, init_rounds, int_type):
        self.int_type = int_type
        self.oram = oram_type(size, entry_size=(32,log2(size)), \
                                  init_rounds=init_rounds, \
                                  value_type=int_type.basic_type)
    def __getitem__(self, index):
        return self.make_entry(*self.oram.read(index))
    def make_entry(self, value, empty):
        return HeapEntry(self.int_type, (empty,) + value)
    def __setitem__(self, index, value):
        self.oram.access(index, value.data(), True, new_empty=value.empty)
    def access(self, index, value, write):
        tmp = self.oram.access(index, value.data(), write)
        return self.make_entry(*tmp)
    def delete(self, index, for_real):
        self.oram.delete(index, for_real)
    def read_and_maybe_remove(self, index):
        entry, state = self.oram.read_and_maybe_remove(index)
        return self.make_entry(*entry), state
    def add(self, index, entry, state):
        self.oram.add(Entry(MemValue(index), \
                                [MemValue(i) for i in entry.data()], \
                                entry.empty), state=state)
    def __len__(self):
        return len(self.oram)

class HeapQ(object):
    def __init__(self, max_size, oram_type=ORAM, init_rounds=-1, int_type=sint):
        basic_type = int_type.basic_type
        self.max_size = max_size
        self.levels = log2(max_size)
        self.depth = self.levels - 1
        self.heap = HeapORAM(2**self.levels, oram_type, init_rounds, int_type)
        self.value_index = oram_type(max_size, entry_size=log2(max_size), \
                                         init_rounds=init_rounds, \
                                         value_type=basic_type)
        self.size = MemValue(int_type(0))
        self.int_type = int_type
        self.basic_type = basic_type
        print('heap: %d levels, depth %d, size %d, index size %d' % \
            (self.levels, self.depth, self.heap.oram.size, self.value_index.size))
    def update(self, value, prio, for_real=True):
        self._update(self.basic_type.hard_conv(value), \
                         self.basic_type.hard_conv(prio), \
                         self.basic_type.hard_conv(for_real))
    def pop(self, for_real=True):
        return self._pop(self.basic_type.hard_conv(for_real))
    def bubble_up(self, start):
        bits = bit_decompose(start, self.levels)
        bits.reverse()
        bits = [0] + floatingpoint.PreOR(bits, self.levels)
        bits = [bits[i+1] - bits[i] for i in range(self.levels)]
        shift = sum([bit << i for i,bit in enumerate(bits)])
        childpos = MemValue(start * shift)
        @for_range(self.levels - 1)
        def f(i):
            parentpos = childpos.right_shift(1, self.levels + 1)
            parent, parent_state = self.heap.read_and_maybe_remove(parentpos)
            child, child_state = self.heap.read_and_maybe_remove(childpos)
            swap = parent > child
            new_parent, new_child = cond_swap(swap, parent, child)
            self.heap.add(childpos, new_child, child_state)
            self.heap.add(parentpos, new_parent, parent_state)
            self.value_index.access(new_parent.value, parentpos, swap)
            self.value_index.access(new_child.value, childpos, swap)
            childpos.write(parentpos)
    @method_block
    def _pop(self, for_real=True):
        Program.prog.curr_tape.\
            start_new_basicblock(name='heapq-pop')
        pop_for_real = for_real * (self.size != 0)
        entry = self.heap[1]
        self.value_index.delete(entry.value, for_real)
        last = self.heap[self.basic_type(self.size)]
        self.heap.access(1, last, pop_for_real)
        self.value_index.access(last.value, 1, for_real * (self.size != 1))
        self.heap.delete(self.basic_type(self.size), for_real)
        self.size -= self.int_type(pop_for_real)
        parentpos = MemValue(self.basic_type(1))
        @for_range(self.levels - 1)
        def f(i):
            childpos = 2 * parentpos
            left_child, l_state = self.heap.read_and_maybe_remove(childpos)
            right_child, r_state = self.heap.read_and_maybe_remove(childpos+1)
            go_right = left_child > right_child
            otherchildpos = childpos + 1 - go_right
            childpos += go_right
            child, other_child = cond_swap(go_right, left_child, right_child)
            child_state, other_state = cond_swap(go_right, l_state, r_state)
            parent, parent_state = self.heap.read_and_maybe_remove(parentpos)
            swap = parent > child
            new_parent, new_child = cond_swap(swap, parent, child)
            self.heap.add(childpos, new_child, child_state)
            self.heap.add(otherchildpos, other_child, other_state)
            self.heap.add(parentpos, new_parent, parent_state)
            self.value_index.access(new_parent.value, parentpos, swap)
            self.value_index.access(new_child.value, childpos, swap)
            parentpos.write(childpos)
        self.check()
        return entry.value
    @method_block
    def _update(self, value, prio, for_real=True):
        Program.prog.curr_tape.\
            start_new_basicblock(name='heapq-update')
        index, not_found = self.value_index.read(value)
        self.size += self.int_type(not_found * for_real)
        index = if_else(not_found, self.basic_type(self.size), index[0])
        self.value_index.access(value, self.basic_type(self.size), \
                                    not_found * for_real)
        self.heap.access(index, HeapEntry(self.int_type, 0, prio, value), for_real)
        self.bubble_up(index)
        self.check()
    def __len__(self):
        return self.size
    def check(self):
        if debug:
            for i in range(len(self.heap)):
                if ((2 * i + 1 < len(self.heap) and \
                         self.heap[i] > self.heap[2*i+1]) or \
                        (2 * i + 2 < len(self.heap) and \
                             self.heap[i] > self.heap[2*i+2])) and \
                             not self.heap[i].empty:
                    raise Exception('heap condition violated at %d' % i)
                if i >= self.size and not self.heap[i].empty:
                    raise Exception('wrong size at %d' % i)
                if i < self.size and self.heap[i].empty:
                    raise Exception('empty entry in heap at %d' % i)
                # if not self.heap[i].empty and \
                #         self.heap[i].value not in self.value_index:
                #     raise Exception('missing index at %d' % i)
            for value,(index,empty) in enumerate(self.value_index):
                if not empty and self.heap[index].value != value:
                    raise Exception('index violated at %d' % index)
        if debug_online:
            @for_range(self.max_size)
            def f(value):
                index, not_found = self.value_index.read(value)
                index, not_found = index[0].reveal(), not_found.reveal()
                @if_(not_found == 0)
                def f():
                    heap_value = self.heap[index].value.reveal()
                    @if_(heap_value != value)
                    def f():
                        print_ln('heap mismatch: %s:%s in index, %s in heap', \
                                     value, index, heap_value)
                        crash()            
    def dump(self, msg=''):
        print_ln(msg)
        print_ln('size: %s', self.size.reveal())
        print_str('heap:')
        if isinstance(self.heap.oram, LinearORAM):
            for entry in self.heap.oram.ram:
                print_str(' %s:%s,%s', entry.empty().reveal(), \
                              entry.x[0].reveal(), entry.x[1].reveal())
        else:
            for i in range(self.max_size+1):
                print_str(' %s:%s', *(x.reveal() for x in self.heap.oram[i]))
        print_ln()
        print_str('value index:')
        if isinstance(self.value_index, LinearORAM):
            for entry in self.value_index.ram:
                print_str(' %s:%s', entry.empty().reveal(), entry.x[0].reveal())
        else:
            for i in range(self.max_size):
                print_str(' %s:%s', i, self.value_index[i].reveal())
        print_ln()
        print_ln()

def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
    basic_type = int_type.basic_type
    vert_loops = n_loops * e_index.size // edges.size \
        if n_loops else -1
    dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
                         init_rounds=vert_loops, value_type=basic_type)
    #visited = ORAM(e_index.size)
    #previous = oram_type(e_index.size)
    Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \
                  int_type=int_type)

    if n_loops is not None:
        # put initialization in different timer
        stop_timer()
        start_timer(1)
    dist[source] = (0,0)
    Q.update(source, 0)
    if n_loops is not None:
        stop_timer(1)
        start_timer()
    last_edge = MemValue(basic_type(1))
    i_edge = MemValue(int_type(0))
    u = MemValue(basic_type(0))
    @for_range(n_loops or edges.size)
    def f(i):
        cint(i).print_reg('loop')
        time()
        u.write(if_else(last_edge, Q.pop(last_edge), u))
        #visited.access(u, True, last_edge)
        i_edge.write(int_type(if_else(last_edge, e_index[u], i_edge)))
        v, weight, le = edges[i_edge]
        last_edge.write(le)
        i_edge.iadd(1)
        alt = int_type(dist[u][0]) + int_type(weight)
        #is_shorter = (alt < dist[v]) * (1 - visited[v])
        dv, not_visited = dist.read(v)
        # relying on default dv negative here
        is_shorter = (alt < int_type(dv[0])) + not_visited
        dist.access(v, (basic_type(alt), u), is_shorter)
        #previous.access(v, u, is_shorter)
        Q.update(v, basic_type(alt), is_shorter)
        print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \
                     u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \
                     not_visited.reveal())
    return dist

def convert_graph(G):
    edges = [None] * (2 * G.size())
    e_index = [None] * (len(G))
    i = 0
    for v in G:
        e_index[v] = i
        for u in G[v]:
            edges[i] = [u, G[v][u]['weight'], 0]
            i += 1
        edges[i-1][-1] = 1
    return edges, e_index

def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
    for u in G:
        for v in G[u]:
            G[u][v].setdefault('weight', 1)
    edges_list, e_index_list = convert_graph(G)
    edges = oram_type(len(edges_list), \
                          entry_size=(log2(len(G)), log2(len(G)), 1), \
                          init_rounds=0, value_type=int_type.basic_type)
    e_index = oram_type(len(e_index_list), entry_size=log2(len(edges_list)), \
                            value_type=int_type.basic_type)
    for i in range(n_loops or edges.size):
        cint(i).print_reg('edge')
        time()
        edges[i] = edges_list[i]
    vert_loops = n_loops * e_index.size // edges.size \
        if n_loops else e_index.size
    for i in range(vert_loops):
        cint(i).print_reg('vert')
        time()
        e_index[i] = e_index_list[i]
    return dijkstra(source, edges, e_index, oram_type, n_loops, int_type)

def test_dijkstra_on_cycle(n, oram_type=ORAM, n_loops=None, int_type=sint):
    n_edges = 2 * n
    edges = oram_type(n_edges, entry_size=(log2(n),log2(n),1), init_rounds=0, 
                      value_type=int_type.basic_type)
    e_index = oram_type(n, entry_size=log2(n_edges), init_rounds=0, \
                            value_type=int_type.basic_type)
    @for_range(n_loops or edges.size)
    def f(i):
        cint(i).print_reg('edge')
        time()
        neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n
        edges[i] = (neighbour, 1, i % 2)
    vert_loops = n_loops * e_index.size // edges.size \
        if n_loops else e_index.size
    @for_range(vert_loops)
    def f(i):
        cint(i).print_reg('vert')
        time()
        e_index[i] = 2 * i
    return dijkstra(0, edges, e_index, oram_type, n_loops, int_type)

def test_dijkstra_on_complete(n, oram_type=ORAM, n_loops=None, int_type=sint):
    n_edges = n**2
    edges = oram_type(n_edges, entry_size=(log2(n),log2(n),1), init_rounds=0, 
                      value_type=int_type.basic_type)
    e_index = oram_type(n, entry_size=log2(n_edges), init_rounds=0, \
                            value_type=int_type.basic_type)
    @for_range(n_loops or n)
    def f(i):
        @for_range(n_loops - 1 if n_loops else n - 1)
        def f(j):
            cint(i).print_reg('v1')
            cint(j).print_reg('v2')
            time()
            edges[i*n+j] = (j, 1, 0)
        edges[i*n+n-1] = (n - 1, 1, 1)
    if n_loops is not None:
        stop_timer()
        start_timer(2)
    @for_range(n_loops or n)
    def f(i):
        cint(i).print_reg('vert')
        time()
        e_index[i] = n * i
    if n_loops is not None:
        stop_timer(2)
        start_timer()
    return dijkstra(0, edges, e_index, oram_type, \
                        n_loops**2 if n_loops else n**2, int_type)

class ExtInt(object):
    def __init__(self, x, inf=False):
        self.x = x
        self.inf = inf
    def __add__(self, other):
        if isinstance(other, ExtInt):
            return ExtInt(self.x + other.x, self.inf + other.inf)
        else:
            return ExtInt(self.x + other, self.inf)
    def __sub__(self, other):
        if isinstance(other, ExtInt):
            return ExtInt(self.x - other.x, self.inf - other.inf)
        else:
            return ExtInt(self.x - other, self.inf)    
    def __rsub__(self, other):
        return ExtInt(other - self.x, -self.inf)
    def __mul__(self, other):
        if isinstance(other, ExtInt):
            raise Exception()
            return ExtInt(self.x * other.x, self.inf * other.inf)
        else:
            return ExtInt(self.x * other, self.inf * other)
    __radd__ = __add__
    __rmul__ = __mul__
    def __lt__(self, other):
        if isinstance(other, ExtInt):
            return ((1 - self.inf) * (1 - other.inf) * (self.x < other.x)) + \
                other.inf
        else:
            return (1 - self.inf) * (self.x < other)
    def __gt__(self, other):
        if isinstance(other, ExtInt):
            return ((1 - self.inf) * (1 - other.inf) * (self.x > other.x)) + \
                self.inf
        else:
            return 1 - (1 -self.inf) * (1 - (self.x > other))
    def __repr__(self):
        if self.inf:
            return 'T'
        else:
            return str(self.x)

class Vector(object):
    """ Works like a vector. """
    def __add__(self, other):
        print('add', type(self))
        res = type(self)(len(self))
        @for_range(len(self))
        def f(i):
            res[i] = self[i] + other[i]
        return res
    def __sub__(self, other):
        print('sub', type(other))
        res = type(other)(len(self))
        @for_range(len(self))
        def f(i):
            res[i] = self[i] - other[i]
        return res
    def __mul__(self, other):
        if isinstance(other, Vector):
            res = type(self)(1)
            res[0] = ExtInt(0)
            @for_range(len(self))
            def f(i):
                res[0] += self[i] * other[i]
            return res[0]
        else:
            print('mul', type(self))
            res = type(self)(len(self))
            @for_range_parallel(1024, len(self))
            def f(i):
                res[i] = self[i] * other
            return res
    __rmul__ = __mul__

class VectorList(Vector, list):
    pass

class VectorArray(Vector):
    def __init__(self, length, address=None):
        self.length = length
        if address is None:
            self.arrays = [Array(length, 's') for i in range(2)]
        else:
            self.arrays = [Array(length, 's', addr) \
                               for addr in (address,address+length)]
    def assign(self, values):
        @for_range(len(self))
        def f(i):
            self[i] = values[i]
    def assign_all(self, value):
        self.arrays[0].assign_all(value.x)
        self.arrays[1].assign_all(value.inf)
    def __getitem__(self, index):
        return ExtInt(*[v[index] for v in self.arrays])
    def __setitem__(self, index, value):
        self.arrays[0][index] = value.x
        self.arrays[1][index] = value.inf
    def __len__(self):
        return len(self.arrays[0])

class IntVectorArray(Vector, Array):
    def __init__(self, length):
        Array.__init__(self, length, 's')

class Matrix(object):
    """ Guess what. """
    def __init__(self, rows, columns):
        self.rows = rows
        self.columns = columns
        self.address = Array(2 * rows * columns, 's').address
    def __getitem__(self, index):
        return VectorArray(self.columns, self.address + 2 * self.columns * index)
    def __setitem__(self, index, value):
        self[index].assign(value)
    def __len__(self):
        return self.rows
    def assign_all(self, value):
        @for_range(len(self))
        def f(i):
            self[i].assign_all(value)
        return self

def updatevector(vector, index, value):
    @for_range_parallel(1024, len(vector))
    def f(i):
        vector[i] += index[i] * (value - vector[i])

def binarymin(A):
    if len(A) == 1:
        return [1], A[0]
    else:
        half = len(A) // 2
        A_prime = VectorArray(half)
        B = IntVectorArray(half)
        i = IntVectorArray(len(A))
        @for_range_parallel(128, half)
        def f(j):
            B[j] = A[2*j] < A[2*j+1]
            A_prime[j] = if_else(B[j], A[2*j], A[2*j+1])
        i_prime, min = binarymin(A_prime)
        @for_range_parallel(1024, half)
        def f(j):
            i[2*j] = B[j] * i_prime[j]
            i[2*j+1] = (1 - B[j]) * i_prime[j]
        return i, min

def stupid_dijkstra(M, s, n_loops=None):
    if n_loops is not None:
        stop_timer()
        start_timer(1)
    P = Matrix(len(M), len(M))
    P.assign_all(ExtInt(0))
    d = VectorArray(len(M))
    d.assign_all(ExtInt(0,True))
    q = VectorArray(len(M))
    q.assign_all(ExtInt(0))
    d_prime = VectorArray(len(M))
    updatevector(d, s, 0)
    if n_loops is not None:
        stop_timer(1)
        start_timer()
    @for_range(n_loops or len(M))
    def f(i):
        if n_loops is not None:
            stop_timer()
            start_timer(2)
        d_prime.assign(d + q)
        k, min = binarymin(d_prime)
        updatevector(q, k, ExtInt(0,True))
        if n_loops is not None:
            stop_timer(2)
            start_timer()
        @for_range(n_loops or len(M))
        def f(j):
            a = (d + M[j]) * k
            c = a < d[j]
            P[j] = P[j] + c * (k - P[j])
            d[j] += c * (a - d[j])
    return d, P

def convert_graph_to_matrix(G):
    M = Matrix(len(G), len(G))
    M.assign_all(ExtInt(0,True))
    for u in G:
        for v in G[u]:
            M[u][v] = ExtInt(G[u][v].get('weight', 1))
    return M

def test_stupid_dijkstra(G, source):
    return stupid_dijkstra(convert_graph_to_matrix(G), \
                               demux(bit_decompose(source, log2(len(G)))))

def test_stupid_dijkstra_on_cycle(n, n_loops=None):
    if n_loops is not None:
        stop_timer()
        start_timer(1)
    M = Matrix(n, n)
    M.assign_all(ExtInt(0,True))
    s = IntVectorArray(n)
    s.assign_all(0)
    s[0] = 1
    @for_range(n)
    def f(i):
        M[i][(i+1)%n] = ExtInt(1)
        M[i][(i-1)%n] = ExtInt(1)
    if n_loops is not None:
        stop_timer(1)
        start_timer()
    return stupid_dijkstra(M, s, n_loops)
