import numpy as np
import torch
import copy

from utils import (
    run_til_idx, run_for_n_steps, choice,
)
from causal_models.samplers import any_sampler
from causal_models.num_equivalence import count_only_cmod

from dl_utils.utils import pad_to, get_mask_between, arglast
from dl_utils.datas import CausalDataset
from dl_utils.tokenizer import Tokenizer

from tasks.num_equivalence import (
    sample_copy_equivalence, sample_num_equivalence
)
from tasks.induction_heads import sample_induction_heads
from tasks.arithmetic import sample_arithmetic


def get_task_seq_fxn(
        task="num_equivalence",
        task_config=dict(),):
    """
    Determines and returns the appropriate function to sample sequences
    of the requested task. 
    
    Args:
        task: str
            the name of the task
        task_config: dict
            a dict of keyword arguments to get the appropriate task func.
    Returns:
        python function
            the returned function will sample sequences of the task.
    """
    if task == "num_equivalence":
        sample_fxn =  sample_num_equivalence
    elif task == "copy_equivalence":
        sample_fxn =  sample_copy_equivalence
    elif task == "induction_heads":
        sample_fxn =  sample_induction_heads
    elif task =="arithmetic":
        sample_fxn =  sample_arithmetic
    def task_seq_fxn():
        return sample_fxn(**task_config)
    return task_seq_fxn

# TODO: ALL OTHER DATASET CREATION FUNCTIONS NEED TO BE AUGMENTED
"""
PSEUDO:
- all tasks return lists of strings
    - this will make it easier to decode them for printing
- the tasks will not include special ids like bos
- all tasks will return "info" which is a dict to group special entries
    - this will be used for causal abstractions
    - keys are strings, values are strings and iterables of strings
- all dataset creations will append special tokens
- all dataset creations will return a tokenizer
- all dataset creations will convert the info dict values to integer ids
    - the info dict will only be used in the causal abstraction creations

"""
def convert_to_ids(
        data,
        task_masks=None,
        tokenizer=None,
        add_bos=True,
        add_eos=True,
        info=None):
    """
    Args:
        data: list of lists of strs
        task_masks: list of lists of (ints or bools)
        tokenizer: Tokenizer
        add_bos: bool
            if true, will prepend a bos token to each sample
        add_eos: bool
            if true, will prepend a eos token to each sample
        info: none or dict
            if a dict is argued, will use the dict to ensure all
            relevant token types are considered.
            keys: str
            values: int or listlike
    Returns:
        ids: list of lists of ints
        max_len: int
            longest length of any id list in ids
        tokenizer: Tokenizer
    """
    if tokenizer is None:
        tokenizer = Tokenizer(words=set(), unk_token=None)
        tokenizer.train(tok_X=data, alphabetize=True)
    bos_id = tokenizer.bos_id
    eos_id = tokenizer.eos_id
    ids = []
    max_len = 0
    for i,samp in enumerate(data):
        seq = [tokenizer.convert2id(s) for s in samp]
        if add_bos:
            seq = [bos_id] + seq
            if task_masks is not None:
                task_masks[i] = [0] + task_masks[i]
        if add_eos:
            seq = seq + [eos_id]
            if task_masks is not None:
                task_masks[i] = task_masks[i] + [1]
        ids.append(seq)
        if len(ids[i])>max_len: max_len = len(ids[i])
    return ids, task_masks, max_len, tokenizer

def info2tokdict(info, add_bos, add_eos):
    """
    Args:
        info: dict
        add_bos: bool
        add_eos: bool
    returns:
        dict
    """
    toks = set()
    for k,v in info.items():
        if type(v)==str: toks.add(v)
        elif hasattr(v, "__len__"): toks = toks.union(set(v))
    toks = sorted(toks)
    word2id = {}
    word2id[Tokenizer.pad_token] = 0
    if add_bos: word2id[Tokenizer.bos_token] = 1
    if add_eos: word2id[Tokenizer.eos_token] = 2
    offset = len(word2id)
    word2id = {**word2id, **{t: offset+i for i,t in enumerate(toks)}}
    return word2id

def preprocess_raw_samples(
        data,
        task_masks=None,
        add_bos=True,
        add_eos=True,
        seq_len=None,
        info=dict(),
        tokenizer=None,
        as_tensor=False,
        concat=False,
        word2id=None,
        **kwargs,
    ):
    """
    This function handles all the preprocessing steps to convert a
    list of lists of tokens (strings) of varying lengths to a sequence
    of sequences of token ids (ints) in the desired format.

    Args:
        data: list of lists of str
        task_masks: list of lists of ints
            0 means tokens will be teacher forced
        add_bos: bool
        add_eos: bool
        info: dict
        tokenizer: Tokenizer
        as_tensor: bool
            if true, and concat is false, will return both data and
            task_masks as tensors
        concat: bool
            if true, will concatenate all data into a single sequence
            and will sample contexts from this single, long sequence
            directly.
    """
    if concat and not (add_bos or add_eos):
        add_eos = True
    if word2id is None and info is not None:
        word2id = info2tokdict(info, add_bos=add_bos, add_eos=add_eos)
    if tokenizer is None and word2id is not None:
        tokenizer = Tokenizer(
            words=set(),
            unk_token=None,
            word2id=word2id)

    # Convert strings to ids and add bos/eos ids
    data, task_masks, max_len, tokenizer = convert_to_ids(
        data=data,
        task_masks=task_masks,
        tokenizer=tokenizer,
        info=info,
        add_bos=add_bos,
        add_eos=add_eos)

    # Pad everything to same lengths if concat is false
    if seq_len is None: seq_len = max_len
    pad_id = tokenizer.pad_id
    if not concat:
        for i in range(len(data)):
            #Pads the argued list to the goal length along a single
            # dimension.
            data[i] = pad_to(
                data[i], seq_len, fill_val=pad_id, side="right", dim=-1)
            data[i] = data[i][:seq_len]
            if task_masks is not None:
                task_masks[i] = pad_to(
                    task_masks[i], seq_len,
                    fill_val=pad_id, side="right", dim=-1)
                task_masks[i] = task_masks[i][:seq_len]

    # Change info dict from string values to id values. info dicts will
    # be used in causal programs
    converter = tokenizer.convert2id
    for k in list(info.keys()):
        if "token" in k:
            new_k = k.replace("token", "id")
            if type(info[k])==str:
                info[new_k] = converter(info[k])
            elif type(info[k])==list or type(info[k])==set:
                info[new_k] = type(info[k])([converter(w) for w in info[k]])
    info["pad_id"] = tokenizer.pad_id
    info["bos_id"] = tokenizer.bos_id
    info["eos_id"] = tokenizer.eos_id

    # Convert to tensor and return
    if as_tensor and not concat:
        data = torch.LongTensor(data)
        if task_masks is not None:
            task_masks = torch.BoolTensor(task_masks)
    return data, task_masks, info, tokenizer

