import numpy as np
import copy

"""
WARNING: the "phase" variable should only ever take on the values of 0
and 1. It is manually changed in `utils.run_for_n_steps` based on the
introduction of the trigger token. And it is generally
used as a programmatic signalling mechanism rather than an actual model
variable. Feel free to use something other than phase for determining
how the model tracks the phase.

WARNING: the "trigger_id" variable is used in `utils.run_for_n_steps`
to ensure that the appropriate trigger id is used when swapping the
trigger ids.
"""

def count_only_cmod(inpt_id,
               variables,
               pad_id=0,
               bos_id=1,
               eos_id=2,
               resp_id=3,
               trigger_ids=[],
               demo_ids=[],
               unk_ids=[],
               pre_trigger=False,
               *args, **kwargs):
    """
    This model holds one variable for the current count of the
    process and a phase variable that tracks what the input token was.
    It ultimately counts up in the demo phase and
    then back down in the response phase.
    
    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        tok_id: int
            returns a list of ints representing the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
    """
    if variables is None:
        variables = {"count": 0, "phase": 0, "mask": 0}

    variables = {k:v for k,v in variables.items()}
    if inpt_id==bos_id:
        variables["phase"] = 0
        variables["count"] = 0
        variables["mask"] = 0

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] = 0
        #TODO:

        return demo_ids[np.random.randint(0,len(demo_ids))], variables
    elif inpt_id==eos_id or inpt_id==pad_id or inpt_id in unk_ids:
        variables["mask"] = 0
        if inpt_id in unk_ids:
            return inpt_id, variables
        return pad_id, variables
    elif inpt_id in trigger_ids:
        if not pre_trigger or variables["count"]!=0:
            variables["phase"] = 1
            variables["mask"] = 1
            token_id = resp_id
        else:
            variables["phase"] = 0
            variables["mask"] = 0
            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    # allows phase change when no trigger
    #elif variables["phase"]==0 and inpt_id==resp_id:
    #    variables["phase"] = 1
    #    variables["mask"] = 1
    #    variables["count"] -= 1
    #    variables["resp_count"] += 1
    #    token_id = resp_id
    elif variables["phase"]==0: #and inpt_id in demo_ids:
        variables["count"] += 1
        variables["mask"] = 0
        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    elif variables["phase"]>0: # and inpt_id==resp_id:
        variables["count"] -= 1
        variables["mask"] = 1
        token_id = resp_id

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] += 1
        #TODO:

    if variables["count"]<=0 and variables["phase"]>0:
        token_id = eos_id
        variables["mask"] = 1
    return token_id, variables

def and_one_cmod(
    inpt_id,
    variables,
    eos_id=2,
    resp_id=3,
    *args, **kwargs):
    """
    This model mimics the count_only_cmod except for that it outputs
    one more resp token than it should. This is useful for testing
    the attention mechanism of the transformer.
    """
    new_id, variables = count_only_cmod(inpt_id, variables, *args, **kwargs)
    if variables["count"]==0 and variables["phase"]>0: new_id = resp_id
    elif variables["count"]==-1 and variables["phase"]>0: new_id = eos_id
    return new_id, variables

def random_cmod(inpt_id,
               variables,
               *args, **kwargs):
    """
    This is a baseline model that with epsilon probability will
    uniformly sample an add or subtract function to change the count
    by one

    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        tok_id: int
            returns a list of ints representing the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
    """
    eps = 0.15
    if variables is None:
        variables = {"count": 0, "phase": 0, "mask": 0}
    variables = {k:v for k,v in variables.items()}

    if np.random.random()<eps: # and variables["count"]>0:
        variables["count"] += 1

        #if variables["phase"]==0:
        #    variables["count"] -= 1
        #else:
        #    variables["count"] += 1
        #sign = -1 if np.random.random()<0.5 else 1
        #variables["count"] += sign*1
    return count_only_cmod(inpt_id=inpt_id, variables=variables, *args, **kwargs)

