import numpy as np
from .num_equivalence import count_only_cmod, distributed_cmod

def choice(seq):
    """
    Helper function that has similar functionality to numpy's
    choice function. This one is faster than numpy's version,
    however, because we make the assumption that seq is a
    1D list.
    
    Args:
        seq: sequence of objects
    Returns:
        val: object
            a uniformly sampled value from seq
    """
    return seq[np.random.randint(len(seq))]

def idx_sampler(seq, token_ids, *args, **kwargs):
    """
    An important part of the DAS alignment in counting models
    is performing the interchange at different phases in the
    sequence. This function will sample an index in the argued
    sequence from an index that contains one of the token ids.
    It first finds all valid indices and then samples an index
    from these indices.
    
    Args:
        seq: iterable of ints
            the id sequence
        token_ids: list or set of ints
            the token ids that indicate their index should be
            considered for sampling.
    Returns:
        idx: int
            a sampled index. The id at this index in the argued
            seq will be a token in the argued token_ids. If no
            token ids in seq correspond to the integers in
            `token_ids`, this value will default to -1.
    """
    valid_idxs = [i for i in range(len(seq)) if seq[i] in token_ids] 
    if len(valid_idxs)==0: return -1
    return choice(valid_idxs)

def value_sampler(
        seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==0,
        invalids=set(),
        info=None,
        *args, **kwargs):
    """
    This function will sample an index in the argued sequence that
    has a true value according to the check_fxn. The value argued to
    the check fxn is the value corresponding to the argued key in the
    high level variables decided by the causal_model. In the default
    settings, this function will return an index that has an even
    count value.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        check_fxn: lambda fxn
            the check function will receive variable dicts and check
            if they satisfy the sampling criteria. If no values
            satisfy the check function, then -1 is returned.
        info: dict
            optionally argue the info dict that will be passed to the
            causal_model.
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    if info is None: info = dict()
    varbs = None
    valid_idxs = []
    for i,tok_id in enumerate(seq):
        _, varbs = causal_model(tok_id, variables=varbs, **info)
        if check_fxn(varbs) and tok_id not in invalids:
            valid_idxs.append(i)
    if len(valid_idxs)==0: return -1
    return choice(valid_idxs)

def demo_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index in the argued
    sequence that corresponds to the demonstration phase, defined
    by the phase being equal to 0.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            optionally argue the info dict instead of demo_ids.
            Must contain the key "demo_ids". If None,
            demo_ids must be not none.
    Returns:
        idx: int
            a sampled index. This index will be within the
            demo phase of the argued sequence.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=lambda v: v["phase"]==0,
        invalids=invalids,
        info=info)
    return idx


def resp_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index in the argued
    sequence that corresponds to the response phase.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            optionally argue the info dict. Must contain the
            key "resp_id". If None, resp_ids must be not none.
    Returns:
        idx: int
            a sampled index. This index will be within the
            resp phase of the argued sequence.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=lambda v: v["phase"]>0,
        invalids=invalids,
        info=info)
    return idx

def any_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index from either the demonstration
    phase or the response phase. Will not sample a special token.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            optionally argue the info dict. Must contain the
            key "resp_id". If None, resp_ids must be not none.
        info: dict
            optionally argue the info dict instead of demo_ids and
            resp_ids.
            Must contain the key "demo_ids" and "resp_id" or "resp_ids".
            If None, demo_ids and resp_ids must be not none.
    Returns:
        idx: int
            a sampled index. This index will be within the
            demo phase of the argued sequence.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=lambda v: True,
        invalids=invalids,
        info=info)
    return idx

def demo_value_sampler(
        seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]>1,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that satisfies the check_fxn
    while also being in the demonstration phase. The demo phase is
    determined assuming it corresponds to the phase being equal to 0.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        check_fxn: lambda fxn
            the check function will receive variable dicts and check
            if they satisfy the sampling criteria. If no values
            satisfy the check function, then -1 is returned.
        info: dict
            optionally argue the info dict that will be passed to the
            causal_model.
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=lambda v: v["phase"]==0 and check_fxn(v),
        info=info)
    if idx<0: return idx
    tok_id = seq[idx]
    bos_id = info["bos_id"]
    pad_id = info["pad_id"]
    while tok_id in info["trigger_ids"] or tok_id==bos_id or tok_id==pad_id:
        idx = value_sampler(
            seq=seq,
            causal_model=causal_model,
            check_fxn=lambda v: v["phase"]==0 and check_fxn(v),
            invalids={bos_id,pad_id,*info["trigger_ids"]},
            info=info)
        tok_id = seq[idx]
    return idx

def resp_value_sampler(
        seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==0,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that satisfies the check_fxn
    while also being in the response phase. The resp phase is
    determined assuming it corresponds to a phase being something other
    than 0.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        check_fxn: lambda fxn
            the check function will receive variable dicts and check
            if they satisfy the sampling criteria. If no values
            satisfy the check function, then -1 is returned.
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "resp_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=lambda v: v["phase"]!=0 and check_fxn(v),
        info=info,
        invalids=invalids,)
    return idx

def eos_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that satisfies the check_fxn
    while also being in the response phase. The resp phase is
    determined assuming it corresponds to a phase being something other
    than 0.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "resp_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
    """
    assert type(info)==dict
    eos_id = info["eos_id"]
    varbs = None
    valid_idxs = []
    for i,tok_id in enumerate(seq):
        # want to grab index in which the model will predict the eos
        # token, rather than the index at which the EOS will be used as
        # input
        if tok_id==eos_id:
            valid_idxs.append(i-1) 
    if len(valid_idxs)==0: return -1
    return choice(valid_idxs)