def make_dataset(
        task_config=None,
        task_type="num_equivalence",
        n_samples=100,
        seq_len=None,
        tokenizer=None,
        as_tensor=False,
        add_bos=True,
        add_eos=True,
        concat=False,
        *args, **kwargs):
    """
    This function simplifies sampling an entire dataset.
    
    Args:
        task_config: dict
            a configuration dict for the task
        task_type: str
            the name of the desired task. (see `get_task_seq_fxn`)
        n_samples: int
            the number of samples in the dataset
        seq_len: None or int
            the length of the sequences. Otherwise defaults to longest sequence
            length.
        tokenizer: Tokenizer
        as_tensor: bool
            optionally return the data as tensors. Otherwise will be
            returned as lists.
        add_bos: bool
            if true, will add a beginning of sequence token to the
            beginning of each sample.
        add_eos: bool
            if true, will add an end of sequence token to the end of
            each sample.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
    Returns:
        data: list of lists of ints
            these are the sampled sequences. the padding will automatically
            be applied using 0 as the padding token.
        info: dict
            a dict of special str->int pairs that are useful for
            desigining the causal abstractions.
            This is an example for the numeric equivalence task:
                "pad_id": int
                    the padding token id
                "bos_id": int
                    the beginning of sequence token id
                "eos_id": int
                    the end of sequence token id
                "resp_id": int
                    the response id
                "trigger_ids": list of ints
                    all possible trigger ids. If this list is greater than 1
                    which will occur if using multi_trigger setting, the
                    order of the list will correspond to the order of the
                    demo type ids that each trigger refers to.
                "trigger_offset": int
                    the offset from zero to the first trigger token.
                "demo_ids": list of ints
                    the possible demonstration token ids. The length of this
                    list will be equal to n_demo_types and the order will
                    correspond to the appropriate trigger_id index if using
                    multi_trigger setting.
    """
    max_len = 0
    if task_config is None: task_config = kwargs
    get_seq = get_task_seq_fxn(
        task=task_type,
        task_config=task_config,)

    # Collect list of lists of strings and list of lists of bools/ints
    data = []
    task_masks = []
    while len(data)<n_samples:
        sample, task_mask, info = get_seq()
        data.append(sample)
        task_masks.append(task_mask)

    data, task_masks, info, tokenizer = preprocess_raw_samples(
        data=data,
        task_masks=task_masks,
        info=info,
        tokenizer=tokenizer,
        add_bos=add_bos,
        add_eos=add_eos,
        seq_len=seq_len,
        as_tensor=as_tensor,
        concat=concat,
        **kwargs,
    )
    return data, task_masks, info, tokenizer

def make_systematic_dataset(
        n_samples=15,
        task_config=None,
        task_type="num_equivalence",
        keyword="targ_count",
        min_val=1,
        max_val=20,
        step_size=1,
        seq_len=None,
        tokenizer=None,
        as_tensor=False,
        add_bos=True,
        add_eos=True,
        concat=False,
        *args, **kwargs):
    """
    This function simplifies making a dataset consisting of
    n_samples of each target count from min_count to max_count
    inclusive. It will also sample n_samples for each trigger_id
    (multiplying the total number of samples by len(trigger_ids))
    if trigger_ids is argued. Otherwise the trigger id will be
    sampled independently.
    
    Args:
        n_samples: int
            the number of samples for each target count and
            each trigger_id if trigger_ids is not None.
        task_type: str
            the name of the task type. This will determine what task
            to sample from.
        keyword: str
            the keyword argument to systematically search over.
        min_val: int (inclusive)
            the minimum value (corresponding to the argued `keyword`)
            to include in the systematic samples
        max_val: int (inclusive)
            the maximum value (corresponding to the argued `keyword`)
            to include in the systematic samples
        step_size: int
            the step size for the range from min_val to max_val i.e. the
            returned values will include numbers from
            range(min_val, max_val+1, step_size)
        seq_len: None or int
            the length of the sequences. Otherwise defaults to longest
            sequence length.
    Returns:
        data: list of lists of ints
            these are the sampled sequences. the padding will automatically
            be applied using 0 as the padding token. It will be of length
            N = n_samples*(max_count-min_count+1). And if trigger_ids is not
            None, the resulting length will be N*len(trigger_ids)
        task_masks: list of lists of ints/bools
            0s mean corresponding tokens will always be teacher forced
        info: dict
            "pad_id": int
                the padding token id
            "bos_id": int
                the beginning of sequence token id
            "eos_id": int
                the end of sequence token id
            "resp_id": int
                the response id
            "trigger_ids": list of ints
                all possible trigger ids. If this list is greater than 1
                which will occur if using multi_trigger setting, the
                order of the list will correspond to the order of the
                demo type ids that each trigger refers to.
            "trigger_offset": int
                the offset from zero to the first trigger token.
            "demo_ids": list of ints
                the possible demonstration token ids. The length of this
                list will be equal to n_demo_types and the order will
                correspond to the appropriate trigger_id index if using
                multi_trigger setting.
        tokenizer: Tokenizer
        samp_types: dict
            a dict with a single key, `keyword`, that has a value that
            is a list detailing the value of the `keyword` for each
            sample.

            Contents:
                `keyword`: list of values
                    the values (somewhere between `min_val` and `max_val`
                    corresponding to each of the samples.
    """
    if task_config is None:
        task_config = {**kwargs}
    if "info" in task_config:
        for k in task_config["info"]:
            if k in task_config:
                del task_config[k]
        del task_config["info"]
    task_config["ret_info"] = True
    if "hold_outs" in task_config:
        kwargs["hold_outs"] = task_config["hold_outs"]
        task_config["hold_outs"] = set()

    max_len = 0
    samp_types = []
    data = []
    task_masks = []
    for value in range(min_val, max_val+1, step_size):
        if "hold_outs" in kwargs and value in kwargs["hold_outs"]:
            continue
        task_config[keyword] = value
        sample_function = get_task_seq_fxn(
            task=task_type,
            task_config=task_config)
        sub_sample = []
        while len(sub_sample)<n_samples:
            sample, task_mask, info = sample_function()
            if len(sample)>max_len:
                max_len = len(sample)
            sub_sample.append(sample)
            task_masks.append(task_mask)
            samp_types.append(value)
        data += sub_sample
    data, task_masks, info, tokenizer = preprocess_raw_samples(
        data=data,
        task_masks=task_masks,
        info=info,
        tokenizer=tokenizer,
        add_bos=add_bos,
        add_eos=add_eos,
        seq_len=seq_len,
        as_tensor=as_tensor,
        concat=concat,
        **kwargs,
    )
    samp_types = {keyword: samp_types}
    return data, task_masks, info, samp_types, tokenizer

