import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
from dl_utils.utils import arglast, get_mask_past_id, get_mask_between
import copy

DEVICES = { -1: "cpu", **{i:i for i in range(10)} }

def device_fxn(device):
    if device<0: return "cpu"
    return device

def choice(seq):
    """
    Helper function that has similar functionality to numpy's
    choice function. This one is faster than numpy's version,
    however, because we make the assumption that seq is a
    1D list.
    
    Args:
        seq: sequence of objects
    Returns:
        val: object
            a uniformly sampled value from seq
    """
    return seq[np.random.randint(len(seq))]

def pretty_string(arr):
    return ", ".join(["{:2}".format(s) for s in arr])

def get_counts(src,trigger_ids,eos_id,fill_id=None,last_occurence=False):
    """
    Finds and returns the number of tokens between the first occurance
    of the eos_id and the last occurance of the trigger id before the
    eos_id.

    Args:
        src: torch long tensor (..., S)
        trigger_ids: int or set or list or Long Tensor (K,)
        eos_id: int
        fill_id: None or int
            an id to ignore when finding the count
        last_occurence: bool
            if true, will use the last occurence of the trigger id
            instead of the first occurence.
    """
    if type(trigger_ids)==int:
        trigger_ids = torch.LongTensor([trigger_ids])
    elif type(trigger_ids) in {set, list}:
        trigger_ids = torch.LongTensor(list(trigger_ids))
    trigger_ids = trigger_ids.to(device_fxn(src.get_device()))
    src = copy.deepcopy(src.data)
    src[:,-1] = eos_id
    mask = get_mask_past_id(src, id_=eos_id)
    src[mask] = 0
    src_eoss = torch.argmax((src==eos_id).long(), dim=-1)
    src_triggers = arglast(
        torch.isin(src, trigger_ids).long(), dim=-1)
    if fill_id is not None:
        mask = get_mask_between(src.shape,src_triggers,src_eoss)
        src_fills = ((src==fill_id)&mask).long().sum(-1)
    else: src_fills = torch.zeros_like(src_triggers)
    src_counts = src_eoss-src_triggers-src_fills-1
    return src_counts

def check_correct_count(
        preds,
        targs,
        trigger_id,
        eos_id,
        sep_id=None,
        fill_id=None,
        last_occurence=False):
    """
    Determines if the predicted count is the same as the target count.

    Args:
        preds: torch tensor (B,S)
        targs: torch tensor (B,S)
        trigger_id: int or list like or None
            the id of the trigger token. Also can argue list or set of
            trigger ids.
        eos_id: int or None
            the id of the end of sequence token
        sep_id: int or None
            optionally argue the sep_id. if the sep_id is not none, it is used
            as a separator between individual number predictions.
        fill_id: int or None
            optionally argue an id that should be excluded from calculating
            the count.
        last_occurence: bool
            if true, will use the last occurence of the trigger_id as
            the start of the count
    Returns:
        corrects: torch bool tensor
            a tensor where true means that the predicted count was correct
    """
    targ_counts = get_counts(
        targs, trigger_ids=trigger_id,
        eos_id=eos_id, fill_id=fill_id,
        last_occurence=last_occurence,
    )

    # Ensure trigger token at correct index in preds
    preds = copy.deepcopy(preds.data)
    device = preds.get_device()
    if device<0: device = "cpu"
    if type(trigger_id)==int:
        trigger_id = [trigger_id]
    tid = torch.LongTensor(trigger_id)
    idx = torch.isin(targs, tid.to(device))
    preds[idx] = trigger_id[0]

    pred_counts = get_counts(
        preds, trigger_ids=trigger_id,
        eos_id=eos_id, fill_id=fill_id,
        last_occurence=last_occurence,
    )
    eos_match = torch.any((preds==eos_id)&(targs==eos_id),dim=-1)
    corrects = (targ_counts==pred_counts)&eos_match
    return corrects

def get_combos(d, combo=None, keys=None, idx=0, combos=None):
    """
    Recursively finds all combinations of hyperparameters.
    Args:
        d: dict
        combo: dict
            the current combo
    """
    if keys is None:
        keys = list(d.keys())
        print(keys)
    if combos is None:
        combos = []
    if combo is None:
        combo = dict()
    if idx>=len(keys):
        combos.append(combo)
        return combos
    k = keys[idx]
    for v in d[k]:
        get_combos(d, combo={k:v, **combo}, keys=keys, idx=idx+1, combos=combos)
    return combos