def trig_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that satisfies the check_fxn
    while also being in the response phase. The resp phase is
    determined assuming it corresponds to a phase being something other
    than 0.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "resp_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
    """
    assert type(info)==dict
    trig_ids = set(info["trigger_ids"])
    varbs = None
    valid_idxs = []
    for i,tok_id in enumerate(seq):
        # want to grab index in which the model has encoded the trig token
        if tok_id in trig_ids: valid_idxs.append(i)
    if len(valid_idxs)==0: return -1
    return choice(valid_idxs)

def odd_demo_sampler(
        seq,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that has an odd count within
    the demo phase according to the count_only_cmod.

    Args:
        seq: iterable of ints
            the id sequence
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "demo_ids", "trigger_ids", "resp_id", "bos_id"
            and "eos_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    return demo_value_sampler(
        seq=seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==1,
        info=info,)

def even_demo_sampler(
        seq,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that has an even count within
    the demo phase according to the count_only_cmod.

    Args:
        seq: iterable of ints
            the id sequence
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "demo_ids", "trigger_ids", "resp_id", "bos_id"
            and "eos_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    return demo_value_sampler(
        seq=seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==0,
        info=info,)

def odd_resp_sampler(
        seq,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that has an odd count within
    the response phase according to the count_only_cmod.

    Args:
        seq: iterable of ints
            the id sequence
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "demo_ids", "trigger_ids", "resp_id", "bos_id"
            and "eos_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    return resp_value_sampler(
        seq=seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==1,
        info=info,)

def even_resp_sampler(
        seq,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index that has an even count within
    the response phase according to the count_only_cmod.

    Args:
        seq: iterable of ints
            the id sequence
        info: dict
            an information dict that will be passed to the causal_model.
            Must contain "demo_ids", "trigger_ids", "resp_id", "bos_id"
            and "eos_id".
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    return resp_value_sampler(
        seq=seq,
        causal_model=count_only_cmod,
        check_fxn=lambda v: v["count"]%2==0,
        info=info,)

def nofirst_demo_sampler( *args, **kwargs ):
    """
    This function will sample an index that is in the demo phase but
    has a count above 1.

    Args:
        see demo_value_sampler function for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    kwargs["causal_model"] = count_only_cmod
    return demo_value_sampler(
        *args, 
        check_fxn=lambda v: v["count"]>1,
        **kwargs)

def nofirst_resp_sampler( *args, **kwargs ):
    """
    This function will sample an index that is in the resp phase but
    is at least 1 step away from the trigger id. Assumes variables have
    `resp_count`.

    Args:
        see demo_value_sampler function for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    kwargs["causal_model"] = count_only_cmod
    return resp_value_sampler(
        *args, 
        check_fxn=lambda v: v["resp_count"]>1,
        **kwargs)

def nofirst_any_sampler( *args, **kwargs ):
    """
    This function will sample an index from either nofirst_resp_sampler
    or nofirst_demo_sampler with 0.5 probability for each.
    Assumes variables dicts have `resp_count`.

    Args:
        see demo_value_sampler function for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    if np.random.random()<0.5:
        return nofirst_resp_sampler(*args, **kwargs)
    return nofirst_demo_sampler(*args, **kwargs)