def make_intervened_seqs(
        seq1s,
        seq2s,
        idx1s,
        idx2s,
        tmask1s,
        tmask2s,
        seq1_cmod,
        seq2_cmod,
        seq1_var="count",
        seq2_var="count",
        trigger_steps=None,
        trigger_step=None,
        trigger_id=None,
        seq1_info={},
        seq2_info={},):
    """
    This function will combined seq1 with the expected output under the
    argued causal model, seq1, idx1, seq2, idx2.
    
    Args:
        seq1s: list of lists of ints
            okay to argue an ndarray
        seq2s: list of lists of ints
            okay to argue an ndarray
        idx1s: list of ints
            okay to argue an ndarray
        idx2s: list of ints
            okay to argue an ndarray
        tmask1s: list of lists of bools/ints
            the task masks where 0 means that the corresponding token
            will be teacher forced. okay to argue an ndarray.
        tmask2s: list of lists of bools/ints
            the task masks where 0 means that the corresponding token
            will be teacher forced. okay to argue an ndarray.
        seq1_cmod: python function
            the causal/symbolic program
        seq1_var: str
            the variable to perform the causal interchange on. The
            available options will depend on the causal model.
        seq2_var: str
            the variable to perform the causal interchange on. The
            available options will depend on the causal model.
        trigger_steps: None or int or sequence of ints
            optionally argue the number of steps until encountering
            a trigger token. Only applies if idx1 is in the demo
            phase. If None, will not insert a trigger token. If you
            argue an int, the number of trigger steps will be sampled
            from 0 to that argued value (inclusive), if you argue a
            sequence, it will be sampled uniformly from the sequence.
        trigger_step: int or None
            optionally specify a particular number of steps before
            inserting the trigger token. This value will be overwritten
            by trigger_steps.
        trigger_id: None or int
            optionally argue the trigger id that you would like
            to use when inserting a trigger token. Otherwise
            will sample the trigger id from the argued seq1_info
            trigger ids. Only applies if trigger_step is not None.
        seq1_info:
            pad_id: int
               the padding token id
            bos_id: int
               the beginning of sequence token id
            eos_id: int
               the end of sequence token id
            resp_id: int
               the response id
            trigger_ids: list of ints
               all possible trigger ids. If this list is greater than 1
               which will occur if using multi_trigger setting, the
               order of the list will correspond to the order of the
               demo type ids that each trigger refers to.
            demo_ids: list of ints
                the possible demonstration token ids. The length of this
                list will be equal to n_demo_types and the order will
                correspond to the appropriate trigger_id index if using
                multi_trigger setting.
            pre_trigger: bool
                if true, will indicate at the beginning of the demonstration
                phase which token type to count. Otherwise indicates at the
                beginning of the response phase. Only applies if multi_trigger
                is true.
    Returns:
        intr_seqs: list of lists of ints
            the modified seq1s that will now follow the causal interventions
            under the argued sequences, indices, and causal model.
        var1s: list of dicts
            the variables at the intervention index from seq1
        var2s: list of dicts
            the variables at the intervention index from seq2
    """
    if seq1_var!=seq2_var: raise NotImplemented
    seq1_info["pad_id"] = seq1_info.get("pad_id", 0)
    seq1_info["demo_ids"] = seq1_info.get("demo_ids", [4,5,6])
    seq1_info["resp_id"] = seq1_info.get("resp_id", seq1_info["demo_ids"][0])
    seq1_info["eos_id"] = seq1_info.get("eos_id", 2)
    seq1_info["trigger_id"] = trigger_id
    tids = seq1_info.get("trigger_ids", [7])
    seq2_info["pad_id"] = seq2_info.get("pad_id", 0)
    seq2_info["demo_ids"] = seq2_info.get("demo_ids", [4,5,6])
    seq2_info["resp_id"] = seq2_info.get("resp_id", seq2_info["demo_ids"][0])
    seq2_info["eos_id"] = seq2_info.get("eos_id", 2)
    eos_id1 = seq1_info.get("eos_id", 2)
    eos_id2 = seq2_info.get("eos_id", 2)

    max_len = 0
    var1s = []
    var2s = []
    intr_seqs = []
    intr_tmasks = []
    removes = []

    # Assume that we will perform the swap after the representation
    # is used for the current prediction. So model makes pred for
    # current step, then we swap.
    for i in range(len(seq1s)):
        vars1s,_ = run_til_idx(seq1_cmod, seq1s[i], idx1s[i], **seq1_info)
        #print("vs1")
        #print([v["count"] for v in vars1])
        #print(seq1s[i])
        vars1 = vars1s[-1] # last element of list of vars
        var1s.append({**vars1})
        vars2s,_ = run_til_idx(seq2_cmod, seq2s[i], idx2s[i], **seq2_info)
        #print("vs2")
        #print("cmod:", seq2_cmod)
        #print([v for v in vars2s])
        #print(seq2s[i])
        #print(idx2s[i])
        #print()
        vars2 = vars2s[-1] # last element of list of vars
        var2s.append({**vars2})

        if seq1_var!="no_swap":
            if type(vars1[seq1_var]) == list:
                vars1[seq1_var] = [v for v in vars2[seq2_var]]
            else:
                vars1[seq1_var] = vars2[seq2_var]

        # Includes the first inpt_id and tmask in the returned lists.
        # This is tricky because the trigger token will sometimes
        # replace the first token. The variables follow the inpt tokens
        inpt_mask = tmask1s[i][idx1s[i]+1]
        if seq1_var=="inpt_val":
            # We do this because the transformer appears to use the current
            # token to decide which token to produce next. This is okay
            # though because we never sample from the last token, so
            # we never just give the correct answer
            inpt_mask = 0 
        inpt_id = seq1s[i][idx1s[i]+1]
        if seq1_var=="phase":
            if vars1["phase"]==0:
                inpt_id = choice(seq1_info["demo_ids"])
                inpt_mask = 0
            else:
                inpt_id = seq1_info["resp_id"]
                inpt_mask = 1
        else:
            if inpt_id==eos_id1:
                inpt_id = seq1_info["resp_id"]
            elif inpt_id in tids:
                inpt_id = choice(seq1_info["demo_ids"])
                inpt_mask = 0
        seq3, tmask3, vbls = run_for_n_steps(
            causal_model=seq1_cmod,
            variables=vars1,
            inpt_id=inpt_id,
            inpt_tmask=inpt_mask,
            n_steps=None,
            stop_ids={seq1_info.get("pad_id", 0)},
            trigger_steps=trigger_steps,
            trigger_step=trigger_step,
            **seq1_info)

        intr_seq = [s for s in seq1s[i][:idx1s[i]+1]]
        tmask = [t for t in tmask1s[i][:idx1s[i]+1]]
        intr_vbls = [v for v in vars1s[:idx1s[i]]]+[vars1]
        # Can sometimes be difficult to use actually predicted token
        # as the input at the start of the run_for_n_steps function.
        # This is a hacky way to ignore the first token, if needed.
        if seq1_var=="no_swap":
            start = seq1s[i][idx1s[i]+1]==eos_id1
        elif vars1["phase"]==0:
            start = 0
        else:
            start = seq2s[i][idx2s[i]+1]==eos_id2
        for j in range(int(start), len(seq3)):
            intr_seq.append(seq3[j])
            tmask.append(tmask3[j])
            intr_vbls.append(vbls[j])
        max_len = max(max_len, len(intr_seq))
        intr_seqs.append(intr_seq)
        intr_tmasks.append(tmask)

    #    assert seq1s[i][idx1s[i]] not in seq1_info["trigger_ids"]
    #    if True: #i%10==0:
    #        print()
    #        print("idx1:", idx1s[i], "-", "idx2:", idx2s[i])
    #        print("idxs:", ",".join(["{:2}".format(str(s)) for s in range(max(len(intr_seq), len(seq2s[i])))]))
    #        print("seq1:", ",".join(["{:2}".format(s) for s in intr_seq]))
    #        print("val :", ",".join(["{:2}".format(w[seq1_var]) for w in intr_vbls]))
    #        print("ogsq:", ",".join(["{:2}".format(s) for s in seq1s[i]]))
    #        print("seq2:", ",".join(["{:2}".format(s) for s in seq2s[i]]))
    #        print("val :", ",".join(["{:2}".format(w[seq2_var]) for w in vars2s]))
    #        print("tmsk:", ",".join(["{:2}".format(s) for s in tmask]))
    #        print("old :", ",".join(["{:2}".format(s) for s in seq1s[i][:idx1s[i]+1]]))
    #        print("new :", ",".join(["{:2}".format(s) for s in seq3]))
    #print("Seq1")
    #for k,v in seq1_info.items():
    #    print(k,v)
    #print("Seq2")
    #for k,v in seq2_info.items():
    #    print(k,v)
    #assert False

    #hist = {i:0 for i in range(10)}
    #for i in range(len(seq1s)):
    #    hist[int(seq1s[i][idx1s[i]])] += 1
    #print(hist)
    #hist = {i:0 for i in range(10)}
    #for i in range(len(seq2s)):
    #    hist[int(seq2s[i][idx2s[i]])] += 1
    #print(hist)
    #assert False

        #if intr_var_key in {"resp_count","demo_count"} and\
        #        vars1["demo_count"]<vars1["resp_count"]:
        #    removes.append(i)

    # Can try to remove
    #if len(removes)>0:
    #    removes = set(removes)
    #    new_iss = []
    #    new_seq2s = []
    #    new_idx1s = []
    #    new_idx2s = []
    #    for i in range(len(intr_seqs)):
    #       if i not in removes:
    #           new_iss.append(intr_seqs[i])
    #           new_seq2s.append(seq2s[i])
    #           new_idx1s.append(idx1s[i])
    #           new_idx2s.append(idx2s[i])
    #    seq2s = new_seq2s
    #    intr_seqs = new_iss
    #    idx1s = new_idx1s
    #    idx2s = new_idx2s

    for i in range(len(intr_seqs)):
        intr_seqs[i] = pad_to(
            intr_seqs[i],
            max_len,
            fill_val=seq1_info.get("pad_id",0))
        intr_seqs[i] = intr_seqs[i][:max_len]
        intr_tmasks[i] = pad_to(
            intr_tmasks[i],
            max_len,
            fill_val=0)
        intr_tmasks[i] = intr_tmasks[i][:max_len]

    ret_dict = {
        "seq1s": intr_seqs,
        "seq2s": seq2s,
        "idx1s": idx1s,
        "idx2s": idx2s,
        "var1s": var1s,
        "var2s": var2s,
        "tmask1s": intr_tmasks,
        "tmask2s": tmask2s,
    }
    return ret_dict