def demo_resp_count_cmod(inpt_id,
                   variables,
                   pad_id=0,
                   bos_id=1,
                   eos_id=2,
                   resp_id=4,
                   trigger_ids=[],
                   demo_ids=[],
                   unk_ids=[],
                   pre_trigger=False,
                   *args, **kwargs):
    """
    This model holds 2 variables. One for the count of the
    demonstration tokens and one for the count of the
    responses.

    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "demo_count": int
                the count of the demo sequence
            "resp_count": int
                the count of the response sequence
            "phase": bool or int
                0: demonstration phase
                1: response phase
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        token_id: int
            the next token in the sequence according to the model
        variables: dict
            "demo_count": int
                the count of the demo sequence
            "resp_count": int
                the count of the response sequence
            "phase": bool or int
                0: demonstration phase
                1: response phase
    """
    if variables is None:
        variables = {"demo_count": 0, "resp_count": 0, "phase": 0, "mask": 0}
    variables = {k:v for k,v in variables.items()}
    if inpt_id==bos_id:
        variables["phase"] = 0
        variables["demo_count"] = 0
        variables["resp_count"] = 0
        variables["mask"] = 0
        return demo_ids[np.random.randint(0,len(demo_ids))], variables
    elif inpt_id==eos_id or inpt_id==pad_id:
        variables["mask"] = 0
        return pad_id, variables
    elif inpt_id in unk_ids:
        variables["mask"] = 0
        return inpt_id, variables
    elif inpt_id in trigger_ids:
        if not pre_trigger or variables["demo_count"]!=0:
            variables["phase"] = 1
            token_id = resp_id
            variables["mask"] = 1
        else:
            variables["phase"] = 0
            variables["mask"] = 0
            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    elif variables["phase"]==0:
        variables["demo_count"] += 1
        variables["mask"] = 0
        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    elif variables["phase"]>0:
        variables["resp_count"] += 1
        variables["mask"] = 1
        token_id = resp_id

    if variables["demo_count"]<=variables["resp_count"] and\
            variables["phase"]>0:
        variables["mask"] = 1
        token_id = eos_id
    return token_id, variables

def evens_offset_cmod(
               inpt_id,
               variables,
               pad_id=0,
               bos_id=1,
               eos_id=2,
               resp_id=3,
               trigger_ids=[],
               demo_ids=[],
               unk_ids=[],
               pre_trigger=False,
               *args, **kwargs):
    """
    This model holds one variable for the total count integer divided by
    two, and another variable to offset the count by one, effectively
    incremening even, main count to an odd count. It also has a phase
    variable that tracks what token to output and whether to add or
    subtract from the count variables.

    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "count": int
                the current count of the sequence integer divided by
                two.
            "offset": int
                whether or not to offset by one to make an even number
            "phase": bool or int
                0: demonstration phase
                1: response phase
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        tok_id: int
            returns a list of ints representing the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
    """
    if variables is None:
        variables = {"count": 0, "offset":0, "phase": 0, "mask": 0}

    variables = {k:v for k,v in variables.items()}
    if inpt_id==bos_id:
        variables["phase"] = 0
        variables["offset"] = 0
        variables["count"] = 0
        variables["mask"] = 0

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] = 0
        #TODO:

        return demo_ids[np.random.randint(0,len(demo_ids))], variables
    elif inpt_id==eos_id or inpt_id==pad_id:
        variables["mask"] = 0
        return pad_id, variables
    elif inpt_id in unk_ids:
        variables["mask"] = 0
        return inpt_id, variables
    elif inpt_id in trigger_ids:
        if not pre_trigger or variables["count"]!=0 or variables["offset"]!=0:
            variables["phase"] = 1
            variables["mask"] = 1
            token_id = resp_id
        else:
            variables["phase"] = 0
            variables["mask"] = 0
            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    elif variables["phase"]==0 and inpt_id in demo_ids:
        if bool(variables["offset"]):
            variables["count"] += 1
        variables["offset"] = 1-variables["offset"]
        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
        variables["mask"] = 0
    elif variables["phase"]>0 and inpt_id==resp_id:
        if not bool(variables["offset"]):
            variables["count"] -= 1
        variables["offset"] = 1-variables["offset"]
        token_id = resp_id
        variables["mask"] = 1

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] += 1
        #TODO:

    if variables["count"]<=0 and not bool(variables["offset"]) and variables["phase"]>0:
        token_id = eos_id
        variables["mask"] = 1
    return token_id, variables