def last_resp_sampler( *args, **kwargs ):
    """
    This function will sample an index that is in the resp phase that
    is 1 step away from the eos id. Assumes variables have `count`
    and `phase`.

    Args:
        see demo_value_sampler function for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    kwargs["causal_model"] = count_only_cmod
    return resp_value_sampler(
        *args, 
        check_fxn=lambda v: v["count"]==0,
        **kwargs)

def second2last_resp_sampler( *args, **kwargs ):
    """
    This function will sample an index that is in the resp phase that
    is 1 step away from the eos id. Assumes variables have `count`
    and `phase`.

    Args:
        see demo_value_sampler function for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    kwargs["causal_model"] = count_only_cmod
    return resp_value_sampler(
        *args, 
        check_fxn=lambda v: v["count"]==1,
        **kwargs)

def notlast_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index from either the demonstration
    phase or the response phase. Will not sample a special token.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            optionally argue the info dict. Must contain the
            key "resp_id". If None, resp_ids must be not none.
        info: dict
            optionally argue the info dict instead of demo_ids and
            resp_ids.
            Must contain the key "demo_ids" and "resp_id" or "resp_ids".
            If None, demo_ids and resp_ids must be not none.
    Returns:
        idx: int
            a sampled index. This index will be within the
            demo phase of the argued sequence.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    check_fxn = lambda v: not (v["phase"]==1 and v["count"]==0)
    if causal_model==distributed_cmod:
        check_fxn = lambda v: not (v["phase"]==1 and np.mean(v["tok_vals"])<=0)
    idx = value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=check_fxn,
        invalids=invalids,
        info=info)
    return idx

def notlast_resp_sampler(
        seq,
        causal_model=count_only_cmod,
        info=None,
        *args, **kwargs):
    """
    This function will sample an index from either the demonstration
    phase or the response phase. Will not sample a special token.

    Args:
        seq: iterable of ints
            the id sequence
        causal_model: python function
            the high level causal model
        info: dict
            optionally argue the info dict. Must contain the
            key "resp_id". If None, resp_ids must be not none.
        info: dict
            optionally argue the info dict instead of demo_ids and
            resp_ids.
            Must contain the key "demo_ids" and "resp_id" or "resp_ids".
            If None, demo_ids and resp_ids must be not none.
    Returns:
        idx: int
            a sampled index. This index will be within the
            demo phase of the argued sequence.
    """
    bos_id = info["bos_id"]
    eos_id = info["eos_id"]
    pad_id = info["pad_id"]
    trig_ids = info["trigger_ids"]
    invalids = {bos_id,eos_id,pad_id, *trig_ids}
    check_fxn = lambda v: v["count"]>0
    idx = resp_value_sampler(
        seq=seq,
        causal_model=causal_model,
        check_fxn=check_fxn,
        invalids=invalids,
        info=info)
    return idx

def anyorlast_sampler( *args, **kwargs ):
    """
    This function will sample an index that is, with 50/50 probability,
    either in the resp phase that is 1 step away from the eos id, or
    the output of any_sampler. Assumes variables have `resp_count` and
    `phase`.

    Args:
        see last_resp_sampler and any_sampler functions for args
    Returns:
        idx: int
            a sampled index. This index will correspond to an index in
            which the high level variables satisfy the check function.
    """
    if np.random.random()>0.5:
        return last_resp_sampler(*args, **kwargs)
    return any_sampler(*args, **kwargs)

if __name__=="__main__":
    seq = [1,4,5,6,4,5,7,3,3,3,3,3,2]
    info = {
        "demo_ids": [4,5,6],
        "trigger_ids": [7],
        "resp_id": 3,
        "eos_id": 2,
        "bos_id": 1,
    }
    valids = {1,2,3,4,5,7,8,9,10,11}
    for i in range(1000):
        assert any_sampler(seq, info=info) in valids
    print("test passed")

    #print()
    #print("Demo Even")
    #for i in range(10):
    #    print("Sampled:", even_demo_sampler(seq, info=info))

    #print()
    #print("Resp Odd")
    #for i in range(10):
    #    print("Sampled:", odd_resp_sampler(seq, info=info))

    #print()
    #print("Resp Even")
    #for i in range(10):
    #    print("Sampled:", even_resp_sampler(seq, info=info))