def make_systematic_intrv_dataset(
        n_samples=15,
        dest_incr=1,
        source_incr=4,
        dest_step=1,
        source_step=1,
        trigger_steps=None,
        trigger_step=1,
        seq1_kwargs=None,
        seq2_kwargs=None,
        ret_info=True,
        *args, **kwargs):
    """
    This function will sample n_samples base sequences for each target
    quantity from min count to max count incrementing by dest_incr
    in seq1_kwargs and then will sample n_samples source sequences for
    each count in seq2_kwargs incrementing by source_incr. The intervention
    indices are then selected systematically from 1 to the target count
    of the accompanying sequence, incrementing the indices by dest_step
    and source_step. This process is completed between each destination
    sequence and each source sequence.

    Args:
        n_samples: int
            the number of samples for each target count and
            each trigger_id if trigger_ids is not None.
        min_count: int
            the minimum target count to sample
        max_count: int (inclusive)
            the maximum inclusive target count for the base sequences
        max_source_count: int (inclusive)
            the maximum inclusive target count for the source sequences
        max_demo_tokens: int
            optionally argue a maximum number of demonstration
            tokens. This only applies if using multi_trigger where
            the total number of tokens is different than the
            target quantity. Defaults to maximum
        n_demo_types: int
            the number of demonstration types (ie number of types of
            fruit if we were counting fruit).
        multi_trigger: bool
            if true, will use a conditional trigger to indicate
            which token type to count during the demonstration
            phase.
        pre_trigger: bool
            if true, will indicate at the beginning of the demonstration
            phase which token type to count. Otherwise indicates at the
            beginning of the response phase. Only applies if multi_trigger
            is true.
        seq_len: None or int
            the length of the sequences. Otherwise defaults to longest
            sequence length.
        hold_outs: set of ints
            optionally argue values that you don't want sampled as the
            target quantities.
        base_hold_outs: set of ints
            optionally argue values that you don't want to be sampled as
            the base/destination count
        source_hold_outs: set of ints
            optionally argue values that you don't want to be sampled as
            the source count
        pad_id: int
            the id of the padding token
        trigger_ids: list of ints or None
            Optionally specify a list of trigger_ids to sample from
        copy_task: bool
            if true, will return data that is a copy task rather than
            a counting task.
        trigger_steps: None or int or sequence of ints
            the number of demonstration tokens to precede the trigger token
            after an intervention in the demonstration phase. If you
            argue an int, the number of trigger steps will be sampled
            uniformly in the range from 0 to the argued value (inclusive).
            If you argue a sequence, it will be sampled
            uniformly from the sequence. no trigger will be inserted if
            None.
        trigger_step: int or None
            optionally argue a specific number of steps that will
            precede the insertion of a trigger token. This value is
            overwritten if trigger_steps is not None
        dest_incr: int
            Refers to the increment in the destination sequence target
            quantities. Concretely,
            the target counts will range from min_count to max_count
            incrementing by dest_incr for the num_equivalence task. i.e.
            range(min, max, dest_incr)
        source_incr: int
            Refers to the increment in the source sequence target
            quantities. Concretely, the target counts
            will range from 0 to max_count incrementing by
            by source_incr. i.e. range(min, max, source_incr)
        dest_step: int
            the increment of the intervention index within a sequence.
            So, for a given base target count, the intervention indices
            will step up according to this value.
        source_step: int
            the increment of the intervention index within a sequence.
            So, for a given source target count, the intervention indices
            will step up according to this value.
        ret_info: bool
            if true, will return a dict of information that is helpful
            to ensure token types align everywhere.
    Returns:
        data_dict: dict
            "base_data": torch long tensor (B,S)
            "base_idxs": torch long tensor (B,)
            "base_vars": list of dicts
                Each dict should hold the variables at the swap index.
                keys:
                    "targ_count": int
                    "phase": int
                    "demo_count": int
                    "resp_count": int
                    "trigger_id": int
            "source_data": torch long tensor (B,S)
            "source_idxs": torch long tensor (B,)
            "source_vars": list of dicts
                Each dict should hold the variables at the swap index.
                keys:
                    "targ_count": int
                    "phase": int
                    "demo_count": int
                    "resp_count": int
                    "trigger_id": int
        info: dict
            "pad_id": int
                the padding token id
            "bos_id": int
                the beginning of sequence token id
            "eos_id": int
                the end of sequence token id
            "resp_id": int
                the response id
            "trigger_ids": list of ints
                all possible trigger ids. If this list is greater than 1
                which will occur if using multi_trigger setting, the
                order of the list will correspond to the order of the
                demo type ids that each trigger refers to.
            "trigger_offset": int
                the offset from zero to the first trigger token.
            "demo_ids": list of ints
                the possible demonstration token ids. The length of this
                list will be equal to n_demo_types and the order will
                correspond to the appropriate trigger_id index if using
                multi_trigger setting.
    """
    kwarg_defaults={
        "task_type": "num_equivalence",
        "task_config": {
            "min_count": 1,
            "max_count": 20,
            "hold_outs": set(),
            "max_demo_tokens": None,
            "n_demo_types": 3,
            "multi_trigger": False,
            "pre_trigger": False,
        },
        "pad_id": 0,
        "bos_id": 1,
        "eos_id": 2,
        "seq_len": None,
        "trigger_ids": None,
        "trigger_id": None,
        "demo_ids": None,
        "resp_id": None,
        "causal_model": count_only_cmod,
        "intr_var_key": "count",
        "idx_sampler": any_sampler,
        "word2id": None,
        "tokenizer": None,
        "add_bos": True,
        "add_eos": True,
    }
    if seq1_kwargs is None: seq1_kwargs = kwarg_defaults
    if seq2_kwargs is None: seq2_kwargs = seq1_kwargs
    seq1_kwargs = copy.deepcopy(seq1_kwargs)
    seq2_kwargs = copy.deepcopy(seq2_kwargs)
    for k,v in kwarg_defaults.items():
        if k not in seq1_kwargs: seq1_kwargs[k] = v
    for k,v in seq1_kwargs.items():
        if k not in seq2_kwargs: seq2_kwargs[k] = v
        elif seq2_kwargs[k] is None: seq2_kwargs[k] = v

    base_cmod = seq1_kwargs["causal_model"]
    if type(base_cmod)==str:
        base_cmod = globals()[base_cmod]
    source_cmod = seq2_kwargs["causal_model"]
    if type(source_cmod)==str:
        source_cmod = globals()[source_cmod]

    keyword = "targ_count"
    if seq1_kwargs["task_type"]=="arithmetic":
        keyword = "n_ops"
    elif seq1_kwargs["task_type"]=="induction_heads":
        raise NotImplemented

    base_holdouts = seq1_kwargs.get(
        "hold_outs",
        seq1_kwargs["task_config"].get("hold_outs", set()))
    source_holdouts = seq2_kwargs.get(
        "hold_outs",
        seq2_kwargs["task_config"].get("hold_outs", set()))

    # We delete these here because we want to sample sequences of all
    # values, but we can hold out the interchanged values
    if "hold_outs" in seq1_kwargs: del seq1_kwargs["hold_outs"]
    if "hold_outs" in seq1_kwargs["task_config"]:
        del seq1_kwargs["task_config"]["hold_outs"]
    if "hold_outs" in seq2_kwargs: del seq2_kwargs["hold_outs"]
    if "hold_outs" in seq2_kwargs["task_config"]:
        del seq2_kwargs["task_config"]["hold_outs"]

    print("Creating Base Samples")
    tconfig = seq1_kwargs.get("task_config", {})
    dkwargs = {
        "n_samples": n_samples,
        "keyword": keyword,
        "step_size": dest_incr,
        "concat": False,
        **seq1_kwargs,
        "ret_info": True,
        "min_val": tconfig.get("min_count", 1),
        "max_val": tconfig.get("max_count", 20),
    }
    # types are the values that we are systematically searching over
    og_data,og_tmasks,og_info,og_types,tokenizer=make_systematic_dataset(
        **dkwargs,)

    # TODO:
    for i in range(5):
        print(og_data[i][:40])
    print()

    if seq2_kwargs["word2id"] is None and seq2_kwargs["tokenizer"] is None:
        seq2_kwargs["tokenizer"] = tokenizer

    base_data = []
    base_idxs = []
    base_types = []
    base_tmasks = []
    source_data = []
    source_idxs = []
    source_types = []
    source_tmasks = []
    print("Creating Source Samples")
    seq2_kwargs["tokenizer"] = seq2_kwargs.get("tokenizer", None)
    og_types = og_types[keyword]
    min_base_type = 0
    for base_seq,base_tmask,base_type in zip(og_data,og_tmasks,og_types):
        base_ignores = {
            seq1_kwargs.get("pad_id", 0),
            og_info.get("bos_id",None),
            og_info.get("eos_id",None),
            *og_info.get("trigger_ids",None),
        }
        min_count = seq2_kwargs["task_config"].get("min_count",1)
        max_count = seq2_kwargs["task_config"].get("max_count",20)
        task_type = seq2_kwargs.get("task_type", "num_equivalence")
        for src_val in range(min_count, max_count+1, source_incr):
            task_config = {**seq2_kwargs["task_config"], keyword: src_val}
            sample_function = get_task_seq_fxn(
                task=task_type,
                task_config=task_config)
            sub_sample = []
            for _ in range(n_samples):
                source_seq, source_tmask, source_info = sample_function()
                seq2_kwargs["concat"] = False
                tup = preprocess_raw_samples(
                    data=[source_seq],
                    task_masks=[source_tmask],
                    info=source_info,
                    as_tensor=False,
                    **seq2_kwargs,
                )
                seqs, tmasks, source_info, tokenizer = tup
                source_seq =   seqs[0]
                source_tmask = tmasks[0]
                seq2_kwargs["tokenizer"] = tokenizer

                source_ignores = {
                    seq2_kwargs.get("pad_id", 0),
                    seq2_kwargs.get("bos_id",None),
                    seq2_kwargs.get("eos_id",None),
                    *source_info.get("trigger_ids", None),
                }

                base_vars,_ = run_til_idx(
                    base_cmod,
                    base_seq,
                    len(base_seq),
                    **og_info)
                source_vars,_ = run_til_idx(
                    source_cmod,
                    source_seq,
                    len(source_seq),
                    **source_info)
                for bi in range(0,len(base_seq),dest_step):
                    #if bi < min_base_type: continue
                    # Ignores unwanted ids
                    if base_seq[bi] in base_ignores:   continue
                    elif base_vars[bi].get("count",0) in base_holdouts:
                        continue
                    for si in range(0,len(source_seq),source_step):
                        if source_seq[si] in source_ignores: continue
                        elif source_vars[si].get("count",0) in source_holdouts:
                            continue
                        base_data.append([b for b in base_seq])
                        source_data.append([s for s in source_seq])
                        base_idxs.append(bi)
                        source_idxs.append(si)
                        # Type refers to the type of sequence
                        if base_type > bi:
                            base_types.append(bi)
                        else:
                            base_types.append(base_type)
                        source_types.append(src_val)
                        base_tmasks.append([t for t in base_tmask])
                        source_tmasks.append([t for t in source_tmask])
        min_base_type = base_type

    print("Tot Samples:", len(base_data))
    print()
    print("Creating Intervention Set")
    intr_dict = make_intervened_seqs(
            seq1s=base_data,
            seq2s=source_data,
            idx1s=base_idxs,
            idx2s=source_idxs,
            tmask1s=base_tmasks,
            tmask2s=source_tmasks,
            seq1_cmod=base_cmod,
            seq2_cmod=source_cmod,
            seq1_var=seq1_kwargs["intr_var_key"],
            seq2_var=seq2_kwargs["intr_var_key"],
            trigger_steps=trigger_steps,
            trigger_step=trigger_step,
            trigger_id=og_info["trigger_id"],
            seq1_info=og_info,
            seq2_info=source_info,)
    base_data = intr_dict["seq1s"]
    source_data = intr_dict["seq2s"]
    base_tmasks = intr_dict["tmask1s"]
    source_tmasks = intr_dict["tmask2s"]
    base_idxs = intr_dict["idx1s"]
    source_idxs = intr_dict["idx2s"]
    base_vars = intr_dict["var1s"]
    source_vars = intr_dict["var2s"]
    max_len = 0
    for b,s in zip(base_data, source_data):
        max_len = max(max_len, len(b), len(s))
    for i in range(len(base_data)):
        if len(base_data[i])<max_len:
            base_data[i] = pad_to(
                base_data[i], max_len, fill_val=seq1_kwargs["pad_id"])
            base_tmasks[i] = pad_to(
                base_tmasks[i], max_len, fill_val=0)
    for i in range(len(source_data)):
        if len(source_data[i])<max_len:
            source_data[i] = pad_to(
                source_data[i],max_len,fill_val=seq2_kwargs["pad_id"])
            source_tmasks[i] = pad_to(
                source_tmasks[i], max_len, fill_val=0)
    try:
        data_dict = {
            "base_data": torch.LongTensor(base_data),
            "base_idxs": torch.LongTensor(base_idxs),
            "base_vars": base_vars,
            "base_types": base_types,
            "base_tmasks": torch.LongTensor(base_tmasks),
            "source_data": torch.LongTensor(source_data),
            "source_idxs": torch.LongTensor(source_idxs),
            "source_vars": source_vars,
            "source_types": source_types,
            "source_tmasks": torch.LongTensor(source_tmasks),
        }
    except:
        lens = set([len(s) for s in source_data])
        print("Lens:", lens)
        assert False
    seq_len = seq1_kwargs.get("seq_len", None)
    if seq_len is not None and seq_len>0:
        data_dict["base_data"] = data_dict["base_data"][:,:seq_len]
        data_dict["base_tmasks"] = data_dict["base_tmasks"][:,:seq_len]
    seq_len = seq2_kwargs.get("seq_len", None)
    if seq_len is not None and seq_len>0:
        data_dict["source_data"] = data_dict["source_data"][:,:seq_len]
        data_dict["source_tmasks"] = data_dict["source_tmasks"][:,:seq_len]
    if ret_info:
        return data_dict, og_info
    return data_dict