def threevar_sum_cmod(
               inpt_id,
               variables,
               pad_id=0,
               bos_id=1,
               eos_id=2,
               resp_id=3,
               trigger_ids=[],
               demo_ids=[],
               unk_ids=[],
               pre_trigger=False,
               *args, **kwargs):
    """
    This model cycles through three variables that are summed together
    to determine the count.

    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "count0": int
                the current count of the zeroth variable
            "count1": int
                the current count of the first variable
            "count2": int
                the current count of the second variable
            "idx": int
                tracks which count to increment or decrement.
            "phase": bool or int
                0: demonstration phase
                1: response phase
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        tok_id: int
            returns a list of ints representing the sequence
        variables: dict
            "count": int
                the current count of the process.
            "phase": bool or int
                0: demonstration phase
                1: response phase
    """
    if variables is None:
        variables = {
            "count0": 0,
            "count1": 0,
            "count2": 0,
            "idx":0,
            "phase": 0,
            "mask": 0}

    variables = {k:v for k,v in variables.items()}
    if inpt_id==bos_id:
        variables["phase"] = 0
        variables["count0"] = 0
        variables["count1"] = 0
        variables["count2"] = 0
        variables["idx"] = 0
        variables["mask"] = 0

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] = 0
        #TODO:

        return demo_ids[np.random.randint(0,len(demo_ids))], variables
    elif inpt_id==eos_id or inpt_id==pad_id:
        variables["mask"] = 0
        return pad_id, variables
    elif inpt_id in unk_ids:
        variables["mask"] = 0
        return inpt_id, variables
    elif inpt_id in trigger_ids:
        if not pre_trigger or variables["count"]!=0 or variables["offset"]!=0:
            variables["phase"] = 1
            variables["mask"] = 1
            token_id = resp_id
        else:
            variables["phase"] = 0
            variables["mask"] = 0
            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
    elif variables["phase"]==0 and inpt_id in demo_ids:
        variables["count"+str(variables["idx"])] += 1
        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
        variables["mask"] = 0
    elif variables["phase"]>0 and inpt_id==resp_id:
        variables["count"+str(variables["idx"])] -= 1
        token_id = resp_id
        variables["mask"] = 1

        #TODO: used for removing first step after trigger from response phase sampling
        variables["resp_count"] += 1
        #TODO:

    variables["idx"] = (variables["idx"]+1)%3

    s = np.sum([variables["count"+str(i)] for i in range(3)])
    if s<=0 and variables["phase"]>0:
        token_id = eos_id
        variables["mask"] = 1
    return token_id, variables

