import numpy as np

def induction_heads(
        resp_dist=10,
        n_trigs=1,
        n_trig_types=1,
        n_token_types=9,
        trig_types=None,
        token_types=None,
        first_idx=None,
        max_first_idx=20,
        trig_first=True,
        incl_sep=True,
        allow_dupls=False,
        ret_info=True,
        *args, **kwargs,
    ):
    """
    This task tests the models' ability to recall a specific piece of
    information conditioned on a trigger. For example, the sequence may be
    a random sequence of tokens where somewhere in the sequence there
    exists a T. The goal of the model is to output the token that follows
    the T.
    
    Ex: B A W T R N W P T _ # fill in an R here
    
    The traditional task only has a single trigger which precedes the token
    of interest, we allow however for tasks in which there are multiple
    triggers and we allow the triggers to follow the tokens of interest
    rather than precede them. Lastly, if you are using multiple triggers,
    we do not allow triggers to also be the targets.

    Args:
        resp_dist: int
            the number of tokens between the first trigger and the first
            response
        n_trigs: None or int
            the number of triggers in the sequence. if none, will default
            to n_trig_types (or len(trig_types) if trig types is argued)
        n_trig_types: int
            the number of trigger types. if trig_types is not None, will
            use the tokens argued in trig_types instead of creating new
            ones.
        n_token_types: int
            the number of valid tokens that are not triggers
        trig_types: list or set of possible trigger tokens
            optionally argue a list of possible tokens to use as triggers.
            if None, will create the number of trigger types argued in
            `n_trig_types`
        token_types: list or set of possible non-trigger tokens
            optionally argue a list of token types. if None, will
            create n_token_types new token types
        first_idx: int
            optionally argue the first index for a trigger token to appear
        max_first_idx: int
            the maximum sequence length preceding the first trigger token.
        trig_first: bool
            determines whether the triggers precede or follow the
            corresponding target token.
        allow_dupls: bool or int
            if true, will allow each trigger token type to be sampled
            multiple times. if an int is argued, it limits the possible
            number of times that each type can be sampled. if a bool,
            it allows infinite resamples.
        incl_sep: bool
            if true, will include a separator token between the input
            phase and the response phase.
    Returns:
        seq: list of str
            the sequence of tokens
        pred_mask: list of ints
            a binary mask where 1s denote predictable tokens and 0s
            denote tokens that are impossible to predict.
    """
    # Initialize
    if trig_types is None:
        trig_types = ["T"+str(i) for i in range(n_trig_types)]
    trig_types = [str(t) for t in trig_types]
    if n_trigs is None: n_trigs = len(trig_types)
    if token_types is None:
        token_types = [str(i) for i in range(n_token_types)]
    token_types = [str(t) for t in token_types]
    if first_idx is None:
        first_idx = np.random.randint(int(not trig_first), max_first_idx)
        #max_seq_len-resp_dist-int(trig_first)-2*n_trigs-1
    poss_idxs = list(range(first_idx,first_idx+resp_dist))

    # Loop initialization
    if type(allow_dupls)==int and allow_dupls>0:
        trig_counts = {t:0 for t in trig_types}
    which_type = np.random.randint(len(trig_types))
    tidx2ttype = dict()
    taken_idxs = set()
    # Loop to determine indices where the trigger tokens will go and the
    # corresponding trigger types
    perm = list(np.random.permutation(len(poss_idxs)))
    for i in [0]+perm: # zero guarantees the first index is correct
        idx = poss_idxs[i]
        if idx not in taken_idxs and idx not in tidx2ttype:
            # We can allow duplicates by randomly selecting triggers.
            # We can limit the number of repeats by tracking the number
            # of times each trigger has been used in trig_counts
            if allow_dupls:
                which_type = np.random.randint(len(trig_types))
                if type(allow_dupls)==int and allow_dupls>0:
                    for _ in range(len(trig_types)):
                        which_type = (which_type + 1)%len(trig_types)
                        if trig_counts[trig_types[which_type]]<allow_dupls:
                            trig_counts[trig_types[which_type]]+=1
                            break
            tidx2ttype[idx] = trig_types[which_type%len(trig_types)]
            which_type += 1
            taken_idxs.add(idx)
            taken_idxs.add(idx-1)
            taken_idxs.add(idx+1)
        if len(tidx2ttype)>=n_trigs:
            break

    # Make randomized sequence
    sorted_trig_idxs = sorted(list(tidx2ttype.keys()))
    seq = []
    tmask = []
    for i in range(sorted_trig_idxs[0]+resp_dist+1):
        if i in tidx2ttype:
            seq.append(tidx2ttype[i])
        else:
            seq.append(token_types[np.random.randint(len(token_types))])
        tmask.append(0)
    sep_token = "<SEP>"
    if incl_sep:
        seq.append(sep_token)
        tmask.append(0)
    for i in range(len(sorted_trig_idxs)):
        seq.append(tidx2ttype[sorted_trig_idxs[i]])
        tmask.append(1)
        targ_idx = sorted_trig_idxs[i]
        targ_idx = targ_idx+1 if trig_first else targ_idx-1
        seq.append(seq[targ_idx])
        tmask.append(1)
    if ret_info:
        info = {
            "token_types": token_types,
            "trig_types": trig_types,
            "sep_token": sep_token,
        }
        return seq, tmask, info
    return seq, tmask

def sample_induction_heads(
        resp_dist=None,
        min_resp_dist=5,
        max_resp_dist=20,
        n_trigs=None,
        n_trig_types=1,
        *args, **kwargs):
    if resp_dist is None:
        temp = n_trigs if n_trigs is not None else n_trig_types
        min_resp_dist = max(min_resp_dist, 2*temp)
        resp_dist = np.random.randint(min_resp_dist, max_resp_dist+1)
    return induction_heads(
        resp_dist=resp_dist,
        n_trigs=n_trigs,
        n_trig_types=n_trig_types,
        *args, **kwargs)

if __name__=="__main__":
    resp_dist=20
    max_first_idx=15
    n_trig_types = 5
    n_token_types = 10
    allow_dupls = None
    n_trigs=None
    trig_types=None
    token_types=None
    trig_first=True
    incl_sep = True
    for i in range(1000):
        seq, tmask, info = induction_heads(
            resp_dist=resp_dist,
            n_trigs=n_trigs,
            n_trig_types=n_trig_types,
            n_token_types=n_token_types,
            trig_types=trig_types,
            token_types=token_types,
            trig_first=trig_first,
            incl_sep=incl_sep,
            allow_dupls=allow_dupls,
            max_first_idx=max_first_idx,
        )
        if len(seq)>(max_first_idx+resp_dist+2*n_trig_types+incl_sep):
            print(seq)
            print("len:", len(seq))
            print("len", len(tmask))
            assert False
    for i in range(3):
        seq, tmask, info = induction_heads(
            resp_dist=resp_dist,
            n_trigs=n_trigs,
            n_trig_types=n_trig_types,
            n_token_types=n_token_types,
            trig_types=trig_types,
            token_types=token_types,
            trig_first=trig_first,
            incl_sep=incl_sep,
            allow_dupls=allow_dupls,
            max_first_idx=max_first_idx,
        )
        print("Len:", len(seq))
        print("len", len(tmask))
        print("n_ones:", np.sum(np.asarray(seq)==1))
        print(seq)