def extract_metrics(seqs, idxs, vars, trg_counts, info, task_type="num_equivalence"):
    """
    Loops through the sequences and finds relevant information about them.

    Args:
        seqs: long tensor (B, S)
        idxs: long tensor (B,)
        vars: sequence of dicts (B,)
        trg_counts: sequence of ints (B,)
            the original target counts of the sequences
        info: dict
            holds general information about the sequences
    Returns:
        metrics: dict of lists
    """
    if task_type!="num_equivalence": raise NotImplemented
    metrics = {
        "targ_count": trg_counts,
        "trigger_steps": [],
        **{k:[] for k in vars[0] if not hasattr(vars[0][k],"__len__")}
    }
    tids = torch.LongTensor(list(info["trigger_ids"]))
    resp_ids = info.get("resp_ids", set())
    if len(resp_ids)==0: resp_ids = {info["resp_id"]}
    dids = torch.LongTensor(list(info["demo_ids"]))
    for b in range(len(seqs)):
        for k in vars[b]:
            metrics[k].append(vars[b][k])
        # Trigger Steps
        tsteps = 0
        if seqs[b][idxs[b]].item() not in resp_ids:
            tidx = torch.isin(seqs[b], tids).long()
            tsteps = (torch.argmax(tidx,dim=-1) - idxs[b]).item()
        metrics["trigger_steps"].append(tsteps)
    return metrics