def distributed_cmod(inpt_id,
               variables,
               pad_id=0,
               bos_id=1,
               eos_id=2,
               resp_id=3,
               trigger_ids=[],
               demo_ids=[],
               unk_ids=[],
               pre_trigger=False,
               *args, **kwargs):
    """
    This model first determines the sign of each input token and then
    averages their values. When their average is 0 and the input token
    is a resp_id, then it outputs the eos token.
    
    Args:
        inpt_id: int
            the current token in the sequence
        variables: dict
            "tok_vals": list
                a history of the token values it has seen so far
            "inpt_val": int
                the value of the incoming token as it contributes to
                the averaging (unless trigger which
                will overwrite the inpt_val).
            "mask": int
                this is just used to know whether the output token should
                be teacher forced or not. 0 means do teacher force
        pad_id: int
           the padding token id
        bos_id: int
           the beginning of sequence token id
        eos_id: int
           the end of sequence token id
        resp_id: int
           the response id
        trigger_ids: list of ints
           all possible trigger ids. If this list is greater than 1
           which will occur if using multi_trigger setting, the
           order of the list will correspond to the order of the
           demo type ids that each trigger refers to.
        demo_ids: list of ints
            the possible demonstration token ids. The length of this
            list will be equal to n_demo_types and the order will
            correspond to the appropriate trigger_id index if using
            multi_trigger setting.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
    Returns:
        tok_id: int
        variables: dict
    """
    if variables is None:
        variables ={
            "tok_vals": [],
            "inpt_val": 0,
            "phase": 0,
            "read_out": 0,
            "mask": 0}

    variables = {k:copy.deepcopy(v) for k,v in variables.items()}
    variables["tok_vals"].append(variables["inpt_val"])
    if inpt_id==bos_id:
        variables["mask"] = 0
        variables["phase"] = 0
        # useful for stats tracking
        variables["resp_count"] = 0

        variables["inpt_val"] = 1
        return demo_ids[np.random.randint(0,len(demo_ids))], variables
    elif inpt_id==eos_id or inpt_id==pad_id:
        variables["mask"] = 0
        variables["inpt_val"] = 0
        return pad_id, variables
    elif inpt_id in unk_ids:
        variables["mask"] = 0
        variables["inpt_val"] = 0
        return inpt_id, variables
    elif inpt_id in trigger_ids and trigger_ids[0]!=resp_id:
        variables["tok_vals"][-1] = 0
        if not pre_trigger:
            variables["phase"] = 1
            variables["mask"] = 1
            token_id = resp_id
            variables["inpt_val"] = -1
        else:
            variables["mask"] = 0
            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
            variables["inpt_val"] = 1
    # allows phase change when no trigger
    elif inpt_id==resp_id:
        if variables["phase"] == 0:
            variables["phase"] = 1
            if variables["inpt_val"]==1:
                variables["tok_vals"][-1] = -1
        variables["mask"] = 1
        variables["resp_count"] += 1
        token_id = resp_id
        variables["inpt_val"] = -1
    elif inpt_id in demo_ids:
        variables["mask"] = 0
        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
        variables["inpt_val"] = 1

    avg = np.mean(variables["tok_vals"])
    if avg<=0 and variables["phase"]>0:
        token_id = eos_id
        variables["mask"] = 1
    return token_id, variables



#def evens_and_odds_cmod(
#               inpt_id,
#               variables,
#               pad_id=0,
#               bos_id=1,
#               eos_id=2,
#               resp_id=3,
#               trigger_ids=[],
#               demo_ids=[],
#               pre_trigger=False,
#               *args, **kwargs):
#    """
#    This model holds one variable for the current count of odd numbers
#    and another for the evens and a phase variable that tracks what
#    token to output and whether to add or subtract from the counts.
#
#    Args:
#        inpt_id: int
#            the current token in the sequence
#        variables: dict
#            "even_count": int
#                the current count of the even numbers.
#            "odd_count": int
#                the current count of the odd numbers.
#            "phase": bool or int
#                0: demonstration phase
#                1: response phase
#        pad_id: int
#           the padding token id
#        bos_id: int
#           the beginning of sequence token id
#        eos_id: int
#           the end of sequence token id
#        resp_id: int
#           the response id
#        trigger_ids: list of ints
#           all possible trigger ids. If this list is greater than 1
#           which will occur if using multi_trigger setting, the
#           order of the list will correspond to the order of the
#           demo type ids that each trigger refers to.
#        demo_ids: list of ints
#            the possible demonstration token ids. The length of this
#            list will be equal to n_demo_types and the order will
#            correspond to the appropriate trigger_id index if using
#            multi_trigger setting.
#        pre_trigger: bool
#            if true, will indicate at the beginning of the demonstration
#            phase which token type to count. Otherwise indicates at the
#            beginning of the response phase. Only applies if multi_trigger
#            is true.
#    Returns:
#        tok_id: int
#            returns a list of ints representing the sequence
#        variables: dict
#            "count": int
#                the current count of the process.
#            "phase": bool or int
#                0: demonstration phase
#                1: response phase
#    """
#    if variables is None:
#        variables = {"odd_count": 0, "even_count":0, "phase": 0}
#
#    variables = {k:v for k,v in variables.items()}
#    if inpt_id==bos_id:
#        variables["phase"] = 0
#        variables["even_count"] = 0
#        variables["odd_count"] = 0
#
#        #TODO: used for removing first step after trigger from response phase sampling
#        variables["resp_count"] = 0
#        #TODO:
#
#        return demo_ids[np.random.randint(0,len(demo_ids))], variables
#    elif inpt_id==eos_id or inpt_id==pad_id:
#        return pad_id, variables
#    elif inpt_id in trigger_ids:
#        if not pre_trigger or variables["count"]!=0:
#            variables["phase"] = 1
#            token_id = resp_id
#        else:
#            variables["phase"] = 0
#            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
#    elif variables["phase"]==0 and inpt_id in demo_ids:
#        if variables["even_count"]!=variables["odd_count"]:
#            variables
#        variables["count"] += 1
#        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
#    elif variables["phase"]>0 and inpt_id==resp_id:
#        variables["count"] -= 1
#        token_id = resp_id
#
#        #TODO: used for removing first step after trigger from response phase sampling
#        variables["resp_count"] += 1
#        #TODO:
#
#    if variables["count"]<=0 and variables["phase"]>0:
#        token_id = eos_id
#    return token_id, variables



