"""
This task is defined by a sequence of arithmetic operations. You can
choose which operations are permitted out of +, -, *, and /.

The sequences are of the following forms where the first number after
the comma is the result of the previous step in case 1. and we have the
choice to ignore PEMDAS in case 2.
    1:
        1 + 3 - 2 = 2 , * 4 + 5 = 13 , + 10 = 23

    or

    2:
        1 + 3 - 2 * 4 + 5 + 10 = 23
"""
import numpy as np
try:
    from .utils import sep_digit_fxn
except:
    from utils import sep_digit_fxn

def arithmetic(
        max_val=100,
        min_val=-100,
        max_new=9,
        n_ops=10,
        sep_every=None,
        blank_state=False,
        use_pemdas=False,
        ops={"+", "-", "*"}, # available: +, -, *, /
        round_to=3,
        sep_digits=True,
        reverse_digits=False,
        numeral_base=10,
    ):
    """
    Creates a single arithmetic sequence. If division is included in
    the available ops, the sequence will round the result to round the
    values to `round_to` decimals.

    Args:
        max_val: int (inclusive)
            the maximum allowed cumulative value of the arithmetic.
            i.e. the cumulative sum will never exceed a value of max_val
        min_val: int (inclusive)
            the minimum allowed cumulative value of the arithmetic.
        max_new: int (exclusive)
            the maximum primitive to be used in the sequence. i.e. the
            sequence will never use a number larger than the max_new
            other than when seq_every is true and the current step in the
            sequence is the cumulative sum.
        n_ops: int
            the number of operations in the sequence.
        sep_every: None or int
            if None, the cumulative value will not be reported till the
            very end of the sequence. If an int is argued, the cumulative
            value will be printed every sep_every operations.
        blank_state: bool
            if true, will be the same as using `sep_every` except that
            the printed cumulative state will be an empty, blank
            character and the tmask will be zero. This allows us to
            control for the possibility that merely having extra spaces
            to "reason" is advantageous for transformers.
        use_pemdas: bool
            if true, each subsegment will follow the conventional
            order of operations PEMDAS, otherwise will simply apply
            each successive operation to the cumulative value.
        ops: set of str
            the set of potential math operations.
        round_to: int
            if using division, you can argue a value to round to.
        sep_digits: bool
            if true, will always separate individual digits as individual
            elements in the list. Otherwise keeps numerals together as
            a single token. i.e. for the number 102, we can do
                sep_digits==True: ["1","0","2"]
                sep_digits==False: ["102"]
        reverse_digits: bool
            if true and sep_digits is true, will reverse the ordering
            of the digits so that the ones place comes first, followed
            by 10s, etc. i.e. for 102, we will get ["2", "0", "1"]
        numeral_base: int
            optionally change the base of the number system
    Returns:
        seq: list of str
            the values and the operations
        task_mask: list of int
            a mask where 0 denotes randomly sampled values and 1s
            denote predictable values.
    """
    max_val = max(0,max_val)
    min_val = min(0,min_val)
    ops = list(ops)
    comma = ","
    equals = "="
    blank = "?"
    info = {
        "op_ids": [o for o in ops],
        "equals_id": equals,
        "comma_id": comma,
        "blank_id": blank,
    }
    cumu_val = np.random.randint(max_new)
    # Separate or don't separate the digits of the new number
    if sep_digits:
        seq = sep_digit_fxn(
            n_ops,reverse_digits=reverse_digits,base=numeral_base)
    else:
        seq = [str(n_ops)]
    seq += ["ops", str(cumu_val)]
    task_mask = [0 for _ in seq] # tracks which indices should be predictable
    pemdas_str = seq[-1]
    op_count = 0
    for i in range(n_ops):
        # Value needs to be reduced
        if cumu_val>=max_val:
            temp = [o for o in ops if o in {"-","/"}]
            op = temp[np.random.randint(len(temp))]
            new_val = np.random.randint(max_new)
        # Value needs to be increased
        elif cumu_val<=min_val:
            temp = [o for o in ops if o in {"+","/",}]
            op = temp[np.random.randint(len(temp))]
            new_val = np.random.randint(max_new)
        # Value can be whatever
        else:
            op = ops[np.random.randint(len(ops))]
            if op=="*":
                if cumu_val<0:
                    temp = int(min(np.abs(min_val/cumu_val),max_new))
                elif cumu_val==0:
                    temp = max_new
                else:
                    temp = int(min(max_val/cumu_val,max_new))
                new_val = np.random.randint(1,max(2,temp))
            elif op=="+":
                new_val = np.random.randint(int(min(max_new,max_val-cumu_val)))
            elif op=="-":
                new_val = np.abs(cumu_val-min_val)
                new_val = np.random.randint(int(
                    min(max_new,max(new_val,1))))
            else:
                new_val = np.random.randint(1,max_new)
        if not use_pemdas:
            cumu_val = eval(str(cumu_val)+op+str(new_val))
        pemdas_str += op+str(new_val)
        seq.append(op)
        op_count += 1

        # Separate or don't separate the digits of the new number
        if sep_digits:
            new_val = sep_digit_fxn(
                new_val,reverse_digits=reverse_digits,base=numeral_base)
            seq += new_val
            task_mask += [0]+[0 for _ in new_val]
        else:
            seq.append(str(new_val))
            task_mask += [0,0]

        # Output a running total
        if sep_every and op_count==sep_every and i<n_ops-1:
            op_count = 0
            seq.append(equals)
            task_mask.append(0)
            if use_pemdas:
                cumu_val = eval(pemdas_str)
                pemdas_str = str(cumu_val)
            if type(cumu_val)==float: cumu_val = round(cumu_val, round_to)
            if blank_state:
                seq.append(blank)
                task_mask.append(0)
            elif sep_digits:
                new_val = sep_digit_fxn(
                    cumu_val,
                    reverse_digits=reverse_digits,
                    base=numeral_base)
                seq += new_val
                task_mask += [1 for _ in new_val]
            else:
                seq.append(str(cumu_val))
                task_mask.append(1)
            seq.append(comma)
            task_mask.append(int(not blank_state))
    seq.append(equals)
    task_mask.append(0)
    if use_pemdas: cumu_val = eval(pemdas_str)
    if type(cumu_val)==float: cumu_val = round(cumu_val, round_to)
    if sep_digits:
        new_val = sep_digit_fxn(
            cumu_val, reverse_digits=reverse_digits, base=numeral_base)
        seq += new_val
        task_mask += [1 for _ in new_val]
    else:
        seq.append(str(cumu_val))
        task_mask.append(1)
    return seq, task_mask, info

