import sys
import math

from Compiler import types

from Compiler.util import *

from .oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry
from .library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str

class OMatrixRow(object):
    def __init__(self, oram, base, add_type):
        self.oram = oram
        self.base = base
        self.add_type = add_type
    def get_index(self, offset):
        if isinstance(offset, types._secret):
            return self.base + self.add_type.hard_conv(offset)
        else:
            return self.base + offset
    def __getitem__(self, offset):
        return untuplify(self.read(offset)[0])
    def __setitem__(self, offset, item):
        self.oram[self.get_index(offset)] = item
    def read(self, offset):
        return self.oram.read(self.get_index(offset))

class OMatrix:
    def __init__(self, N, M=None, oram_type=OptimalORAM, int_type=types.sint):
        print('matrix', oram_type)
        self.N = N
        self.M = M or N
        self.oram = oram_type(N * self.M, entry_size=log2(N), init_rounds=0, \
                                  value_type=int_type.basic_type)
        self.int_type = int_type
    def __getitem__(self, a):
        if math.log(self.M, 2) % 1 == 0 or self.int_type == types.sint:
            add_type = self.int_type.basic_type
        else:
            class add_type(self.int_type):
                n_bits = log2(self.N * self.M)
            if type(a) == self.int_type.basic_type:
                a = add_type(a)
        return OMatrixRow(self.oram, a * self.M, add_type)

class OReverseMatrixRow(object):
    def __init__(self, oram, index, N, M, basic_type):
        self.oram = oram
        self.N = N
        self.M = M
        self.index = index
        self.basic_type = basic_type
    def __getitem__(self, offset):
        return untuplify(self.read(offset)[0])
    def read(self, offset):
        temp = TrivialORAM(self.M, self.basic_type, 1, log2(self.N))
        prefs = self.oram[self.index]
        for i in range(self.M):
            temp.ram[i] = Entry(prefs[i], i, value_type=self.basic_type)
        return temp.read(offset)

class OReverseMatrix(OMatrix):
    def __init__(self, N, M, oram_type=OptimalORAM, int_type=types.sint):
        self.N = N
        self.M = M
        self.oram = oram_type(N, entry_size=(log2(N),)*M, init_rounds=0, \
                                  value_type=int_type.basic_type)
        self.basic_type = int_type.basic_type
    def __getitem__(self, a):
        return OReverseMatrixRow(self.oram, a, self.N, self.M, self.basic_type)
    def __setitem__(self, index, value):
        self.oram[index] = value

class OStack:
    def __init__(self, N, oram_type=OptimalORAM, int_type=types.sint):
        print('stack', oram_type)
        self.oram = oram_type(N, entry_size=log2(N), init_rounds=0, \
                                  value_type=int_type.basic_type)
        self.size = types.MemValue(int_type(0))
        self.int_type = int_type
    def append(self, item, for_real=True):
        self.oram.access(self.size, item, for_real)
        self.size.iadd(self.int_type(for_real))
    def pop(self):
        self.size.isub(1)
        return self.oram[self.size]