def get_intervention_dataset(
        n_samples=1000,
        trigger_steps=3,
        seq1_kwargs={
            "task_type": "num_equivalence",
            "task_config": {
                "min_count": 1,
                "max_count": 20,
                "hold_outs": {4,9,14,17},
                "max_demo_tokens": None,
                "n_demo_types": 3,
                "multi_trigger": False,
                "pre_trigger": True,
            },
            "pad_id": 0,
            "bos_id": 1,
            "eos_id": 2,
            "seq_len": None,
            "trigger_id": None,
            "causal_model": count_only_cmod,
            "intr_var_key": "count",
            "idx_sampler": any_sampler,
            "word2id": None,
        },
        seq2_kwargs={
            "task_type": None,
            "task_config": None,
            "pad_id": None,
            "bos_id": None,
            "eos_id": None,
            "seq_len": None,
            "trigger_id": None,
            "causal_model": None,
            "intr_var_key": None,
            "idx_sampler": None,
            "word2id": None,
        },):
    """
    This function combines a number of steps to create an interchange
    dataset. Anything in seq2_kwargs that is None will default to the
    corresponding value in seq1_kwargs. seq1 is the destination, seq2
    is the source.

    Args:
        n_samples: int
            the number of data samples
        task_config: dict
            A dict of parameters for the task.
            min_count: int
                the minimum target count to be sampled
            max_count: int
                the maximum target count to be sampled
            max_demo_tokens: int
                optionally argue a maximum number of demonstration
                tokens. This only applies if using multi_trigger where
                the total number of tokens is different than the
                target quantity. Defaults to maximum
            n_demo_types: int
                the number of demonstration types (ie number of types of
                fruit if we were counting fruit).
            multi_trigger: bool
                if true, will use a conditional trigger to indicate
                which token type to count during the demonstration
                phase.
            pre_trigger: bool
                if true, will indicate at the beginning of the demonstration
                phase which token type to count. Otherwise indicates at the
                beginning of the response phase. Only applies if multi_trigger
                is true.
        trigger_steps: None or int or sequence of ints
            the number of demonstration tokens to precede the trigger token
            after an intervention in the demonstration phase. If you
            argue an int, the trigger steps will be sampled from 0 to
            that value (inclusive), if you argue a sequence, it will be
            sampled uniformly from the sequence. If None, will not
            insert a trigger token.
        seq_len: None or int
            the length of the sequences. Otherwise defaults to longest sequence
            length.
        trigger_id: int or None
            Optionally argue a trigger id to use for the whole dataset.
            Otherwise will sample uniformly from the trigger ids.
        hold_outs: set of ints
            optionally argue values that you don't want sampled as the target
            quantities.
        intr_var_key: str
            the key of the variable dict to perform the intervention.
        idx_sampler: function
            the function to sample an intervention index based on the
            tokens in the sequence.
        copy_task: bool
            if true, will return data that is a copy task rather than
            a counting task.
    Returns:
        data: dict
            "seq1s": torch.LongTensor (N,S)
                the base sequences, that will be intervened upon. The intervention
                labels follow the corresponding idx1 in the sequence.
            "seq2s": torch.LongTensor (N,S)
                the source sequences, that will be used to intervene upon the base
                sequences.
            "idx1s": torch.LongTensor (N,)
                the indices indicating which index in seq1 to intervene at
            "idx2s": torch.LongTensor (N,)
                the indices indicating which index in seq2 to harvest variables
                from.
    """
    for k,v in seq2_kwargs.items():
        if v is None: seq2_kwargs[k] = seq1_kwargs[k]
    for k,v in seq1_kwargs.items():
        if k not in seq2_kwargs: seq2_kwargs[k] = seq1_kwargs[k]

    seq1s,tmask1s,seq1_info,tokenizer = make_dataset(
        n_samples=n_samples,
        **seq1_kwargs,
        ret_info=True,)
    idx1s = [
        seq1_kwargs["idx_sampler"](
            s,
            info=seq1_info,
            causal_model=seq1_kwargs["causal_model"]) for s in seq1s
    ]
    seq2s,tmask2s,seq2_info,tokenizer = make_dataset(
        n_samples=n_samples,
        **seq2_kwargs,
        ret_info=True, )
    idx2s = [
        seq2_kwargs["idx_sampler"](
            s,
            info=seq2_info,
            causal_model=seq2_kwargs["causal_model"]) for s in seq2s
    ]
    intr_dict = make_intervened_seqs(
            seq1s=seq1s,
            seq2s=seq2s,
            idx1s=idx1s,
            idx2s=idx2s,
            tmask1s=tmask1s,
            tmask2s=tmask2s,
            seq1_cmod=seq1_kwargs["causal_model"],
            seq2_cmod=seq2_kwargs["causal_model"],
            seq1_var=seq1_kwargs["intr_var_key"],
            seq2_var=seq2_kwargs["intr_var_key"],
            trigger_steps=trigger_steps,
            trigger_id=seq1_kwargs["trigger_id"],
            seq1_info=seq1_info,
            seq2_info=seq2_info,
    )
    seq1s = intr_dict["seq1s"]
    seq2s = intr_dict["seq2s"]
    tmask1s = intr_dict["tmask1s"]
    tmask2s = intr_dict["tmask2s"]
    idx1s = intr_dict["idx1s"]
    idx2s = intr_dict["idx2s"]
    v1s = intr_dict["var1s"]
    v2s = intr_dict["var2s"]
    s1_pad_id = seq1_kwargs["pad_id"]
    s2_pad_id = seq2_kwargs["pad_id"]
    if len(seq1s[0])!=len(seq2s[0]):
        seq_len = max(len(seq1s[0]),len(seq2s[0]))
        if len(seq1s[0])<seq_len:
            for i in range(len(seq1s)):
                seq1s[i] = pad_to(seq1s[i], seq_len, fill_val=s1_pad_id)
                tmask1s[i] = pad_to(tmask1s[i], seq_len, fill_val=0)
        else:
            for i in range(len(seq2s)):
                seq2s[i] = pad_to(seq2s[i], seq_len, fill_val=s2_pad_id)
                tmask2s[i] = pad_to(tmask2s[i], seq_len, fill_val=0)

    ## TODO: REMOVE
    ##print("Trigger Steps:", trigger_step)
    #for i in range(10):
    #    print("Seq1 Idx:", idx1s[i], "--", "Seq2 Idx:", idx2s[i])
    #    print("Vars1:", v1s[i])
    #    print("Vars2:", v2s[i])
    #    print("Idxs:", 
    #        ", ".join(["{:2}".format(_) for _ in range(max(len(seq1s[i]),len(seq2s[i])))]))
    #    print("Seq1:", ", ".join(["{:2}".format(s) for s in seq1s[i]]))
    #    print("Seq2:", ", ".join(["{:2}".format(s) for s in seq2s[i]]))
    #    print()
    #assert False

    ret_dict = {
        "base_data": torch.LongTensor(seq1s),
        "source_data": torch.LongTensor(seq2s),
        "base_tmasks": torch.LongTensor(tmask1s),
        "source_tmasks": torch.LongTensor(tmask2s),
        "base_idxs": torch.LongTensor(idx1s),
        "source_idxs": torch.LongTensor(idx2s),
    }
    return ret_dict

