import numpy as np
try:
    from .utils import sep_digit_fxn
except:
    from utils import sep_digit_fxn

def base_num_equivalence(
        targ_count,
        trigger_token=None,
        n_demo_types=1,
        tot_demo_tokens=None,
        multi_trigger=False,
        pre_trigger=True,
        offset=0,
        ret_info=False,
        resp_token=None,
        demo_tokens=None,
        cot_count=False,
        sep_digits=True,
        reverse_digits=False,
        incl_trigger=True,
        trigger_tokens=None,
        unk_tokens=None,
        n_unk_types=1,
        unk_p:float=0,
        max_unks:int=np.inf,
        *args, **kwargs):
    """
    This function serves to build sequences that start with a beginning of
    sequence token type, some number of demonstration tokens for the model to
    count, a trigger token to indicate that the model should now respond with
    the count, the goal number of response tokens, and lastly an end of sequence
    token.
    
    We have the ability to pick the number of demonstration token types (ie the
    number of fruit types if we were counting fruit); we can pick whether the
    we want the network to count the sum of all fruit types or count
    one particular fruit; we have the ability to indicate to the network which
    fruit type to count at the beginning or the end of the demonstration
    sequence.
    
    Args:
        targ_count: int
            the goal number of things to count.
        n_demo_types: int
            the number of demonstration id types (ie number of types of
            fruit if we were counting fruit).
        tot_demo_tokens: int
            the total number of demonstration tokens in the initial demonstration
            phase . Only applies if multi_trigger is true. Otherwise takes the
            value of targ_count.
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        offset: int
            optionally argue an offset amount to shift all tokens by.
        trigger_token: str or None
            Optionally specify the trigger token to use for the sequence.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
        resp_token: None or str
            optionally argue the response id
        demo_tokens: None or list of str
            optionally argue the available demonstration id types
        cot_count: bool
            if true, will include tokens representing the count of the
            sequence after every token
        sep_digits: bool
            if true, will always separate individual digits as individual
            elements in the list. Otherwise keeps numerals together as
            a single token. i.e. for the number 102, we can do
                sep_digits==True: ["1","0","2"]
                sep_digits==False: ["102"]
        reverse_digits: bool
            if true and sep_digits is true, will reverse the ordering
            of the digits so that the ones place comes first, followed
            by 10s, etc. i.e. for 102, we will get ["2", "0", "1"]
        incl_trigger: bool
            optionally return sequences without any trigger token. If
            this is true, the task_mask will be 0 for the first response
            token in the sequence. Overwrites the trigger_token argument
        trigger_tokens: list of str
            all possible trigger tokens. If this list is greater than 1
            which will occur if using multi_trigger setting, the
            order of the list will correspond to the order of the
            demo type ids that each trigger refers to.
        unk_tokens: list of str
            all possible unk tokens. unk tokens are included in
            the demo phase with some probability while being ignored by
            the task to break associations with positional encodings.
        n_unk_types: int
            the number of possible unk token types if unk_tokens is
            unspecified
        unk_p: float
            the probability of inserting a unk token in place of a
            demo token.
        max_unks: int
            the maximum number of dummy tokens if using dummy tokens.
            if None, defaults to infinity.
    Returns:
        seq: list of str
            the token sequence
        task_mask: list of ints
        if ret_info:
            info: dict
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "trigger_tokens": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_tokens": list of str
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_token index if using
                    multi_trigger setting.
    """
    if resp_token is None:
        resp_token = "R"
    if demo_tokens is None:
        demo_tokens = ["D"+str(i) for i in range(n_demo_types)]
    if trigger_tokens is None:
        if incl_trigger:
            trigger_tokens = ["T"+str(i) for i in range(n_demo_types)]\
                             if multi_trigger else ["T"]
        else:
            trigger_tokens = [resp_token]
    if trigger_token is None:
        trigger_token = trigger_tokens[np.random.randint(len(trigger_tokens))]
    if max_unks is None: max_unks = np.inf
    if unk_tokens is None:
        if unk_p>0 and max_unks>0:
            unk_tokens = ["U"+str(i) for i in range(n_unk_types)]
        else:
            unk_tokens = []

    if not tot_demo_tokens:
        tot_demo_tokens = targ_count
    
    # Get demo sequence
    if multi_trigger:
        demo_idx = int(trigger_token[-1])
        # guarantee right number target token
        demo_seq = [demo_tokens[demo_idx] for _ in range(targ_count)] 
        # create list of non-target token types
        other_types = [
            d for d in demo_tokens+unk_tokens if d != demo_tokens[demo_idx]]
        demo_seq = demo_seq + [
            other_types[np.random.randint(len(other_types))]\
                for _ in range(tot_demo_tokens-len(demo_seq))
        ]
        np.random.shuffle(demo_seq) # shuffle the sequence
    else:
        demo_seq = []
        n_demos = 0
        n_unks = 0
        while n_demos<targ_count:
            if n_unks<max_unks and len(unk_tokens)>0 and np.random.random()<unk_p:
                tok = unk_tokens[np.random.randint(len(unk_tokens))]
                n_unks += 1
            else:
                tok = demo_tokens[np.random.randint(len(demo_tokens))]
                n_demos += 1
            demo_seq.append(tok)

    # Get response sequence
    resp_seq = [resp_token for _ in range(targ_count)]
    
    # Combine all elements
    seq = []
    if multi_trigger and pre_trigger:
        seq.append(trigger_token)
    seq += demo_seq
    if incl_trigger:
        seq += [trigger_token]
    task_mask =  [0 for _ in range(len(seq))]
    resp_mask = [1 for _ in range(len(resp_seq))]
    if not incl_trigger: resp_mask[0] = 0
    task_mask = task_mask + resp_mask
    seq += resp_seq
    if not ret_info:
        return seq, task_mask
    return seq, task_mask, {
        "resp_token": resp_token,
        "demo_tokens": demo_tokens,
        "trigger_tokens": trigger_tokens,
        "trigger_token": trigger_token,
        "unk_tokens": unk_tokens,
    }