class Matchmaker:
    def init_hard(self, n_loops=None):
        if n_loops is None or n_loops > self.N * self.M:
            inner_loops = self.M
            outer_loops = self.N
        else:
            inner_loops = min(n_loops, self.M)
            outer_loops = n_loops / inner_loops
        self.m_prefs = OMatrix(self.N, self.M, oram_type=self.oram_type, \
                                   int_type=self.int_type)
        @for_range(outer_loops)
        def f(i):
            time()
            types.cint(i).print_reg('mpre')
            @for_range(inner_loops)
            def f(j):
                self.m_prefs[i][j] = (-i + j + self.N - 1) % (self.N - 1)
            if self.M < self.N:
                self.m_prefs[i][self.M-1] = (2 * self.N - 2 - i) % self.N
            else:
                self.m_prefs[i][self.N-1] = self.N - 1
        if self.reverse:
            self.f_ranks = OReverseMatrix(self.N, self.M, \
                                              oram_type=self.oram_type, \
                                              int_type=self.int_type)
        else:
            self.f_ranks = OMatrix(self.N, oram_type=self.oram_type, \
                                       int_type=self.int_type)
        @for_range(outer_loops)
        def f(i):
            time()
            types.cint(i).print_reg('fran')
            @for_range(inner_loops)
            def f(j):
                if self.reverse:
                    self.f_ranks[i] = tuple((-i - j + 2 * self.N - 2) % self.N \
                                                for j in range(self.M))
                else:
                    self.f_ranks[i][(-i - j + 2 * self.N - 2) % self.N] = j
                #self.f_ranks[i][j].reveal().print_reg()

    def init_easy(self):
        self.m_prefs = OMatrix(self.N, self.M, oram_type=self.oram_type, \
                                   int_type=self.int_type)
        @for_range(self.N)
        def f(i):
            time()
            types.cint(i).print_reg('mpre')
            @for_range(self.M)
            def f(j):
                self.m_prefs[i][j] = (i + j) % self.N
        self.f_ranks = OMatrix(self.N, oram_type=self.oram_type, \
                                   int_type=self.int_type)
        @for_range(self.N)
        def f(i):
            time()
            types.cint(i).print_reg('fran')
            @for_range(self.M)
            def f(j):
                self.f_ranks[i][(j-i+self.N)%self.N] = j

    def engage(self, man, woman, for_real):
        self.wives.access(man, woman, for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('a')
        self.husbands.access(woman, man, for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('b')
        #(man * 10 + woman * 1 + for_real * 100).reveal().print_reg('eng')
        # if for_real:
        #     print 'engage', man, woman
        #     self.wives[man] = woman
        #     self.husbands[woman] = man

    def dump(self, man, woman, for_real):
        self.wives.delete(man, for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('c')
        self.husbands.delete(woman, for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('d')
        self.unengaged.append(man, for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('e')
        #(man * 10 + woman + for_real * 100).reveal().print_reg('dump')
        # if for_real:
        #     print 'dump', man, woman
        #     self.wives[man] = clown
        #     self.husbands[woman] = clown

    def propose(self, man, woman, for_real):
        (fiance,), free = self.husbands.read(woman)
        #self.husbands.ram[0].x[0].reveal().print_reg('f')
        engaged = 1 - free
        rank_man = self.f_ranks[woman][man]
        #self.husbands.ram[0].x[0].reveal().print_reg('g')
        (rank_fiance,), worst_fiance = self.f_ranks[woman].read(engaged*fiance)
        #self.husbands.ram[0].x[0].reveal().print_reg('h')
        leaving = self.int_type(rank_man) < self.int_type(rank_fiance)
        if self.M < self.N:
            leaving = 1 - (1 - leaving) * (1 - worst_fiance)
        print_str('woman: %s, man: %s, fiance: %s, worst fiance: %s, ', \
                     *(x.reveal() for x in (woman, man, fiance, worst_fiance)))
        print_ln('rank man: %s, rank fiance: %s, engaged: %s, leaving: %s', \
                     *(x.reveal() for x in \
                           (rank_man, rank_fiance, engaged, leaving)))
        self.dump(fiance, woman, engaged * leaving * for_real)
        self.engage(man, woman, (1 - (engaged * (1 - leaving))) * for_real)
        self.unengaged.append(man, engaged * (1 - leaving) * for_real)
        #self.husbands.ram[0].x[0].reveal().print_reg('i')

    def match(self, n_loops=None):
        if n_loops is None or n_loops > self.N * self.M:
            loop = do_while
            init_rounds = self.N
        else:
            loop = for_range(n_loops)
            init_rounds = n_loops / self.M
        self.wives = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        self.husbands = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        propose = \
            self.oram_type(self.N, entry_size=log2(self.N), \
                               init_rounds=0, value_type=self.basic_type)
        self.unengaged = OStack(self.N, oram_type=self.oram_type, \
                                    int_type=self.int_type)
        @for_range(init_rounds)
        def f(i):
            self.unengaged.append(i)
        rounds = types.MemValue(types.regint(0))
        @loop
        def f(i=None):
            rounds.iadd(1)
            time()
            man = self.unengaged.pop()
            #self.husbands.ram[0].x[0].reveal().print_reg('j')
            pref = self.int_type(propose[man])
            if self.M < self.N and n_loops is None:
                @if_((pref == self.M).reveal())
                def f():
                    print_ln('run out of acceptable women')
                    crash()
            #self.husbands.ram[0].x[0].reveal().print_reg('k')
            propose[man] = pref + 1
            #self.husbands.ram[0].x[0].reveal().print_reg('l')
            self.propose(man, self.m_prefs[man][pref], True)
            print_ln('man: %s, pref: %s, left: %s', \
                         *(x.reveal() for x in (man, pref, self.unengaged.size)))
            # self.wives[man].reveal().print_reg('wife')
            return types.regint((self.unengaged.size > 0).reveal())
        print_ln('%s rounds', rounds)
        @for_range(init_rounds)
        def f(i):
            types.cint(i).print_reg('wife')
            self.husbands[i].reveal().print_reg('husb')

    def __init__(self, N, M=None, reverse=False, oram_type=OptimalORAM, \
                     int_type=types.sint):
        self.N = N
        self.M = N if M is None else M
        self.oram_type = oram_type
        self.reverse = reverse
        self.int_type = int_type
        self.basic_type = int_type.basic_type
        print('match', self.oram_type)
