"""
This module contains all instruction types for arithmetic computation
and general control of the virtual machine such as control flow.

The parameter descriptions refer to the instruction arguments in the
right order.
"""

# All base classes, utility functions etc. should go in
# instructions_base.py instead. This is for two reasons:
# 1) Easier generation of documentation
# 2) Ensures that 'from instructions import *' will only import assembly
# instructions and nothing else.
#
# Note: every instruction should have a suitable docstring for
# auto-generation of documentation

import itertools
import operator
from . import tools
from random import randint
from functools import reduce
from Compiler.config import *
from Compiler.exceptions import *
import Compiler.instructions_base as base


# avoid naming collision with input instruction
_python_input = input

###
### Load and store instructions
###

@base.gf2n
@base.vectorize
class ldi(base.Instruction):
    """ Assign (constant) immediate value to clear register (vector).

    :param: destination (cint)
    :param: value (int)
    """
    __slots__ = []
    code = base.opcodes['LDI']
    arg_format = ['cw','i']

@base.gf2n
@base.vectorize
class ldsi(base.Instruction):
    """ Assign (constant) immediate value to secret register (vector).

    :param: destination (sint)
    :param: value (int)
    """
    __slots__ = []
    code = base.opcodes['LDSI']
    arg_format = ['sw','i']