def num_equivalence(
        *args,
        chain_of_num=False,
        sep_digits=True,
        reverse_digits=False,
        numeral_base=10,
        strategy="stack",
        copy_task=False,
        **kwargs):
    """
    This function builds on the numeric equivalence task by including
    tokens representing the count of the sequence after every demo
    or response token.
    
    Args:
        targ_count: int
            the goal number of things to count.
        n_demo_types: int
            the number of demonstration id types (ie number of types of
            fruit if we were counting fruit).
        tot_demo_tokens: int
            the total number of demonstration tokens in the initial demonstration
            phase . Only applies if multi_trigger is true. Otherwise takes the
            value of targ_count.
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        offset: int
            optionally argue an offset amount to shift all tokens by.
        trigger_token: int or None
            Optionally specify the trigger id to use for the sequence.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
        resp_token: None or str
            optionally argue the response id
        demo_tokens: None or list of str
            optionally argue the available demonstration id types
        chain_of_num: bool
            if true, will include a chain of thought like counting
            variable in between each token.
        sep_digits: bool
            if true, will always separate individual digits as individual
            elements in the list. Otherwise keeps numerals together as
            a single token. i.e. for the number 102, we can do
                sep_digits==True: ["1","0","2"]
                sep_digits==False: ["102"]
        reverse_digits: bool
            if true and sep_digits is true, will reverse the ordering
            of the digits so that the ones place comes first, followed
            by 10s, etc. i.e. for 102, we will get ["2", "0", "1"]
        numeral_base: int
            optionally change the base of the number system
        strategy: str (options: "stack", "match")
            the chain of thought strategy. stack will count up then
            count down, whereas match will count up and then count up
            again.
        copy_task: bool
            if true, will use same token type for input and response.
            This overwrites n_demo_types
    Returns:
        seq: list of str
            the token sequence
        task_mask: list of ints
        if ret_info:
            info: dict
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "trigger_tokens": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_tokens": list of ints
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_token index if using
                    multi_trigger setting.
    """
    if copy_task:
        kwargs["n_demo_types"] = 1
        kwargs["resp_token"] =  "D"
        kwargs["demo_tokens"] = ["D"]
    seq, mask, info = base_num_equivalence(*args, **kwargs)
    if not chain_of_num: return seq, mask, info
    new_mask = []
    new_seq = []
    demo_count = 0
    resp_count = 0
    phase = 0
    for i in range(len(seq)):
        if seq[i] in info["demo_tokens"]:
            demo_count += 1
        elif seq[i] == info["resp_token"]:
            resp_count += 1
        elif seq[i] in info["trigger_tokens"]:
            phase = 1
        new_seq.append(seq[i])
        new_mask.append(mask[i])
        count = demo_count - resp_count
        if phase==1 and strategy=="match":
            count = resp_count
        if sep_digits:
            count = sep_digit_fxn(
                value=count,
                reverse_digits=reverse_digits,
                base=numeral_base)
        else: count = [str(count)]
        new_seq += count
        new_mask += [1 for _ in count]
    return new_seq, new_mask, info

def copy_equivalence(*args, **kwargs):
    """
    This function serves to build sequences that start with a bos token,
    some number of demonstration tokens for the model to track, a 
    trigger token to indicate that the model should now respond by
    copying the sequence it just saw, the repeated sequence, and lastly
    an end of sequence token.

    We have the ability to pick the number of demonstration token types (ie the
    number of fruit types if we were counting fruit); we can pick whether the
    we want the network to respond with the exact fruit pattern or
    fruits of one particular type.

    Args:
        targ_count: int
            the goal number of things to count.
        n_demo_types: int
            the number of demonstration id types (ie number of types of
            fruit if we were counting fruit).
        tot_demo_tokens: int
            the total number of demonstration tokens in the initial demonstration
            phase . Only applies if multi_trigger is true. Otherwise takes the
            value of targ_count.
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        offset: int
            optionally argue an offset amount to shift all tokens by.
        trigger_token: str or None
            Optionally specify the trigger id to use for the sequence.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
    Returns:
        seq: list of ints
            the token sequence
        if ret_info:
            info: dict
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "trigger_tokens": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_tokens": list of ints
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_token index if using
                    multi_trigger setting.
    """
    return num_equivalence(*args, resp_token="D", demo_tokens=["D"], **kwargs)