def sample_arithmetic(
        max_val=100,
        min_val=-100,
        max_new=10,
        max_ops=20,
        min_ops=1,
        sep_every=None,
        use_pemdas=False,
        ops={"+", "-", "*"}, # available: +, -, *, /
        round_to=3,
        hold_outs=None,
        sep_digits=True,
        reverse_digits=False,
        n_ops=None,
        blank_state=False,
        *args,**kwargs):
    """
    Use this function to sample an arithmetic sequence.
        
    Args:
        max_ops: int (inclusive)
            the maximum number of operations allowed in a single sequence
        min_ops: int (inclusive)
            the minimum number of operations allowed in a single sequence
        sep_digits: bool
            if true, will always separate individual digits as individual
            elements in the list. Otherwise keeps numerals together as
            a single token. i.e. for the number 102, we can do
                sep_digits==True: ["1","0","2"]
                sep_digits==False: ["102"]
        reverse_digits: bool
            if true and sep_digits is true, will reverse the ordering
            of the digits so that the ones place comes first, followed
            by 10s, etc. i.e. for 102, we will get ["2", "0", "1"]
        n_ops: none or int
            optionally argue the number of operations instead of
            sampling it.
        blank_state: bool
            if true, will be the same as using `sep_every` except that
            the printed cumulative state will be an empty, blank
            character and the tmask will be zero. This allows us to
            control for the possibility that merely having extra spaces
            to "reason" is advantageous for transformers.
    Returns:
        seq: list of str
            the sequence of strings
        task_mask: list of ints
            0s denote unpredictable tokens, 1s denote predictable tokens
    """
    if n_ops is None:
        n_ops = np.random.randint(min_ops, max_ops+1)
    return arithmetic(
        max_val=max_val,
        min_val=min_val,
        max_new=max_new,
        n_ops=n_ops,
        sep_every=sep_every,
        blank_state=blank_state,
        use_pemdas=use_pemdas,
        ops=ops,
        round_to=round_to,
        sep_digits=sep_digits,
        reverse_digits=reverse_digits,
    )


if __name__=="__main__":
    max_val=20
    min_val=0
    max_new=10
    max_ops=20
    min_ops=1
    sep_every=0
    blank_state=True
    use_pemdas=False
    ops={"+", "-", "*"} # available: +, -, *, /
    round_to=3
    for i in range(10):
        arith, mask, info = sample_arithmetic(
            n_ops=5,
            max_val=max_val,
            min_val=min_val,
            max_new=max_new,
            max_ops=max_ops,
            min_ops=min_ops,
            sep_every=sep_every,
            blank_state=blank_state,
            use_pemdas=use_pemdas,
            ops=ops, # available: +, -, *, /
            round_to=round_to,
        )
        print(" ".join([str(a)[:1] for a in arith]))
        print(" ".join([str(m) for m in mask]))
        print()