def get_datasets(config):
    """
    This function creates a toy dataset of sequences. The sequences
    consist of a beginning of sequence token, then a starting token k
    that can take K possible values, a string of N ordered digits
    ranging somewhere in the range of 1-100, and a final output of the
    starting token k.

    Args:
        config: dict
            a dict of configuration settings

    Returns:
        tokenizer: huggingface tokenizer or dl_utils tokenizer
        train_dataset: torch Dataset
        val_dataset: torch Dataset
    """
    config["n_train_samples"] = config.get("n_train_samples", 1000)
    kwargs = {**config}
    kwargs["n_samples"] = config["n_train_samples"]
    if "task_config" in kwargs:
        kwargs["task_config"] = {**config["task_config"]}
    train_samps, train_tmasks, info, tokenizer = make_dataset(**kwargs)

    config["n_val_samples"] = config.get("n_val_samples", 100)
    kwargs["n_samples"] = config["n_val_samples"]
    kwargs["tokenizer"] = tokenizer

    for k,v in kwargs.items(): config[k] = v # update config for records
    for k,v in info.items(): config[k] = v # update config for records
    config["word2id"] = tokenizer.word2id
    config["n_tokens"] = len(config["word2id"])

    if "task_config" in kwargs:
        tconf = kwargs["task_config"]
    else:
        tconf = {**kwargs}
        kwargs["task_config"] = tconf
    min_count = tconf.get("min_count", 1)
    max_count = tconf.get("max_count", 20)
    if tconf.get("hold_outs",None) is not None and len(tconf["hold_outs"])>0:
        all_nums = set(range(min_count, max_count+1))
        tconf["hold_outs"] = all_nums - set(tconf["hold_outs"])

    # Never concat for validation (concat refers to using individual
    # samples in each data row or, if true, use multiple samples
    # concatenated together along a single row)
    kwargs["concat"] = False
    val_samps, val_tmasks, _, _ = make_dataset(**kwargs)

    if len(val_samps[0])!=len(train_samps[0]):
        print(
            "Train seqlen:", len(train_samps[0]),
            "- Val:", len(val_samps[0]))
        if len(val_samps[0])>len(train_samps[0]):
            train_samps = pad_to(train_samps, len(val_samps[0]))

    temp = {**config, **tokenizer.special_ids}
    train_dataset = CausalDataset(
        data=train_samps,
        labels=None,
        masks={"task_mask": train_tmasks},
        **temp)
    temp["concat"] = False
    val_dataset =   CausalDataset(
        data=val_samps,
        labels=None,
        masks={"task_mask": val_tmasks},
        **temp)
    config["seq_len"] = train_dataset.seq_len
    return tokenizer, train_dataset, val_dataset