######################################################################
## PRE/POST TRIGGER MULTI TRIGGER MODELS
######################################################################

####def trigid_count_only_cmod(
####        inpt_id,
####        variables,
####        pad_id=0,
####        bos_id=1,
####        eos_id=2,
####        resp_id=4,
####        trigger_ids=[],
####        demo_ids=[],
####        pre_trigger=False,
####        *args, **kwargs):
####    """
####    This model holds one variable for the current count of the
####    process and a phase variable that tracks what the input token was.
####    It ultimately counts up in the demo phase and
####    then back down in the response phase.
####    
####    Args:
####        inpt_id: int
####            the current token in the sequence
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "count": int
####                the current count of the process.
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####        pad_id: int
####           the padding token id
####        bos_id: int
####           the beginning of sequence token id
####        eos_id: int
####           the end of sequence token id
####        resp_id: int
####           the response id
####        trigger_ids: list of ints
####           all possible trigger ids. If this list is greater than 1
####           which will occur if using multi_trigger setting, the
####           order of the list will correspond to the order of the
####           demo type ids that each trigger refers to.
####        demo_ids: list of ints
####            the possible demonstration token ids. The length of this
####            list will be equal to n_demo_types and the order will
####            correspond to the appropriate trigger_id index if using
####            multi_trigger setting.
####        pre_trigger: bool
####            if true, will indicate at the beginning of the demonstration
####            phase which token type to count. Otherwise indicates at the
####            beginning of the response phase. Only applies if multi_trigger
####            is true.
####    Returns:
####        seq: list of ints
####            returns a list of ints representing the sequence
####        variables: dict
####            "count": int
####                the current count of the process.
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####    """
####    if variables is None:
####        variables = { "trigger_id": 0, "count": 0, "phase": 0 }
####    variables = {k:v for k,v in variables.items()}
####    if inpt_id==bos_id:
####        variables["phase"] = 0
####        variables["count"] = 0
####        variables["trigger_id"] = 0
####        return demo_ids[np.random.randint(0,len(demo_ids))], variables
####    elif inpt_id==eos_id or inpt_id==pad_id:
####        return pad_id, variables
####    elif inpt_id in trigger_ids:
####        variables["trigger_id"] = inpt_id
####        if variables["count"]>0 or variables["phase"]>0:
####            variables["phase"] = 1
####            token_id = resp_id
####        else:
####            variables["phase"] = 0
####            return demo_ids[np.random.randint(0,len(demo_ids))], variables
####    elif inpt_id in demo_ids:
####        variables["phase"] = 0
####        idx = trigger_ids.index(variables["trigger_id"])
####        variables["count"] += int(inpt_id==demo_ids[idx])
####        token_id = demo_ids[np.random.randint(len(demo_ids))]
####    elif inpt_id==resp_id:
####        variables["phase"] = 1
####        variables["count"] -= 1
####        token_id = resp_id
####
####    if variables["count"]<=0 and variables["phase"]>0:
####        token_id = eos_id
####    return token_id, variables
####
####
####def trigid_multi_counts_cmod(
####        inpt_id,
####        variables,
####        pad_id=0,
####        bos_id=1,
####        eos_id=2,
####        resp_id=4,
####        trigger_ids=[],
####        demo_ids=[],
####        pre_trigger=False,
####        track_all=True,
####        *args, **kwargs):
####    """
####    This model holds a variable for each countable demo id that stores
####    the current count of that demo id. It then stores a variable for
####    which demo_id it was supposed to count when it gets the trigger,
####    and finally counts the appropriate count back downward to zero.
####
####    Args:
####        inpt_id: int
####            the current token in the sequence
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "counts": list of ints of length len(demo_ids)
####                the current count of each demo_id.
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####        pad_id: int
####           the padding token id
####        bos_id: int
####           the beginning of sequence token id
####        eos_id: int
####           the end of sequence token id
####        resp_id: int
####           the response id
####        trigger_ids: list of ints
####           all possible trigger ids. If this list is greater than 1
####           which will occur if using multi_trigger setting, the
####           order of the list will correspond to the order of the
####           demo type ids that each trigger refers to.
####        demo_ids: list of ints
####            the possible demonstration token ids. The length of this
####            list will be equal to n_demo_types and the order will
####            correspond to the appropriate trigger_id index if using
####            multi_trigger setting.
####        pre_trigger: bool
####            if true, will indicate at the beginning of the demonstration
####            phase which token type to count. Otherwise indicates at the
####            beginning of the response phase. Only applies if multi_trigger
####            is true.
####        track_all: bool
####            if track_all is true, this function will count all demo ids
####            independently, even those that are outside of the triggered
####            category. Alternatively it will ignore non-focussed values.
####    Returns:
####        seq: list of ints
####            returns a list of ints representing the sequence
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "counts": list of ints of length len(demo_ids)
####                the current count of each demo_id.
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####    """
####    if variables is None:
####        variables = {
####            "trigger_id": 0,
####            "counts": [0 for _ in demo_ids],
####            "phase": 0,
####            "step": 0,
####            }
####    new_vars = dict()
####    for k,v in variables.items():
####        if type(v)==list:
####            new_vars[k] = [vv for vv in v]
####        else:
####            new_vars[k] = v
####    variables = new_vars
####    variables["step"] += 1
####    if inpt_id==bos_id:
####        variables["phase"] = 0
####        variables["counts"] = [0 for _ in demo_ids]
####        variables["trigger_id"] = 0
####        return demo_ids[np.random.randint(0,len(demo_ids))], variables
####    elif inpt_id==eos_id or inpt_id==pad_id:
####        return pad_id, variables
####    elif inpt_id in trigger_ids:
####        variables["trigger_id"] = inpt_id
####        if not pre_trigger or variables["step"]>2:
####            variables["phase"] = 1
####            token_id = resp_id
####        else:
####            variables["phase"] = 0
####            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id in demo_ids:
####        variables["phase"] = 0
####        if track_all or inpt_id==demo_ids[trigger_ids.index(variables["trigger_id"])]:
####            idx = demo_ids.index(inpt_id)
####            variables["counts"][idx] += 1
####        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id==resp_id:
####        variables["phase"] = 1
####        idx = trigger_ids.index(variables["trigger_id"])
####        variables["counts"][idx] -= 1
####        token_id = resp_id
####
####    if variables["trigger_id"] in trigger_ids:
####        idx = trigger_ids.index(variables["trigger_id"])
####        if variables["phase"]>0 and variables["counts"][idx]<=0:
####            variables["phase"] = 1
####            token_id = eos_id
####    return token_id, variables
####
####def trigid_multi_counts_ignore_cmod(*args, **kwargs):
####    """
####    This model holds a variable for each countable demo id. It increases,
####    however, only the count corresponding to the triggered category,
####    ignoring the other demo_ids. It then counts the appropriate count
####    back downward to zero in the response phase.
####
####    Args:
####        see `trigid_multi_counts_cmod`
####    Returns:
####        see `trigid_multi_counts_cmod`
####    """
####    kwargs["track_all"] = False
####    return trigid_multi_counts_cmod(*args, **kwargs)
####
####def trigid_demo_resp_count_cmod(
####        inpt_id,
####        variables,
####        pad_id=0,
####        bos_id=1,
####        eos_id=2,
####        resp_id=4,
####        trigger_ids=[],
####        demo_ids=[],
####        *args, **kwargs):
####    """
####    This model holds 3 variables. One to track which demo token to focus
####    on, one to track the count of the demonstration phase and one for
####    the count of the response phase. This model might arise in a pre-
####    trigger setting with mutliple triggers.
####
####    Args:
####        inpt_id: int
####            the current token in the sequence
####        variables: dict
####            "trigger_id": int
####                the index of the demo token to count within the demo_ids
####                list
####            "demo_count": int
####                the count of the demo sequence
####            "resp_count": int
####                the count of the response sequence
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####        pad_id: int
####           the padding token id
####        bos_id: int
####           the beginning of sequence token id
####        eos_id: int
####           the end of sequence token id
####        resp_id: int
####           the response id
####        trigger_ids: list of ints
####           all possible trigger ids. If this list is greater than 1
####           which will occur if using multi_trigger setting, the
####           order of the list will correspond to the order of the
####           demo type ids that each trigger refers to.
####        demo_ids: list of ints
####            the possible demonstration token ids. The length of this
####            list will be equal to n_demo_types and the order will
####            correspond to the appropriate trigger_id index if using
####            multi_trigger setting.
####        pre_trigger: bool
####            if true, will indicate at the beginning of the demonstration
####            phase which token type to count. Otherwise indicates at the
####            beginning of the response phase. Only applies if multi_trigger
####            is true.
####    Returns:
####        token_id: int
####            the next token in the sequence according to the model
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "demo_count": int
####                the count of the demo sequence
####            "resp_count": int
####                the count of the response sequence
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####    """
####    assert type(demo_ids)==list and type(trigger_ids)==list
####    if variables is None:
####        variables = {
####          "trigger_id": 0,
####          "demo_count": 0,
####          "resp_count": 0,
####          "phase": 0,
####          "step": 0,}
####    variables = {k:v for k,v in variables.items()}
####    variables["step"] += 1
####    if inpt_id==bos_id:
####        variables["phase"] = 0
####        variables["trigger_id"] = 0
####        variables["demo_count"] = 0
####        variables["resp_count"] = 0
####        return demo_ids[np.random.randint(0,len(demo_ids))], variables
####    elif inpt_id==eos_id or inpt_id==pad_id:
####        return pad_id, variables
####    elif inpt_id in demo_ids:
####        variables["phase"] = 0
####        idx = trigger_ids.index(variables["trigger_id"])
####        variables["demo_count"] += int(inpt_id==demo_ids[idx])
####        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id in trigger_ids:
####        variables["trigger_id"] = inpt_id
####        if variables["demo_count"]>0 or variables["step"]>2:
####            variables["phase"] = 1
####            token_id = resp_id
####        else:
####            variables["phase"] = 0
####            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id==resp_id:
####        variables["phase"] = 1
####        variables["resp_count"] += 1
####        token_id = resp_id
####
####    if variables["phase"]>0 and variables["demo_count"]<=variables["resp_count"]:
####        variables["phase"] = 1
####        token_id = eos_id
####    return token_id, variables
####
####
####def trigid_demo_counts_resp_count_cmod(
####        inpt_id,
####        variables,
####        pad_id=0,
####        bos_id=1,
####        eos_id=2,
####        resp_id=4,
####        trigger_ids=[],
####        demo_ids=[],
####        *args, **kwargs):
####    """
####    This model holds 3 variables. One to track which demo token to focus
####    on, one to track the count of the demonstration phase and one for
####    the count of the response phase. This model might arise in a pre-
####    trigger setting with mutliple triggers.
####
####    Args:
####        inpt_id: int
####            the current token in the sequence
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "demo_counts": int
####                the count of the demo sequence
####            "resp_count": int
####                the count of the response sequence
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####        pad_id: int
####           the padding token id
####        bos_id: int
####           the beginning of sequence token id
####        eos_id: int
####           the end of sequence token id
####        resp_id: int
####           the response id
####        trigger_ids: list of ints
####           all possible trigger ids. If this list is greater than 1
####           which will occur if using multi_trigger setting, the
####           order of the list will correspond to the order of the
####           demo type ids that each trigger refers to.
####        demo_ids: list of ints
####            the possible demonstration token ids. The length of this
####            list will be equal to n_demo_types and the order will
####            correspond to the appropriate trigger_id index if using
####            multi_trigger setting.
####        pre_trigger: bool
####            if true, will indicate at the beginning of the demonstration
####            phase which token type to count. Otherwise indicates at the
####            beginning of the response phase. Only applies if multi_trigger
####            is true.
####    Returns:
####        token_id: int
####            the next token in the sequence according to the model
####        variables: dict
####            "trigger_id": int
####                the id of the trigger token, used to find the index of
####                the demo_id within the demo_ids list.
####            "demo_count": int
####                the count of the demo sequence
####            "resp_count": int
####                the count of the response sequence
####            "phase": bool or int
####                0: demonstration phase
####                1: response phase
####    """
####    assert type(demo_ids)==list and type(trigger_ids)==list
####    if variables is None:
####        variables = {
####            "trigger_id": 0, "demo_counts": [0 for _ in demo_ids],
####            "resp_count": 0, "phase": 0
####        }
####    new_vars = dict()
####    for k,v in variables.items():
####        if type(v)==list:
####            new_vars[k] = [vv for vv in v]
####        else:
####            new_vars[k] = v
####    variables = new_vars
####    if inpt_id==bos_id:
####        variables["phase"] = 0
####        variables["trigger_id"] = 0
####        variables["demo_counts"] = [0 for _ in demo_ids]
####        variables["resp_count"] = 0
####        return demo_ids[np.random.randint(0,len(demo_ids))], variables
####    elif inpt_id==eos_id or inpt_id==pad_id:
####        variables["phase"] = 1
####        return pad_id, variables
####    elif inpt_id in demo_ids:
####        variables["phase"] = 0
####        idx = demo_ids.index(inpt_id)
####        variables["demo_counts"][idx] += 1
####        token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id in trigger_ids:
####        variables["trigger_id"] = inpt_id
####        if sum(variables["demo_counts"])!=0:
####            variables["phase"] = 1
####            token_id = resp_id
####        else:
####            variables["phase"] = 0
####            token_id = demo_ids[np.random.randint(0,len(demo_ids))]
####    elif inpt_id==resp_id:
####        variables["phase"] = 1
####        variables["resp_count"] += 1
####        token_id = resp_id
####
####    idx = trigger_ids.index(variables["trigger_id"])
####    if variables["demo_counts"][idx]<=variables["resp_count"]:
####        variables["phase"] = 1
####        token_id = eos_id
####    return token_id, variables


if __name__=="__main__":
    targ_count = 5
    pad_id = 0
    bos_id = 1
    eos_id = 2
    resp_id = 3
    demo_ids = [4]
    trig_ids = [5]
    trig_id = trig_ids[0]
    
    inpt_id = bos_id
    varbs = None
    seq = []
    for i in range(2*(targ_count + 1)+1):
        inpt_id, varbs = and_one_cmod(
            inpt_id,
            variables=varbs,
            pad_id=pad_id,
            eos_id=eos_id,
            bos_id=bos_id,
            resp_id=resp_id,
            demo_ids=demo_ids,
            trigger_ids=trig_ids,
            trigger_id=trig_id,
        )
        if i==targ_count: inpt_id = trig_id
        seq.append(inpt_id)
    print("seq:", seq)