def sample_num_equivalence(
        targ_count=None,
        trigger_token=None,
        min_count=1,
        max_count=20,
        max_demo_tokens=None,
        n_demo_types=3,
        multi_trigger=True,
        pre_trigger=False,
        trigger_tokens=None,
        hold_outs=None,
        offset=0,
        ret_info=True,
        task_fxn=num_equivalence,
        *args, **kwargs):
    """
    This function simplifies sampling a sequence.
    
    Args:
        targ_count: int or None
            optionally argue a specific target count. If this value is
            not None, it will be used instead of sampling a random
            target count.
        trigger_token: int or None
            optionally argue a specific trigger_token. If this value is
            not None, it will be used instead of sampling a random
            trigger_token.
        min_count: int
            the minimum possible target count to sample
        max_count: int (inclusive)
            the maximum possible inclusive target count to sample
        max_demo_tokens: int
            optionally argue a maximum number of demonstration
            tokens. This only applies if using multi_trigger where
            the total number of tokens is different than the
            target quantity. Defaults to maximum count
        n_demo_types: int
            the number of demonstration types (ie number of types of
            fruit if we were counting fruit).
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        hold_outs: set of ints
            optionally argue values that you don't want sampled as the target
            quantities.
        offset: int
            optionally argue an offset amount to shift all tokens by.
        trigger_token: int or None
            Optionally specify the trigger id to use for the whole
            dataset.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
    Returns:
        seq: list of ints
            the token sequence
        if ret_info:
            info: dict
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "resp_token": int
                    the response token id during the response phase
                "trigger_tokens": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_tokens": list of ints
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_token index if using
                    multi_trigger setting.
    """
    if targ_count is None:
        targ_count = np.random.randint(min_count, max_count+1)
    if hold_outs is not None and targ_count in hold_outs:
        while targ_count in hold_outs:
            targ_count = np.random.randint(min_count, max_count+1)
    tot_toks = None
    if multi_trigger and n_demo_types>1:
        if max_demo_tokens is None or max_demo_tokens==0:
            max_demo_tokens = max_count
        if max_demo_tokens>targ_count:
            tot_toks = targ_count + np.random.randint(
                max_demo_tokens-targ_count)
    return task_fxn(targ_count=targ_count,
                    tot_demo_tokens=tot_toks,
                    n_demo_types=n_demo_types,
                    multi_trigger=multi_trigger,
                    pre_trigger=pre_trigger,
                    offset=offset,
                    trigger_token=trigger_token,
                    trigger_tokens=trigger_tokens,
                    ret_info=ret_info,
                    *args, **kwargs)

def sample_copy_equivalence( *args,**kwargs ):
    """
    This function simplifies sampling a sequence that will repeat
    itself after the trigger token.
    
    Args:
        min_count: int
            the minimum possible target count to sample
        max_count: int (inclusive)
            the maximum possible inclusive target count to sample
        max_demo_tokens: int
            optionally argue a maximum number of demonstration
            tokens. This only applies if using multi_trigger where
            the total number of tokens is different than the
            target quantity. Defaults to maximum count
        n_demo_types: int
            the number of demonstration types (ie number of types of
            fruit if we were counting fruit).
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        hold_outs: set of ints
            optionally argue values that you don't want sampled as the target
            quantities.
        offset: int
            optionally argue an offset amount to shift all tokens by.
        trigger_token: int or None
            Optionally specify the trigger id to use for the whole
            dataset.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
    Returns:
        seq: list of ints
            the token sequence
        if ret_info:
            info: dict
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "resp_token": int
                    the response token id during the response phase
                "trigger_tokens": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_tokens": list of ints
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_token index if using
                    multi_trigger setting.
    """
    return sample_num_equivalence(
        *args, task_fxn=copy_equivalence, **kwargs)

if __name__=="__main__":
    seq, mask, info = sample_num_equivalence(
        min_count=1,
        max_count=20,
        max_demo_tokens=None,
        n_demo_types=3,
        multi_trigger=False,
        pre_trigger=False,
        hold_outs=None,
        offset=0,
        trigger_token=None,
        ret_info=True,
        task_fxn=num_equivalence,
        chain_of_num=False,
        sep_digits=True,
        reverse_digits=False,
        numeral_base=10,
        incl_trigger=False,
        strategy="match",)
    print(" ".join([f"{i:2}" for i in range(len(seq))]))
    print(" ".join([f"{s:2}" for s in seq]))
    print(" ".join([str(m)+" " for m in mask]))