@base.gf2n
@base.vectorize
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
    """ Assign clear memory value(s) to clear register (vector) by
    immediate address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: destination (cint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['LDMC']
    arg_format = ['cw','int']

@base.gf2n
@base.vectorize
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
    """ Assign secret memory value(s) to secret register (vector) by
    immediate address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: destination (sint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['LDMS']
    arg_format = ['sw','int']

@base.gf2n
@base.vectorize
class stmc(base.DirectMemoryWriteInstruction):
    """ Assign clear register (vector) to clear memory value(s) by
    immediate address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: source (cint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['STMC']
    arg_format = ['c','int']

@base.gf2n
@base.vectorize
class stms(base.DirectMemoryWriteInstruction):
    """ Assign secret register (vector) to secret memory value(s) by
    immediate address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: source (sint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['STMS']
    arg_format = ['s','int']

@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
    """ Assign clear integer memory value(s) to clear integer register
    (vector) by immediate address. The vectorized version starts at
    the base address and then iterates the memory address.

    :param: destination (regint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['LDMINT']
    arg_format = ['ciw','int']

@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
    """ Assign clear integer register (vector) to clear integer memory
    value(s) by immediate address. The vectorized version starts at
    the base address and then iterates the memory address.

    :param: source (regint)
    :param: memory address base (int)

    """
    __slots__ = []
    code = base.opcodes['STMINT']
    arg_format = ['ci','int']

@base.vectorize
class ldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign clear memory value(s) to clear register (vector) by
    register address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: destination (cint)
    :param: memory address base (regint)

    """
    code = base.opcodes['LDMCI']
    arg_format = ['cw','ci']
    direct = staticmethod(ldmc)

@base.vectorize
class ldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign secret memory value(s) to secret register (vector) by
    register address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: destination (sint)
    :param: memory address base (regint)

    """
    code = base.opcodes['LDMSI']
    arg_format = ['sw','ci']
    direct = staticmethod(ldms)

@base.vectorize
class stmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign clear register (vector) to clear memory value(s) by
    register address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: source (cint)
    :param: memory address base (regint)

    """
    code = base.opcodes['STMCI']
    arg_format = ['c','ci']
    direct = staticmethod(stmc)

@base.vectorize
class stmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign secret register (vector) to secret memory value(s) by
    register address. The vectorized version starts at the base
    address and then iterates the memory address.

    :param: source (sint)
    :param: memory address base (regint)

    """
    code = base.opcodes['STMSI']
    arg_format = ['s','ci']
    direct = staticmethod(stms)

@base.vectorize
class ldminti(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign clear integer memory value(s) to clear integer register
    (vector) by register address. The vectorized version starts at the
    base address and then iterates the memory address.

    :param: destination (regint)
    :param: memory address base (regint)

    """
    code = base.opcodes['LDMINTI']
    arg_format = ['ciw','ci']
    direct = staticmethod(ldmint)

@base.vectorize
class stminti(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
    """ Assign clear integer register (vector) to clear integer memory
    value(s) by register address. The vectorized version starts at the
    base address and then iterates the memory address.

    :param: source (regint)
    :param: memory address base (regint)

    """
    code = base.opcodes['STMINTI']
    arg_format = ['ci','ci']
    direct = staticmethod(stmint)

@base.vectorize
class gldmci(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
    r""" Assigns register $c_i$ the value in memory \verb+C[cj]+. """
    code = base.opcodes['LDMCI'] + 0x100
    arg_format = ['cgw','ci']
    direct = staticmethod(gldmc)

@base.vectorize
class gldmsi(base.ReadMemoryInstruction, base.IndirectMemoryInstruction):
    r""" Assigns register $s_i$ the value in memory \verb+S[cj]+. """
    code = base.opcodes['LDMSI'] + 0x100
    arg_format = ['sgw','ci']
    direct = staticmethod(gldms)

@base.vectorize
class gstmci(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
    r""" Sets \verb+C[cj]+ to be the value $c_i$. """
    code = base.opcodes['STMCI'] + 0x100
    arg_format = ['cg','ci']
    direct = staticmethod(gstmc)

@base.vectorize
class gstmsi(base.WriteMemoryInstruction, base.IndirectMemoryInstruction):
    r""" Sets \verb+S[cj]+ to be the value $s_i$. """
    code = base.opcodes['STMSI'] + 0x100
    arg_format = ['sg','ci']
    direct = staticmethod(gstms)

@base.gf2n
@base.vectorize
class movc(base.Instruction):
    """ Copy clear register (vector).

    :param: destination (cint)
    :param: source (cint)
    """
    __slots__ = []
    code = base.opcodes['MOVC']
    arg_format = ['cw','c']

@base.gf2n
@base.vectorize
class movs(base.Instruction):
    """ Copy secret register (vector).

    :param: destination (cint)
    :param: source (cint)
    """
    __slots__ = []
    code = base.opcodes['MOVS']
    arg_format = ['sw','s']

@base.vectorize
class movint(base.Instruction):
    """ Copy clear integer register (vector).

    :param: destination (regint)
    :param: source (regint)
    """
    __slots__ = []
    code = base.opcodes['MOVINT']
    arg_format = ['ciw','ci']

@base.vectorize
class pushint(base.StackInstruction):
    """ Pushes clear integer register to the thread-local stack.

    :param: source (regint)
    """
    code = base.opcodes['PUSHINT']
    arg_format = ['ci']

@base.vectorize
class popint(base.StackInstruction):
    """ Pops from the thread-local stack to clear integer register.

    :param: destination (regint)
    """
    code = base.opcodes['POPINT']
    arg_format = ['ciw']


###
### Machine
###

@base.vectorize
class ldtn(base.Instruction):
    """ Store the number of the current thread in clear integer register.

    :param: destination (regint)
    """
    code = base.opcodes['LDTN']
    arg_format = ['ciw']

@base.vectorize
class ldarg(base.Instruction):
    """ Store the argument passed to the current thread in clear integer
    register.

    :param: destination (regint)
    """
    code = base.opcodes['LDARG']
    arg_format = ['ciw']

@base.vectorize
class starg(base.Instruction):
    """ Copy clear integer register to the thread argument.

    :param: source (regint)
    """
    code = base.opcodes['STARG']
    arg_format = ['ci']

@base.gf2n
class reqbl(base.Instruction):
    """ Requirement on computation modulus. Minimal bit length of prime if
    positive, minus exact bit length of power of two if negative.


    :param: requirement (int)
    """
    code = base.opcodes['REQBL']
    arg_format = ['int']

class time(base.IOInstruction):
    """ Output time since start of computation. """
    code = base.opcodes['TIME']
    arg_format = []

class start(base.Instruction):
    """ Start timer.

    :param: timer number (int)
    """
    code = base.opcodes['START']
    arg_format = ['i']

class stop(base.Instruction):
    """ Stop timer.

    :param: timer number (int)
    """
    code = base.opcodes['STOP']
    arg_format = ['i']

class use(base.Instruction):
    """ Offline data usage. Necessary to avoid reusage while using
    preprocessing from files. Also used to multithreading for expensive
    preprocessing.

    :param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
    :param: type (0: triple, 1: square, 2: bit, 3: inverse, 6: daBit)
    :param: number (int, -1 for unknown)
    """
    code = base.opcodes['USE']
    arg_format = ['int','int','int']

class use_inp(base.Instruction):
    """ Input usage.  Necessary to avoid reusage while using
    preprocessing from files.

    :param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
    :param: input player (int)
    :param: number (int, -1 for unknown)
    """
    code = base.opcodes['USE_INP']
    arg_format = ['int','int','int']

class use_edabit(base.Instruction):
    """ edaBit usage. Necessary to avoid reusage while using
    preprocessing from files. Also used to multithreading for expensive
    preprocessing.

    :param: loose/strict (0/1)
    :param: length (int)
    :param: number (int, -1 for unknown)
    """
    code = base.opcodes['USE_EDABIT']
    arg_format = ['int','int','int']

class run_tape(base.Instruction):
    """ Start tape/bytecode file in another thread.

    :param: number of arguments to follow (multiple of three)
    :param: tape number (int)
    :param: virtual machine thread number (int)
    :param: tape argument (int)
    :param: (repeat the last three)...
    """
    code = base.opcodes['RUN_TAPE']
    arg_format = tools.cycle(['int','int','int'])

class join_tape(base.Instruction):
    """ Join thread.

    :param: virtual machine thread number (int)
    """
    code = base.opcodes['JOIN_TAPE']
    arg_format = ['int']

class crash(base.IOInstruction):
    """ Crash runtime. """
    code = base.opcodes['CRASH']
    arg_format = []

class start_grind(base.IOInstruction):
    code = base.opcodes['STARTGRIND']
    arg_format = []

class stop_grind(base.IOInstruction):
    code = base.opcodes['STOPGRIND']
    arg_format = []

@base.gf2n
class use_prep(base.Instruction):
    """ Custom preprocessed data usage.

    :param: tag (16 bytes / 4 units, cut off at first zero byte)
    :param: number of items to use (int, -1 for unknown)
    """
    code = base.opcodes['USE_PREP']
    arg_format = ['str','int']

class nplayers(base.Instruction):
    """ Store number of players in clear integer register.

    :param: destination (regint)
    """
    code = base.opcodes['NPLAYERS']
    arg_format = ['ciw']

class threshold(base.Instruction):
    """ Store maximal number of corrupt players in clear integer register.

    :param: destination (regint)
    """
    code = base.opcodes['THRESHOLD']
    arg_format = ['ciw']

class playerid(base.Instruction):
    """ Store current player number in clear integer register.

    :param: destination (regint)
    """
    code = base.opcodes['PLAYERID']
    arg_format = ['ciw']

###
### Basic arithmetic
###

@base.gf2n
@base.vectorize
class addc(base.AddBase):
    """ Clear addition.

    :param: result (cint)
    :param: summand (cint)
    :param: summand (cint)
    """
    __slots__ = []
    code = base.opcodes['ADDC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class adds(base.AddBase):
    """ Secret addition.

    :param: result (sint)
    :param: summand (sint)
    :param: summand (sint)
    """
    __slots__ = []
    code = base.opcodes['ADDS']
    arg_format = ['sw','s','s']

@base.gf2n
@base.vectorize
class addm(base.AddBase):
    """ Mixed addition.

    :param: result (sint)
    :param: summand (sint)
    :param: summand (cint)
    """
    __slots__ = []
    code = base.opcodes['ADDM']
    arg_format = ['sw','s','c']

@base.gf2n
@base.vectorize
class subc(base.SubBase):
    """ Clear subtraction.

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (cint)
    """
    __slots__ = []
    code = base.opcodes['SUBC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class subs(base.SubBase):
    """ Secret subtraction.

    :param: result (sint)
    :param: first operand (sint)
    :param: second operand (sint)
    """
    __slots__ = []
    code = base.opcodes['SUBS']
    arg_format = ['sw','s','s']

@base.gf2n
@base.vectorize
class subml(base.SubBase):
    """ Subtract clear from secret value.

    :param: result (sint)
    :param: first operand (sint)
    :param: second operand (cint)
    """
    __slots__ = []
    code = base.opcodes['SUBML']
    arg_format = ['sw','s','c']

@base.gf2n
@base.vectorize
class submr(base.SubBase):
    """ Subtract secret from clear value.

    :param: result (sint)
    :param: first operand (cint)
    :param: second operand (sint)
    """
    __slots__ = []
    code = base.opcodes['SUBMR']
    arg_format = ['sw','c','s']

@base.gf2n
@base.vectorize
class mulc(base.MulBase):
    """ Clear multiplication.

    :param: result (cint)
    :param: factor (cint)
    :param: factor (cint)
    """
    __slots__ = []
    code = base.opcodes['MULC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class mulm(base.MulBase):
    """ Multiply secret and clear value.

    :param: result (sint)
    :param: factor (sint)
    :param: factor (cint)
    """
    __slots__ = []
    code = base.opcodes['MULM']
    arg_format = ['sw','s','c']

@base.gf2n
@base.vectorize
class divc(base.InvertInstruction):
    """ Clear division.

    :param: result (cint)
    :param: dividend (cint)
    :param: divisor (cint)
    """
    __slots__ = []
    code = base.opcodes['DIVC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class modc(base.Instruction):
    """ Clear modular reduction.

    :param: result (cint)
    :param: dividend (cint)
    :param: divisor (cint)
    """
    __slots__ = []
    code = base.opcodes['MODC']
    arg_format = ['cw','c','c']

@base.vectorize
class inv2m(base.InvertInstruction):
    """ Inverse of power of two modulo prime (the computation modulus).

    :param: result (cint)
    :param: exponent (int)
    """
    __slots__ = []
    code = base.opcodes['INV2M']
    arg_format = ['cw','int']

@base.vectorize
class legendrec(base.Instruction):
    """ Clear Legendre symbol computation (a/p) over prime p
    (the computation modulus).

    :param: result (cint)
    :param: a (int)
    """
    __slots__ = []
    code = base.opcodes['LEGENDREC']
    arg_format = ['cw','c']

@base.vectorize
class digestc(base.Instruction):
    """ Clear truncated hash computation.

    :param: result (cint)
    :param: input (cint)
    :param: byte length of hash value used (int)
    """
    __slots__ = []
    code = base.opcodes['DIGESTC']
    arg_format = ['cw','c','int']

###
### Bitwise operations
###

@base.gf2n
@base.vectorize
class andc(base.Instruction):
    """ Logical AND of clear (vector) registers.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (cint)
    """
    __slots__ = []
    code = base.opcodes['ANDC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class orc(base.Instruction):
    """ Logical OR of clear (vector) registers.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (cint)
    """
    __slots__ = []
    code = base.opcodes['ORC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class xorc(base.Instruction):
    """ Logical XOR of clear (vector) registers.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (cint)
    """
    __slots__ = []
    code = base.opcodes['XORC']
    arg_format = ['cw','c','c']

@base.vectorize
class notc(base.Instruction):
    """ Clear logical NOT of a constant number of bits of clear
    (vector) register.

    :param: result (cint)
    :param: operand (cint)
    :param: bit length (int)
    """
    __slots__ = []
    code = base.opcodes['NOTC']
    arg_format = ['cw','c', 'int']

@base.vectorize
class gnotc(base.Instruction):
    r""" Clear logical NOT $cg_i = \lnot cg_j$ """
    __slots__ = []
    code = (1 << 8) + base.opcodes['NOTC']
    arg_format = ['cgw','cg']

    def is_gf2n(self):
        return True

@base.vectorize
class gbitdec(base.Instruction):
    r""" Store every $n$-th bit of $cg_i$ in $cg_j, \dots$. """
    __slots__ = []
    code = base.opcodes['GBITDEC']
    arg_format = tools.chain(['cg', 'int'], itertools.repeat('cgw'))

    def is_g2fn(self):
        return True

    def has_var_args(self):
        return True

@base.vectorize
class gbitcom(base.Instruction):
    r""" Store the bits $cg_j, \dots$ as every $n$-th bit of $cg_i$. """
    __slots__ = []
    code = base.opcodes['GBITCOM']
    arg_format = tools.chain(['cgw', 'int'], itertools.repeat('cg'))

    def is_g2fn(self):
        return True

    def has_var_args(self):
        return True


###
### Special GF(2) arithmetic instructions
###

@base.vectorize
class gmulbitc(base.MulBase):
    r""" Clear GF(2^n) by clear GF(2) multiplication """
    __slots__ = []
    code = base.opcodes['GMULBITC']
    arg_format = ['cgw','cg','cg']

    def is_gf2n(self):
        return True

@base.vectorize
class gmulbitm(base.MulBase):
    r""" Secret GF(2^n) by clear GF(2) multiplication """
    __slots__ = []
    code = base.opcodes['GMULBITM']
    arg_format = ['sgw','sg','cg']

    def is_gf2n(self):
        return True

###
### Arithmetic with immediate values
###

@base.gf2n
@base.vectorize
class addci(base.ClearImmediate):
    """ Addition of clear register (vector) and (constant) immediate value.

    :param: result (cint)
    :param: summand (cint)
    :param: summand (int)
    """
    __slots__ = []
    code = base.opcodes['ADDCI']
    op = '__add__'

@base.gf2n
@base.vectorize
class addsi(base.SharedImmediate):
    """ Addition of secret register (vector)  and (constant) immediate value.

    :param: result (cint)
    :param: summand (cint)
    :param: summand (int)
    """
    __slots__ = []
    code = base.opcodes['ADDSI']
    op = '__add__'

@base.gf2n
@base.vectorize
class subci(base.ClearImmediate):
    """ Subtraction of (constant) immediate value from clear register (vector).

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (int)
    """
    __slots__ = []
    code = base.opcodes['SUBCI']
    op = '__sub__'

@base.gf2n
@base.vectorize
class subsi(base.SharedImmediate):
    """ Subtraction of (constant) immediate value from secret
    register (vector).

    :param: result (sint)
    :param: first operand (sint)
    :param: second operand (int)
    """
    __slots__ = []
    code = base.opcodes['SUBSI']
    op = '__sub__'

@base.gf2n
@base.vectorize
class subcfi(base.ClearImmediate):
    """ Subtraction of clear register (vector) from (constant)
    immediate value.

    :param: result (cint)
    :param: first operand (int)
    :param: second operand (cint)
    """
    __slots__ = []
    code = base.opcodes['SUBCFI']
    op = '__rsub__'

@base.gf2n
@base.vectorize
class subsfi(base.SharedImmediate):
    """ Subtraction of secret register (vector) from (constant)
    immediate value.

    :param: result (sint)
    :param: first operand (int)
    :param: second operand (sint)
    """
    __slots__ = []
    code = base.opcodes['SUBSFI']
    op = '__rsub__'

@base.gf2n
@base.vectorize
class mulci(base.ClearImmediate):
    """ Multiplication of clear register (vector) and (constant)
    immediate value.

    :param: result (cint)
    :param: factor (cint)
    :param: factor (int)
    """
    __slots__ = []
    code = base.opcodes['MULCI']
    op = '__mul__'

@base.gf2n
@base.vectorize
class mulsi(base.SharedImmediate):
    """ Multiplication of secret register (vector) and (constant)
    immediate value.

    :param: result (sint)
    :param: factor (sint)
    :param: factor (int)

    """
    __slots__ = []
    code = base.opcodes['MULSI']
    op = '__mul__'

@base.gf2n
@base.vectorize
class divci(base.InvertInstruction, base.ClearImmediate):
    """ Division of secret register (vector) and (constant) immediate value.

    :param: result (cint)
    :param: dividend (cint)
    :param: divisor (int)
    """
    __slots__ = []
    code = base.opcodes['DIVCI']

@base.gf2n
@base.vectorize
class modci(base.ClearImmediate):
    """ Modular reduction of clear register (vector) and (constant)
    immediate value.

    :param: result (cint)
    :param: dividend (cint)
    :param: divisor (int)

    """
    __slots__ = []
    code = base.opcodes['MODCI']
    op = '__mod__'

@base.gf2n
@base.vectorize
class andci(base.ClearImmediate):
    """ Logical AND of clear register (vector) and (constant)
    immediate value.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (int)
    """
    __slots__ = []
    code = base.opcodes['ANDCI']
    op = '__and__'

@base.gf2n
@base.vectorize
class xorci(base.ClearImmediate):
    """ Logical XOR of clear register (vector) and (constant)
    immediate value.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (int)
    """
    __slots__ = []
    code = base.opcodes['XORCI']
    op = '__xor__'

@base.gf2n
@base.vectorize
class orci(base.ClearImmediate):
    """ Logical OR of clear register (vector) and (constant)
    immediate value.

    :param: result (cint)
    :param: operand (cint)
    :param: operand (int)
    """
    __slots__ = []
    code = base.opcodes['ORCI']
    op = '__or__'


###
### Shift instructions
###

@base.gf2n
@base.vectorize
class shlc(base.Instruction):
    """ Bitwise left shift of clear register (vector).

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (cint)
    """
    __slots__ = []
    code = base.opcodes['SHLC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class shrc(base.Instruction):
    """ Bitwise right shift of clear register (vector).

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (cint)
    """
    __slots__ = []
    code = base.opcodes['SHRC']
    arg_format = ['cw','c','c']

@base.gf2n
@base.vectorize
class shlci(base.ClearShiftInstruction):
    """ Bitwise left shift of clear register (vector) by (constant)
    immediate value.

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (int)

    """
    __slots__ = []
    code = base.opcodes['SHLCI']
    op = '__lshift__'

@base.gf2n
@base.vectorize
class shrci(base.ClearShiftInstruction):
    """ Bitwise right shift of clear register (vector) by (constant)
    immediate value.

    :param: result (cint)
    :param: first operand (cint)
    :param: second operand (int)

    """
    __slots__ = []
    code = base.opcodes['SHRCI']
    op = '__rshift__'

@base.vectorize
class shrsi(base.ClearShiftInstruction):
    """ Bitwise right shift of secret register (vector) by (constant)
    immediate value. This only makes sense in connection with
    protocols allowing local share conversion (i.e., based on additive
    secret sharing modulo a power of two). Moreover, the result is not
    a secret sharing of the right shift of the secret value but needs
    to be corrected using the overflow. This is explained by `Dalskov
    et al. <https://eprint.iacr.org/2020/1330>`_ in the appendix.

    :param: result (sint)
    :param: first operand (sint)
    :param: second operand (int)

    """
    __slots__ = []
    code = base.opcodes['SHRSI']
    arg_format = ['sw','s','i']

###
### Data access instructions
###

@base.gf2n
@base.vectorize
class triple(base.DataInstruction):
    """ Store fresh random triple(s) in secret register (vectors).

    :param: factor (sint)
    :param: factor (sint)
    :param: product (sint)
    """
    __slots__ = []
    code = base.opcodes['TRIPLE']
    arg_format = ['sw','sw','sw']
    data_type = 'triple'

@base.vectorize
class gbittriple(base.DataInstruction):
    r""" Load secret variables $s_i$, $s_j$ and $s_k$
    with the next GF(2) multiplication triple. """
    __slots__ = []
    code = base.opcodes['GBITTRIPLE']
    arg_format = ['sgw','sgw','sgw']
    data_type = 'bittriple'
    field_type = 'gf2n'

    def is_gf2n(self):
        return True

@base.vectorize
class gbitgf2ntriple(base.DataInstruction):
    r""" Load secret variables $s_i$, $s_j$ and $s_k$
    with the next GF(2) and GF(2^n) multiplication triple. """
    code = base.opcodes['GBITGF2NTRIPLE']
    arg_format = ['sgw','sgw','sgw']
    data_type = 'bitgf2ntriple'
    field_type = 'gf2n'

    def is_gf2n(self):
        return True

@base.gf2n
@base.vectorize
class bit(base.DataInstruction):
    """ Store fresh random triple(s) in secret register (vectors).

    :param: destination (sint)
    """
    __slots__ = []
    code = base.opcodes['BIT']
    arg_format = ['sw']
    data_type = 'bit'

@base.vectorize
class dabit(base.DataInstruction):
    """ Store fresh random daBit(s) in secret register (vectors).

    :param: arithmetic part (sint)
    :param: binary part (sbit)
    """
    __slots__ = []
    code = base.opcodes['DABIT']
    arg_format = ['sw', 'sbw']
    field_type = 'modp'
    data_type = 'dabit'

@base.vectorize
class edabit(base.Instruction):
    """ Store fresh random loose edaBit(s) in secret register (vectors).
    The length is the first argument minus one.

    :param: number of arguments to follow / number of bits plus two (int)
    :param: arithmetic (sint)
    :param: binary (sbit)
    :param: (binary)...
    """
    __slots__ = []
    code = base.opcodes['EDABIT']
    arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment(('edabit', len(self.args) - 1), self.get_size())

@base.vectorize
class sedabit(base.Instruction):
    """ Store fresh random strict edaBit(s) in secret register (vectors).
    The length is the first argument minus one.

    :param: number of arguments to follow / number of bits plus two (int)
    :param: arithmetic (sint)
    :param: binary (sbit)
    :param: (binary)...
    """
    __slots__ = []
    code = base.opcodes['SEDABIT']
    arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment(('sedabit', len(self.args) - 1), self.get_size())

@base.vectorize
class randoms(base.Instruction):
    """ Store fresh length-restricted random shares(s) in secret register
    (vectors). This is only implemented for protocols that also implement
    local share conversion with :py:obj:`~Compiler.GC.instructions.split`.

    :param: destination (sint)
    :param: length (int)
    """
    __slots__ = []
    code = base.opcodes['RANDOMS']
    arg_format = ['sw','int']
    field_type = 'modp'

@base.vectorize
class randomfulls(base.Instruction):
    """ Store share(s) of a fresh secret random element in secret
    register (vectors).

    :param: destination (sint)
    """
    __slots__ = []
    code = base.opcodes['RANDOMFULLS']
    arg_format = ['sw']
    field_type = 'modp'

@base.gf2n
@base.vectorize
class square(base.DataInstruction):
    """ Store fresh random square(s) in secret register (vectors).

    :param: value (sint)
    :param: square (sint)
    """
    __slots__ = []
    code = base.opcodes['SQUARE']
    arg_format = ['sw','sw']
    data_type = 'square'

@base.gf2n
@base.vectorize
class inverse(base.DataInstruction):
    """ Store fresh random inverse(s) in secret register (vectors).

    :param: value (sint)
    :param: inverse (sint)
    """
    __slots__ = []
    code = base.opcodes['INV']
    arg_format = ['sw','sw']
    data_type = 'inverse'

    def __init__(self, *args, **kwargs):
        if program.options.ring and not self.is_gf2n():
            raise CompilerError('random inverse in ring not implemented')
        base.DataInstruction.__init__(self, *args, **kwargs)

@base.gf2n
@base.vectorize
class inputmask(base.Instruction):
    r""" Load secret $s_i$ with the next input mask for player $p$ and
    write the mask on player $p$'s private output. """ 
    __slots__ = []
    code = base.opcodes['INPUTMASK']
    arg_format = ['sw', 'p']
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment((self.field_type, 'input', self.args[1]), \
                               self.get_size())

@base.vectorize
class inputmaskreg(base.Instruction):
    """ Store fresh random input mask(s) in secret register (vector) and clear
    register (vector) of the relevant player.

    :param: mask (sint)
    :param: mask (cint, player only)
    :param: player (regint)
    """
    __slots__ = []
    code = base.opcodes['INPUTMASKREG']
    arg_format = ['sw', 'cw', 'ci']
    field_type = 'modp'

    def add_usage(self, req_node):
        # player 0 as proxy
        req_node.increment((self.field_type, 'input', 0), float('inf'))

@base.gf2n
@base.vectorize
class prep(base.Instruction):
    """ Store custom preprocessed data in secret register (vectors).

    :param: number of arguments to follow (int)
    :param: tag (16 bytes / 4 units, cut off at first zero byte)
    :param: destination (sint)
    :param: (repeat destination)...
    """
    __slots__ = []
    code = base.opcodes['PREP']
    arg_format = tools.chain(['str'], itertools.repeat('sw'))
    gf2n_arg_format = tools.chain(['str'], itertools.repeat('sgw'))
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment((self.field_type, self.args[0]), 1)

    def has_var_args(self):
        return True

###
### I/O
###

@base.gf2n
@base.vectorize
class asm_input(base.TextInputInstruction):
    r""" Receive input from player $p$ and put in register $s_i$. """
    __slots__ = []
    code = base.opcodes['INPUT']
    arg_format = tools.cycle(['sw', 'p'])
    field_type = 'modp'

    def add_usage(self, req_node):
        for player in self.args[1::2]:
            req_node.increment((self.field_type, 'input', player), \
                               self.get_size())

@base.vectorize
class inputfix(base.TextInputInstruction):
    __slots__ = []
    code = base.opcodes['INPUTFIX']
    arg_format = tools.cycle(['sw', 'int', 'p'])
    field_type = 'modp'

    def add_usage(self, req_node):
        for player in self.args[2::3]:
            req_node.increment((self.field_type, 'input', player), \
                               self.get_size())

@base.vectorize
class inputfloat(base.TextInputInstruction):
    __slots__ = []
    code = base.opcodes['INPUTFLOAT']
    arg_format = tools.cycle(['sw', 'sw', 'sw', 'sw', 'int', 'p'])
    field_type = 'modp'

    def add_usage(self, req_node):
        for player in self.args[5::6]:
            req_node.increment((self.field_type, 'input', player), \
                               4 * self.get_size())

class inputmixed_base(base.TextInputInstruction):
    __slots__ = []
    field_type = 'modp'
    # the following has to match TYPE: (N_DEST, N_PARAM)
    types = {
        0: (1, 0),
        1: (1, 1),
        2: (4, 1)
    }
    type_ids = {
        'int': 0,
        'fix': 1,
        'float': 2
    }

    def __init__(self, name, *args):
        type_id = self.type_ids[name]
        super(inputmixed_base, self).__init__(type_id, *args)

    @property
    def arg_format(self):
        for i in self.bases():
            t = self.args[i]
            yield 'int'
            for j in range(self.types[t][0]):
                yield 'sw'
            for j in range(self.types[t][1]):
                yield 'int'
            yield self.player_arg_type

    def bases(self):
        i = 0
        while i < len(self.args):
            yield i
            i += sum(self.types[self.args[i]]) + 2

@base.vectorize
class inputmixed(inputmixed_base):
    """ Store private input in secret registers (vectors). The input is
    read as integer or floating-point number and the latter is then
    converted to the internal representation using the given precision.
    This instruction uses compile-time player numbers.

    :param: number of arguments to follow (int)
    :param: type (0: integer, 1: fixed-point, 2: floating-point)
    :param: destination (sint)
    :param: destination (sint, only for floating-point)
    :param: destination (sint, only for floating-point)
    :param: destination (sint, only for floating-point)
    :param: fixed-point precision or precision of floating-point significand (int, not with integer)
    :param: input player (int)
    :param: (repeat from type parameter)...

    """
    code = base.opcodes['INPUTMIXED']
    player_arg_type = 'p'

    def add_usage(self, req_node):
        for i in self.bases():
            t = self.args[i]
            player = self.args[i + sum(self.types[t]) + 1]
            n_dest = self.types[t][0]
            req_node.increment((self.field_type, 'input', player), \
                               n_dest * self.get_size())

@base.vectorize
class inputmixedreg(inputmixed_base):
    """ Store private input in secret registers (vectors). The input is
    read as integer or floating-point number and the latter is then
    converted to the internal representation using the given precision.
    This instruction uses run-time player numbers.

    :param: number of arguments to follow (int)
    :param: type (0: integer, 1: fixed-point, 2: floating-point)
    :param: destination (sint)
    :param: destination (sint, only for floating-point)
    :param: destination (sint, only for floating-point)
    :param: destination (sint, only for floating-point)
    :param: fixed-point precision or precision of floating-point significand (int, not with integer)
    :param: input player (regint)
    :param: (repeat from type parameter)...

    """
    code = base.opcodes['INPUTMIXEDREG']
    player_arg_type = 'ci'

    def add_usage(self, req_node):
        # player 0 as proxy
        req_node.increment((self.field_type, 'input', 0), float('inf'))

@base.gf2n
@base.vectorize
class rawinput(base.RawInputInstruction, base.Mergeable):
    """ Store private input in secret registers (vectors). The input is
    read in the internal binary format according to the protocol.

    :param: number of arguments to follow (multiple of two)
    :param: player number (int)
    :param: destination (sint)
    """
    __slots__ = []
    code = base.opcodes['RAWINPUT']
    arg_format = tools.cycle(['p','sw'])
    field_type = 'modp'

    def add_usage(self, req_node):
        for i in range(0, len(self.args), 2):
            player = self.args[i]
            req_node.increment((self.field_type, 'input', player), \
                               self.get_size())

@base.gf2n
@base.vectorize
class print_reg(base.IOInstruction):
    """ Debugging output of clear register (vector).

    :param: source (cint)
    :param: comment (4 bytes / 1 unit)
    """
    __slots__ = []
    code = base.opcodes['PRINTREG']
    arg_format = ['c','i']
    
    def __init__(self, reg, comment=''):
        super(print_reg_class, self).__init__(reg, self.str_to_int(comment))

@base.gf2n
@base.vectorize
class print_reg_plain(base.IOInstruction):
    """ Output clear register.

    :param: source (cint)
    """
    __slots__ = []
    code = base.opcodes['PRINTREGPLAIN']
    arg_format = ['c']

class cond_print_plain(base.IOInstruction):
    """ Conditionally output clear register (with precision).
    Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.

    :param: condition (cint, no output if zero)
    :param: source (cint)
    :param: precision (cint)
    """
    code = base.opcodes['CONDPRINTPLAIN']
    arg_format = ['c', 'c', 'c']

class print_int(base.IOInstruction):
    """ Output clear integer register.

    :param: source (regint)
    """
    __slots__ = []
    code = base.opcodes['PRINTINT']
    arg_format = ['ci']

@base.vectorize
class print_float_plain(base.IOInstruction):
    """ Output floating-number from clear registers.

    :param: significand (cint)
    :param: exponent (cint)
    :param: zero bit (cint, zero output if bit is one)
    :param: sign bit (cint, negative output if bit is one)
    :param: NaN (cint, regular number if zero)
    """
    __slots__ = []
    code = base.opcodes['PRINTFLOATPLAIN']
    arg_format = ['c', 'c', 'c', 'c', 'c']

class print_float_prec(base.IOInstruction):
    """ Set number of digits after decimal point for
    :py:obj:`~Compiler.instructions.print_float_plain`.

    :param: number of digits (int)
    """
    __slots__ = []
    code = base.opcodes['PRINTFLOATPREC']
    arg_format = ['int']

class print_char(base.IOInstruction):
    """ Output a single byte.

    :param: byte (int)
    """
    code = base.opcodes['PRINTCHR']
    arg_format = ['int']

    def __init__(self, ch):
        super(print_char, self).__init__(ord(ch))

class print_char4(base.IOInstruction):
    """ Output four bytes.

    :param: four bytes (int)
    """
    code = base.opcodes['PRINTSTR']
    arg_format = ['int']

    def __init__(self, val):
        super(print_char4, self).__init__(self.str_to_int(val))

class cond_print_str(base.IOInstruction):
    """ Conditionally output four bytes.

    :param: condition (cint, no output if zero)
    :param: four bytes (int)
    """
    code = base.opcodes['CONDPRINTSTR']
    arg_format = ['c', 'int']

    def __init__(self, cond, val):
        super(cond_print_str, self).__init__(cond, self.str_to_int(val))

@base.vectorize
class pubinput(base.PublicFileIOInstruction):
    """ Store public input in clear register (vector).

    :param: destination (cint)
    """
    __slots__ = []
    code = base.opcodes['PUBINPUT']
    arg_format = ['cw']

@base.vectorize
class readsocketc(base.IOInstruction):
    """ Read a variable number of clear values in internal representation
    from socket for a specified client id and store them in clear registers.

    :param: number of arguments to follow / number of inputs minus one (int)
    :param: client id (regint)
    :param: destination (cint)
    :param: (repeat destination)...
    """
    __slots__ = []
    code = base.opcodes['READSOCKETC']
    arg_format = tools.chain(['ci'], itertools.repeat('cw'))

    def has_var_args(self):
        return True

@base.vectorize
class readsockets(base.IOInstruction):
    """Read a variable number of secret shares + MACs from socket for a client id and store in registers"""
    __slots__ = []
    code = base.opcodes['READSOCKETS']
    arg_format = tools.chain(['ci'], itertools.repeat('sw'))

    def has_var_args(self):
        return True

@base.vectorize
class readsocketint(base.IOInstruction):
    """ Read a variable number of 32-bit integers from socket for a
    specified client id and store them in clear integer registers.

    :param: number of arguments to follow / number of inputs minus one (int)
    :param: client id (regint)
    :param: destination (regint)
    :param: (repeat destination)...
    """
    __slots__ = []
    code = base.opcodes['READSOCKETINT']
    arg_format = tools.chain(['ci'], itertools.repeat('ciw'))

    def has_var_args(self):
        return True

@base.vectorize
class writesocketc(base.IOInstruction):
    """
    Write a variable number of clear GF(p) values from registers into socket 
    for a specified client id, message_type
    """
    __slots__ = []
    code = base.opcodes['WRITESOCKETC']
    arg_format = tools.chain(['ci', 'int'], itertools.repeat('c'))

    def has_var_args(self):
        return True

@base.vectorize
class writesocketshare(base.IOInstruction):
    """ Write a variable number of shares (without MACs) from secret
    registers into socket for a specified client id.

    :param: client id (regint)
    :param: message type (must be 0)
    :param: source (sint)
    :param: (repeat source)...
    """
    __slots__ = []
    code = base.opcodes['WRITESOCKETSHARE']
    arg_format = tools.chain(['ci', 'int'], itertools.repeat('s'))

    def has_var_args(self):
        return True

@base.vectorize
class writesocketint(base.IOInstruction):
    """
    Write a variable number of 32-bit ints from registers into socket
    for a specified client id, message_type
    """
    __slots__ = []
    code = base.opcodes['WRITESOCKETINT']
    arg_format = tools.chain(['ci', 'int'], itertools.repeat('ci'))

    def has_var_args(self):
        return True

class listen(base.IOInstruction):
    """ Open a server socket on a party-specific port number and listen for
    client connections (non-blocking).

    :param: port number (int)
    """
    __slots__ = []
    code = base.opcodes['LISTEN']
    arg_format = ['int']

class acceptclientconnection(base.IOInstruction):
    """ Wait for a connection at the given port and write socket handle
    to clear integer register.

    :param: client id destination (regint)
    :param: port number (int)
    """
    __slots__ = []
    code = base.opcodes['ACCEPTCLIENTCONNECTION']
    arg_format = ['ciw', 'int']

class closeclientconnection(base.IOInstruction):
    """ Close connection to client.

    :param: client id (regint)
    """
    __slots__ = []
    code = base.opcodes['CLOSECLIENTCONNECTION']
    arg_format = ['ci']

class writesharestofile(base.IOInstruction):
    """ Write shares to ``Persistence/Transactions-P<playerno>.data``
    (appending at the end).

    :param: number of shares (int)
    :param: source (sint)
    :param: (repeat from source)...

    """
    __slots__ = []
    code = base.opcodes['WRITEFILESHARE']
    arg_format = itertools.repeat('s')

    def has_var_args(self):
        return True

class readsharesfromfile(base.IOInstruction):
    """ Read shares from ``Persistence/Transactions-P<playerno>.data``.

    :param: number of arguments to follow / number of shares plus two (int)
    :param: starting position in number of shares from beginning (regint)
    :param: destination for final position, -1 for eof reached, or -2 for file not found (regint)
    :param: destination for share (sint)
    :param: (repeat from destination for share)...
    """
    __slots__ = []
    code = base.opcodes['READFILESHARE']
    arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))

    def has_var_args(self):
        return True

@base.gf2n
@base.vectorize
class raw_output(base.PublicFileIOInstruction):
    r""" Raw output of register \verb|ci| to file. """
    __slots__ = []
    code = base.opcodes['RAWOUTPUT']
    arg_format = ['c']

@base.gf2n
@base.vectorize
class startprivateoutput(base.Instruction):
    r""" Initiate private output to $n$ of $s_j$ via $s_i$. """
    __slots__ = []
    code = base.opcodes['STARTPRIVATEOUTPUT']
    arg_format = ['sw','s','p']
    field_type = 'modp'

    def add_usage(self, req_node):
        req_node.increment((self.field_type, 'input', self.args[2]), \
                               self.get_size())

@base.gf2n
@base.vectorize
class stopprivateoutput(base.Instruction):
    r""" Previously iniated private output to $n$ via $c_i$. """
    __slots__ = []
    code = base.opcodes['STOPPRIVATEOUTPUT']
    arg_format = ['cw','c','p']

@base.vectorize
class rand(base.Instruction):
    """ Store insecure random value of specified length in clear integer
    register (vector).

    :param: destination (regint)
    :param: length (regint)
    """
    __slots__ = []
    code = base.opcodes['RAND']
    arg_format = ['ciw','ci']

###
### Integer operations
### 

@base.vectorize
class ldint(base.Instruction):
    """ Store (constant) immediate value in clear integer register (vector).

    :param: destination (regint)
    :param: immediate (int)
    """
    __slots__ = []
    code = base.opcodes['LDINT']
    arg_format = ['ciw', 'i']

@base.vectorize
class addint(base.IntegerInstruction):
    """ Clear integer register (vector) addition.

    :param: result (regint)
    :param: summand (regint)
    :param: summand (regint)
    """
    __slots__ = []
    code = base.opcodes['ADDINT']
    op = operator.add

@base.vectorize
class subint(base.IntegerInstruction):
    """ Clear integer register (vector) subtraction.

    :param: result (regint)
    :param: first operand (regint)
    :param: second operand (regint)
    """
    __slots__ = []
    code = base.opcodes['SUBINT']
    op = operator.sub

@base.vectorize
class mulint(base.IntegerInstruction):
    """ Clear integer register (element-wise vector) multiplication.

    :param: result (regint)
    :param: factor (regint)
    :param: factor (regint)
    """
    __slots__ = []
    code = base.opcodes['MULINT']
    op = operator.mul

@base.vectorize
class divint(base.IntegerInstruction):
    """ Clear integer register (element-wise vector) division with floor
    rounding.

    :param: result (regint)
    :param: dividend (regint)
    :param: divisor (regint)
    """
    __slots__ = []
    code = base.opcodes['DIVINT']
    op = operator.floordiv

@base.vectorize
class bitdecint(base.Instruction):
    """ Clear integer bit decomposition.

    :param: number of arguments to follow / number of bits minus one (int)
    :param: source (regint)
    :param: destination for least significant bit (regint)
    :param: (destination for one bit higher)...
    """
    __slots__ = []
    code = base.opcodes['BITDECINT']
    arg_format = tools.chain(['ci'], itertools.repeat('ciw'))

class incint(base.VectorInstruction):
    """ Create incremental clear integer vector. For example, vector size 10,
    base 1, increment 2, repeat 3, and wrap 2 produces the following::

        (1, 1, 1, 3, 3, 3, 1, 1, 1, 3)

    This is because the first number is always the :py:obj:`base`,
    every number is repeated :py:obj:`repeat` times, after which
    :py:obj:`increment` is added, and after :py:obj:`wrap` increments
    the number returns to :py:obj:`base`.

    :param: destination (regint)
    :param: base (non-vector regint)
    :param: increment (int)
    :param: repeat (int)
    :param: wrap (int)

    """
    __slots__ = []
    code = base.opcodes['INCINT']
    arg_format = ['ciw', 'ci', 'i', 'i', 'i']

    def __init__(self, *args, **kwargs):
        assert len(args[1]) == 1
        if len(args) == 3:
            args = list(args) + [1, len(args[0])]
        super(incint, self).__init__(*args, **kwargs)

class shuffle(base.VectorInstruction):
    """ Randomly shuffles clear integer vector with public randomness.

    :param: destination (regint)
    :param: source (regint)
    """
    __slots__ = []
    code = base.opcodes['SHUFFLE']
    arg_format = ['ciw','ci']

    def __init__(self, *args, **kwargs):
        super(shuffle, self).__init__(*args, **kwargs)
        assert len(args[0]) == len(args[1])

###
### Clear comparison instructions
###

@base.vectorize
class eqzc(base.UnaryComparisonInstruction):
    """ Clear integer zero test. The result is 1 for true and 0 for false.

    :param: destination (regint)
    :param: operand (regint)
    """
    __slots__ = []
    code = base.opcodes['EQZC']

@base.vectorize
class ltzc(base.UnaryComparisonInstruction):
    """ Clear integer less than zero test. The result is 1 for true
    and 0 for false.

    :param: destination (regint)
    :param: operand (regint)
    """
    __slots__ = []
    code = base.opcodes['LTZC']

@base.vectorize
class ltc(base.IntegerInstruction):
    """ Clear integer less-than comparison. The result is 1 if the
    first operand is less and 0 otherwise.

    :param: destination (regint)
    :param: first operand (regint)
    :param: second operand (regint)
    """
    __slots__ = []
    code = base.opcodes['LTC']
    op = operator.lt

@base.vectorize
class gtc(base.IntegerInstruction):
    """ Clear integer greater-than comparison. The result is 1 if the
    first operand is greater and 0 otherwise.

    :param: destination (regint)
    :param: first operand (regint)
    :param: second operand (regint)
    """
    __slots__ = []
    code = base.opcodes['GTC']
    op = operator.gt

@base.vectorize
class eqc(base.IntegerInstruction):
    """ Clear integer equality test. The result is 1 if the operands
    are equal and 0 otherwise.

    :param: destination (regint)
    :param: first operand (regint)
    :param: second operand (regint)
    """
    __slots__ = []
    code = base.opcodes['EQC']
    op = operator.eq


###
### Jumps etc
###

class jmp(base.JumpInstruction):
    """ Unconditional relative jump in the bytecode (compile-time parameter).
    The parameter is added to the regular jump of one after every
    instruction. This means that a jump of 0 results in a no-op while
    a jump of -1 results in an infinite loop.

    :param: number of instructions (int)
    """
    __slots__ = []
    code = base.opcodes['JMP']
    arg_format = ['int']
    jump_arg = 0

class jmpi(base.JumpInstruction):
    """ Unconditional relative jump in the bytecode (run-time parameter).
    The parameter is added to the regular jump of one after every
    instruction. This means that a jump of 0 results in a no-op while
    a jump of -1 results in an infinite loop.

    :param: number of instructions (regint)
    """
    __slots__ = []
    code = base.opcodes['JMPI']
    arg_format = ['ci']
    jump_arg = 0

class jmpnz(base.JumpInstruction):
    """ Conditional relative jump in the bytecode.
    The parameter is added to the regular jump of one after every
    instruction. This means that a jump of 0 results in a no-op while
    a jump of -1 results in an infinite loop.

    :param: condition (regint, only jump if not zero)
    :param: number of instructions (int)
    """
    __slots__ = []
    code = base.opcodes['JMPNZ']
    arg_format = ['ci', 'int']
    jump_arg = 1

class jmpeqz(base.JumpInstruction):
    """ Conditional relative jump in the bytecode.
    The parameter is added to the regular jump of one after every
    instruction. This means that a jump of 0 results in a no-op while
    a jump of -1 results in an infinite loop.

    :param: condition (regint, only jump if zero)
    :param: number of instructions (int)
    """
    __slots__ = []
    code = base.opcodes['JMPEQZ']
    arg_format = ['ci', 'int']
    jump_arg = 1

###
### Conversions
###

@base.gf2n
@base.vectorize
class convint(base.Instruction):
    """ Convert clear integer register (vector) to clear register (vector).

    :param: destination (cint)
    :param: source (regint)
    """
    __slots__ =  []
    code = base.opcodes['CONVINT']
    arg_format = ['cw', 'ci']

@base.vectorize
class convmodp(base.Instruction):
    """ Convert clear integer register (vector) to clear register
    (vector). If the bit length is zero, the unsigned conversion is
    used, otherwise signed conversion is used. This makes a difference
    when computing modulo a prime :math:`p`. Signed conversion of
    :math:`p-1` results in -1 while signed conversion results in
    :math:`(p-1) \mod 2^{64}`.

    :param: destination (regint)
    :param: source (cint)
    :param: bit length (int)

    """
    __slots__ =  []
    code = base.opcodes['CONVMODP']
    arg_format = ['ciw', 'c', 'int']
    def __init__(self, *args, **kwargs):
        if len(args) == len(self.arg_format):
            super(convmodp_class, self).__init__(*args)
            return
        bitlength = kwargs.get('bitlength')
        bitlength = program.bit_length if bitlength is None else bitlength
        if bitlength > 64:
            raise CompilerError('%d-bit conversion requested ' \
                                'but integer registers only have 64 bits' % \
                                bitlength)
        super(convmodp_class, self).__init__(*(args + (bitlength,)))

@base.vectorize
class gconvgf2n(base.Instruction):
    """ Convert from clear modp register $c_j$ to integer register $ci_i$. """
    __slots__ =  []
    code = base.opcodes['GCONVGF2N']
    arg_format = ['ciw', 'cg']

###
### Other instructions
###

# rename 'open' to avoid conflict with built-in open function
@base.gf2n
@base.vectorize
class asm_open(base.VarArgsInstruction):
    """ Reveal secret registers (vectors) to clear registers (vectors).

    :param: number of argument to follow (multiple of two)
    :param: destination (cint)
    :param: source (sint)
    :param: (repeat the last two)...
    """
    __slots__ = []
    code = base.opcodes['OPEN']
    arg_format = tools.cycle(['cw','s'])

@base.gf2n
@base.vectorize
class muls(base.VarArgsInstruction, base.DataInstruction):
    """ (Element-wise) multiplication of secret registers (vectors).

    :param: number of arguments to follow (multiple of three)
    :param: result (sint)
    :param: factor (sint)
    :param: factor (sint)
    :param: (repeat the last three)...
    """
    __slots__ = []
    code = base.opcodes['MULS']
    arg_format = tools.cycle(['sw','s','s'])
    data_type = 'triple'

    def get_repeat(self):
        return len(self.args) // 3

    def merge_id(self):
        # can merge different sizes
        # but not if large
        if self.get_size() is None or self.get_size() > 100:
            return type(self), self.get_size()
        return type(self)

    # def expand(self):
    #     s = [program.curr_block.new_reg('s') for i in range(9)]
    #     c = [program.curr_block.new_reg('c') for i in range(3)]
    #     triple(s[0], s[1], s[2])
    #     subs(s[3], self.args[1], s[0])
    #     subs(s[4], self.args[2], s[1])
    #     asm_open(c[0], s[3])
    #     asm_open(c[1], s[4])
    #     mulm(s[5], s[1], c[0])
    #     mulm(s[6], s[0], c[1])
    #     mulc(c[2], c[0], c[1])
    #     adds(s[7], s[2], s[5])
    #     adds(s[8], s[7], s[6])
    #     addm(self.args[0], s[8], c[2])

@base.gf2n
class mulrs(base.VarArgsInstruction, base.DataInstruction):
    """ Constant-vector multiplication of secret registers.

    :param: number of arguments to follow (multiple of four)
    :param: vector size (int)
    :param: result (sint)
    :param: vector factor (sint)
    :param: constant factor (sint)
    :param: (repeat the last four)...
    """
    __slots__ = []
    code = base.opcodes['MULRS']
    arg_format = tools.cycle(['int','sw','s','s'])
    data_type = 'triple'
    is_vec = lambda self: True

    def __init__(self, res, x, y):
        assert y.size == 1
        assert res.size == x.size
        base.Instruction.__init__(self, res.size, res, x, y)

    def get_repeat(self):
        return sum(self.args[::4])

    def get_def(self):
        return sum((arg.get_all() for arg in self.args[1::4]), [])

    def get_used(self):
        return sum((arg.get_all()
                    for arg in self.args[2::4] + self.args[3::4]), [])

@base.gf2n
@base.vectorize
class dotprods(base.VarArgsInstruction, base.DataInstruction):
    """ Dot product of secret registers (vectors).
    Note that the vectorized version works element-wise.

    :param: number of arguments to follow (int)
    :param: twice the dot product length plus two (I know...)
    :param: result (sint)
    :param: first factor (sint)
    :param: first factor (sint)
    :param: second factor (sint)
    :param: second factor (sint)
    :param: (remaining factors)...
    :param: (repeat from dot product length)...
    """
    __slots__ = []
    code = base.opcodes['DOTPRODS']
    data_type = 'triple'

    def __init__(self, *args):
        flat_args = []
        for i in range(0, len(args), 3):
            res, x, y = args[i:i+3]
            assert len(x) == len(y)
            flat_args += [2 * len(x) + 2, res]
            for x, y in zip(x, y):
                flat_args += [x, y]
        base.Instruction.__init__(self, *flat_args)

    @property
    def arg_format(self):
        field = 'g' if self.is_gf2n() else ''
        for i in self.bases():
            yield 'int'
            yield 's' + field + 'w'
            for j in range(self.args[i] - 2):
                yield 's' + field

    gf2n_arg_format = arg_format

    def bases(self):
        i = 0
        while i < len(self.args):
            yield i
            i += self.args[i]

    def get_repeat(self):
        return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()

    def get_def(self):
        return [self.args[i + 1] for i in self.bases()]

    def get_used(self):
        for i in self.bases():
            for reg in self.args[i + 2:i + self.args[i]]:
                yield reg

class matmul_base(base.DataInstruction):
    data_type = 'triple'
    is_vec = lambda self: True

    def get_repeat(self):
        return reduce(operator.mul, self.args[3:6])

class matmuls(matmul_base):
    """ Secret matrix multiplication from registers. All matrices are
    represented as vectors in row-first order.

    :param: result (sint vector)
    :param: first factor (sint vector)
    :param: second factor (sint vector)
    :param: number of rows in first factor and result (int)
    :param: number of columns in first factor and rows in second factor (int)
    :param: number of columns in second factor and result (int)
    """
    code = base.opcodes['MATMULS']
    arg_format = ['sw','s','s','int','int','int']

class matmulsm(matmul_base):
    """ Secret matrix multiplication reading directly from memory.

    :param: result (sint vector in row-first order)
    :param: base address of first factor (regint value)
    :param: base address of second factor (regint value)
    :param: number of rows in first factor and result (int)
    :param: number of columns in first factor and rows in second factor (int)
    :param: number of columns in second factor and result (int)
    :param: rows of first factor to use (regint vector, length as number of rows in first factor)
    :param: columns of first factor to use (regint vector, length below)
    :param: rows of second factor to use (regint vector, length below)
    :param: columns of second factor to use (regint vector, length below)
    :param: number of columns of first / rows of second factor to use (int)
    :param: number of columns of second factor to use (int)
    """
    code = base.opcodes['MATMULSM']
    arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci',
                  'int','int']

    def __init__(self, *args, **kwargs):
        matmul_base.__init__(self, *args, **kwargs)
        for i in range(2):
            assert args[6 + i].size == args[3 + i]
        for i in range(2):
            assert args[8 + i].size == args[4 + i]

class conv2ds(base.DataInstruction):
    """ Secret 2D convolution.

    :param: result (sint vector in row-first order)
    :param: inputs (sint vector in row-first order)
    :param: weights (sint vector in row-first order)
    :param: output height (int)
    :param: output width (int)
    :param: input height (int)
    :param: input width (int)
    :param: weight height (int)
    :param: weight width (int)
    :param: stride height (int)
    :param: stride width (int)
    :param: number of channels (int)
    :param: padding height (int)
    :param: padding width (int)
    :param: batch size (int)
    """
    code = base.opcodes['CONV2DS']
    arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
                  'int','int','int','int']
    data_type = 'triple'
    is_vec = lambda self: True

    def __init__(self, *args, **kwargs):
        super(conv2ds, self).__init__(*args, **kwargs)
        assert args[0].size == args[3] * args[4] * args[14]
        assert args[1].size == args[5] * args[6] * args[11] * args[14]
        assert args[2].size == args[7] * args[8] * args[11]

    def get_repeat(self):
        return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
            self.args[11] * self.args[14]

@base.vectorize
class trunc_pr(base.VarArgsInstruction):
    """ Probabilistic truncation if supported by the protocol.

    :param: number of arguments to follow (multiple of four)
    :param: destination (sint)
    :param: source (sint)
    :param: bit length of source (int)
    :param: number of bits to truncate (int)
    """
    __slots__ = []
    code = base.opcodes['TRUNC_PR']
    arg_format = tools.cycle(['sw','s','int','int'])

###
### CISC-style instructions
###

@base.gf2n
@base.vectorize
class sqrs(base.CISC):
    """ Secret squaring $s_i = s_j \cdot s_j$. """
    __slots__ = []
    arg_format = ['sw', 's']
    
    def expand(self):
        if program.options.ring:
            return muls(self.args[0], self.args[1], self.args[1])
        s = [program.curr_block.new_reg('s') for i in range(6)]
        c = [program.curr_block.new_reg('c') for i in range(2)]
        square(s[0], s[1])
        subs(s[2], self.args[1], s[0])
        asm_open(c[0], s[2])
        mulc(c[1], c[0], c[0])
        mulm(s[3], self.args[1], c[0])
        adds(s[4], s[3], s[3])
        adds(s[5], s[1], s[4])
        subml(self.args[0], s[5], c[1])


@base.gf2n
@base.vectorize
class lts(base.CISC):
    """ Secret comparison $s_i = (s_j < s_k)$. """
    __slots__ = []
    arg_format = ['sw', 's', 's', 'int', 'int']

    def expand(self):
        from .types import sint
        a = sint()
        subs(a, self.args[1], self.args[2])
        comparison.LTZ(self.args[0], a, self.args[3], self.args[4])

@base.vectorize
class g2muls(base.CISC):
    r""" Secret GF(2) multiplication """
    __slots__ = []
    arg_format = ['sgw','sg','sg']

    def expand(self):
        s = [program.curr_block.new_reg('sg') for i in range(9)]
        c = [program.curr_block.new_reg('cg') for i in range(3)]
        gbittriple(s[0], s[1], s[2])
        gsubs(s[3], self.args[1], s[0])
        gsubs(s[4], self.args[2], s[1])
        gasm_open(c[0], s[3])
        gasm_open(c[1], s[4])
        gmulbitm(s[5], s[1], c[0])
        gmulbitm(s[6], s[0], c[1])
        gmulbitc(c[2], c[0], c[1])
        gadds(s[7], s[2], s[5])
        gadds(s[8], s[7], s[6])
        gaddm(self.args[0], s[8], c[2])

#@base.vectorize
#class gmulbits(base.CISC):
#    r""" Secret $GF(2^n) \times GF(2)$ multiplication """
#    __slots__ = []
#    arg_format = ['sgw','sg','sg']
#
#    def expand(self):
#        s = [program.curr_block.new_reg('s') for i in range(9)]
#        c = [program.curr_block.new_reg('c') for i in range(3)]
#        g2ntriple(s[0], s[1], s[2])
#        subs(s[3], self.args[1], s[0])
#        subs(s[4], self.args[2], s[1])
#        startopen(s[3], s[4])
#        stopopen(c[0], c[1])
#        mulm(s[5], s[1], c[0])
#        mulm(s[6], s[0], c[1])
#        mulc(c[2], c[0], c[1])
#        adds(s[7], s[2], s[5])
#        adds(s[8], s[7], s[6])
#        addm(self.args[0], s[8], c[2])

# hack for circular dependency
from Compiler import comparison