def run_for_n_steps(
        causal_model,
        inpt_id,
        n_steps=None,
        stop_ids={0},
        variables=None,
        trigger_steps=None,
        trigger_step=-1,
        trigger_id=None,
        inpt_tmask=0,
        #demo_ids=None,
        **kwargs):
    """
    Will run the causal_model level model for the argued number of steps.
    Optionally argue a trigger idx to intervene on the causal
    model and switch phases.
    
    Args:
        causal_model: python function
        inpt_id: int
            the first input id
        n_steps: int
            the number of steps to run the causal model.
        stop_ids: int
            optionally argue a set of ids that if produced by the
            model will stop the sequence generation.
        variables: int
        trigger_step: int
            optionally argue a number of steps after which to insert a
            trigger id if intervening into the demo phase. This allows
            you to change phases in the rollout.
            The trigger_id will be located at the argued value
            of trigger_steps in the returned sequence. If negative, will
            not insert a trigger token.
        trigger_steps: None or int or sequence of ints
            optionally sample the trigger step. will uniformly sample
            from the sequence of ints or in the range from 0 to the argued
            int (inclusive). If none, leaves trigger_step unchanged.
            arguing something to trigger_steps will overwrite the value
            of trigger_step.
        trigger_id: int
            the desired trigger id.
        inpt_tmask: int
            the first value of the trigger mask
    Returns:
        seq: list of ints (n_steps+1,)
            the sequence after having run the causal
            model for n_steps. We include the input token
            which results in a length of n_steps+1
    """
    if n_steps is None:
        assert stop_ids is not None
        n_steps = np.inf
    if trigger_step is None:
        trigger_step = -1
    if trigger_steps is not None:
        if type(trigger_steps)==int:
            trigger_step = int(np.random.randint(trigger_steps+1))
        else: trigger_step = choice(list(trigger_steps))
    if variables is None: variables = dict()
    seq = []
    tmask = []
    vbls = []
    i = -1
    while i < n_steps:
        i+=1
        if variables and variables["phase"]==0 and i==trigger_step:
            if trigger_id is None:
                if "trigger_id" in variables and\
                        variables["trigger_id"] in kwargs["trigger_ids"]:
                    trigger_id = variables["trigger_id"]
                else:
                    trigs = kwargs["trigger_ids"]
                    trigger_id = trigs[np.random.randint(len(trigs))]
            inpt_id = trigger_id
            inpt_tmask = 0
        seq.append(inpt_id)
        tmask.append(inpt_tmask)
        inpt_id, variables = causal_model(
            inpt_id=inpt_id, 
            variables=variables,
            **kwargs)
        vbls.append({**variables})
        inpt_tmask = variables.get("mask",1)
        if inpt_id in stop_ids: break
    return seq, tmask, vbls

def run_til_idx(causal_model, seq, idx=None, **kwargs):
    """
    Will run the causal model on the sequence until and
    including the argued idx in the argued seq. This allows us
    to find the high level variables at the argued index.
    
    Args:
        causal_model: python function
        seq: list of ints
            sequence to use as input to the causal model.
        idx: int or None
            index after which to stop running the causal model. The
            returned variables will include computations from the argued
            idx. Optionally argue None to go the full length of the
            argued seq.
        kwargs:
            pad_id: int
               the padding token id
            bos_id: int
               the beginning of sequence token id
            eos_id: int
               the end of sequence token id
            resp_id: int
               the response id
            trigger_ids: list of ints
               all possible trigger ids. If this list is greater than 1
               which will occur if using multi_trigger setting, the
               order of the list will correspond to the order of the
               demo type ids that each trigger refers to.
            demo_ids: list of ints
                the possible demonstration token ids. The length of this
                list will be equal to n_demo_types and the order will
                correspond to the appropriate trigger_id index if using
                multi_trigger setting.
            pre_trigger: bool
                if true, will indicate at the beginning of the demonstration
                phase which token type to count. Otherwise indicates at the
                beginning of the response phase. Only applies if multi_trigger
                is true.
    Returns:
        variables: dict
        output_seq: list of ints
            the output of the causal model. Will be of length idx+1
    """
    variables = None
    if idx is None: idx = len(seq)-1
    output_seq = []
    var_seq = []
    for i in range(min(idx+1,len(seq))):
        tid, variables = causal_model(
            inpt_id=seq[i],
            variables=variables,
            **kwargs)
        if seq[i] in kwargs["trigger_ids"]:
            output_seq[-1] = seq[i]
        output_seq.append(tid)
        var_seq.append(variables)
    return var_seq, output_seq

def print_raw_tensors(sep=",", **tensors):
    """
    prints 1d tensors

    tensors: dict
        tensor_name: tensor (S,)
    """
    idx_name = "Idxs"
    l = max([len(t) for t in tensors.values()])
    n = max([len(idx_name)] + [len(k) for k in tensors])
    print("{:n}:".format(idx_name),
        sep.join(["{:2}".format(_) for _ in range(l)]))
    for name,tensor in tensors.items():
        print("{:n}:".format(name),
            sep.join(["{:2}".format(t) for t in tensor.cpu().tolist()]))

def print_tensor(tensor, n=2):
    print(",".join(["{:n}".format(s) for s in tensor.tolist()]))

def print_tensors(targ, pred, mask, tokenizer):
    """
    args:
        targ: tensor (S,)
        pred: tensor (S,)
        mask: tensor (S,)
        tokenizer: Tokenizer
    Returns:
        logstr: str
            a concatenated string
    """
    logstr = ""
    pad = tokenizer.pad_token

    if tokenizer is not None:
        trg = [
          t[:2] for t in tokenizer.ids_to_toks(ids=targ.cpu().data.tolist())[0]
        ]
    else:
        trg = [
          str(t)[:2] for t in targ.cpu().data.tolist()
        ]
    s = "Trg: " + " ".join([f"{t:2}" for t in trg])
    s = s.replace(pad, "_")
    logstr += s+"\n"
    print(s)

    if tokenizer is not None:
        prd = [
          p[:2] for p in tokenizer.ids_to_toks(ids=pred.cpu().data.tolist())[0]
        ]
    else:
        prd = [
          str(t)[:2] for t in pred.cpu().data.tolist()
        ]
    s = "Prd: " + " ".join([f"{p:2}" for p in prd])
    s = s.replace(pad, "_")
    logstr += s+"\n"
    print(s)

    msk = [str(m) for m in mask[1:].cpu().data.long().tolist()]
    s = "Msk: " + " ".join([f"{m:2}" for m in msk])
    s = s.replace(pad, "_")
    logstr += s+"\n"
    print(s)

    return logstr