if __name__=="__main__":
    kwarg_defaults={
        "task_type": "num_equivalence",
        "task_config": {
            "min_count": 1,
            "max_count": 20,
            "hold_outs": set(),
            "max_demo_tokens": None,
            "n_demo_types": 3,
            "multi_trigger": False,
            "pre_trigger": False,
        },
        "pad_id": 0,
        "bos_id": 1,
        "eos_id": 2,
        "seq_len": None,
        "trigger_ids": None,
        "trigger_id": None,
        "demo_ids": None,
        "resp_id": None,
        "causal_model": count_only_cmod,
        "intr_var_key": "count",
        "idx_sampler": any_sampler,
        "word2id": None,
        "add_bos": True,
        "add_eos": True,
    }
    # Testing
    dataset, info = make_systematic_intrv_dataset(
        n_samples=1,
        dest_incr=2,
        source_incr=4,
        trigger_steps=2,
        seq1_kwargs=kwarg_defaults,
        seq2_kwargs=kwarg_defaults,
        ret_info=True,
        )
    trig_id = info["trigger_ids"][0]
    eos_id = info["eos_id"]
    resp_id = info["resp_id"]
    print(info)

    bdata = dataset["base_data"]
    sdata = dataset["source_data"]
    bidxs = dataset["base_idxs"]
    sidxs = dataset["source_idxs"]
    bvars = dataset["base_vars"]
    svars = dataset["source_vars"]
    btargs = dataset["base_types"]
    stargs = dataset["source_types"]
    btmask = dataset["base_tmasks"]
    stmask = dataset["source_tmasks"]

    sep = ","
    stats = {
        "bcount": [],
        "scount": [],
    }
    for i in range(len(dataset["base_data"])):
        print("BIdx:", bidxs[i], " - ", "SIdx:", sidxs[i])
        print("BCount:", bvars[i]["count"],
                "- SCount:", svars[i]["count"])
        print("BTarg:", btargs[i], "- STarg:", stargs[i])
        print("Idxs:", sep.join(
            ["{:2}".format(_) for _ in range(len(bdata[i]))]))
        print("Base:", sep.join(
            ["{:2}".format(s) for s in bdata[i].tolist()]))
        print("Srce:", sep.join(
            ["{:2}".format(s) for s in sdata[i].tolist()]))
        print("Bmsk:", sep.join(
            ["{:2}".format(s) for s in btmask[i].tolist()]))
        print("Smsk:", sep.join(
            ["{:2}".format(s) for s in stmask[i].tolist()]))
        print()

        tidx = torch.argmax((bdata[i]==trig_id).long())
        demo_ids = torch.LongTensor(info["demo_ids"])
        if bidxs[i]<tidx:
            bcount = bidxs[i]
        else:
            bcount = tidx-1
        tidx = torch.argmax((sdata[i]==trig_id).long())
        scount = ((sdata[i]==resp_id)&(stmask[i].bool())).float().sum(-1)

        stats["bcount"].append(bcount.item())
        stats["scount"].append(scount.item())

        print("BCnt:", bvars[i]["count"])
        print("BTyp:", btargs[i])
        print("Calc:", bcount)
        assert bcount==btargs[i]
        print("SCnt:", svars[i]["count"])
        print("STyp:", stargs[i])
        print("Calc:", scount)
        assert scount==stargs[i]
        print()

    #import pandas as pd
    #df = pd.DataFrame(stats)
    #for i in range(len(df)):
    #    print(df.iloc[i])
