from collections import deque
from heapq import (
    heapify,
    heappush,
    heappop,
)
import numpy as np
import torch
from htssr.primitives import (
    vocab,
    bijection,
    functions,
    tc_functions,
    constants,
    variables,
    special_id,
    special_parameter_id,
    special_symbol,
    arities,
    uops,
    bops,
    latexes,
    nesting_rules,
    complexity_rules,
)

def to_ids(tokens):
    return [bijection[tok] for tok in tokens]

def to_tokens(ids):
    return [bijection[idx] for idx in ids]

def is_valid(ids):
    cum = 0
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        cum -= (arities[idx] - 1)
    return cum == 1

# BUG
def to_infix(ids):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append(bijection[idx])
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(f"{bijection[idx]}({top})")
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            # stack.append(f"({top0}){bijection[idx]}({top1})")
            _spacing = " " * (int(np.log2(max(len(top0), len(top1)))) + 1)
            stack.append(f"({top0}){_spacing}{bijection[idx]}{_spacing}({top1})")
        else:
            return None
    assert len(stack) == 1
    return stack[0]

def to_infix_params(ids, params):
    stack = []
    special_count = sum([_id == special_parameter_id for _id in ids])
    assert special_count == len(params)
    param_ptr = 0
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx == special_parameter_id:
            stack.append(f"{params[param_ptr].item():.7f}")
            param_ptr += 1
        elif arities[idx] == 0:
            stack.append(bijection[idx])
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(f"{bijection[idx]}({top})")
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            # stack.append(f"({top0}){bijection[idx]}({top1})")
            _spacing = " " * (int(np.log2(max(len(top0), len(top1)))) + 1)
            stack.append(f"({top0}){_spacing}{bijection[idx]}{_spacing}({top1})")
        else:
            return None
    assert len(stack) == 1
    return stack[0]

def to_latex(ids):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append(latexes[idx]())
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(latexes[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append(latexes[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    stack = f"${stack[0]}$"
    # for var in variables:
    #     _var = bijection[var]
    #     stack = stack.replace(f"({_var})", _var)
    # for ctt in constants:
    #     _ctt = bijection[ctt]
    #     stack = stack.replace(f"({_ctt})", _ctt)
    return stack

def _fast_eval_expr(ids, x_data):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            if idx in [0, 1]:
                stack.append(idx)
            else:
                stack.append(x_data)
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(functions[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append(functions[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    ans = stack[0] + np.zeros(x_data.shape)
    return ans

def fast_eval_expr(ids, var_data):
    """
    ids: expression (prefix) ids
    var_data: dict var_token -> var_values
    """
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx in constants:
            stack.append(functions[idx]())
        elif idx in variables:
            stack.append(var_data[bijection[idx]])
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(functions[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append(functions[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    ans = stack[0] + np.zeros(len(var_data[special_symbol]))
    return ans

def rolling_fast_eval_expr(ids, var_data):
    """
    ids: expression (prefix) ids
    var_data: dict var_token -> var_values
    """
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx in constants:
            if idx == special_parameter_id:
                stack.append(tc_functions[idx]())
            else:
                stack.append(functions[idx]())
        elif idx in variables:
            stack.append(var_data[bijection[idx]])
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(functions[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append(functions[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    ans = stack[0] + np.zeros(len(var_data[special_symbol]))
    return ans

def rolling_tree_eval_expr(ids, var_data):
    """
    ids: expression (prefix) ids
    var_data: dict var_token -> var_values
    """
    stack = []
    evals = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx in constants:
            if idx == special_parameter_id:
                stack.append(tc_functions[idx]())
            else:
                stack.append(functions[idx]())
        elif idx in variables:
            stack.append(var_data[bijection[idx]])
        elif arities[idx] == 1:
            top = stack.pop()
            evals.append(top)
            stack.append(functions[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            evals.append(top0)
            evals.append(top1)
            stack.append(functions[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    evals.append(stack[0])
    evals = evals[::-1]
    shaper = np.zeros(len(var_data[special_symbol]))
    evals = np.array([_eval + shaper for _eval in evals])
    return evals

def rolling_check_determined_expr(ids, var_data):
    """
    ids: expression (prefix) ids
    var_data: dict var_token -> var_values
    """
    stack = []
    checks = None
    subtrees_pos = None
    subtrees_evals = None
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx in constants:
            if idx == special_parameter_id:
                stack.append((tc_functions[idx](), False, pos, pos))
            else:
                stack.append((functions[idx](), True, pos, pos))
        elif idx in variables:
            stack.append((var_data[bijection[idx]], (idx != special_id), pos, pos))
        elif arities[idx] == 1:
            top = stack.pop()
            if pos == 0:
                checks = (top[1],)
                subtrees_pos = ((top[2], top[3]),)
                subtrees_evals = (top[0],)
            stack.append((functions[idx](top[0]), top[1], pos, top[3]))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            if pos == 0:
                checks = (top0[1], top1[1])
                subtrees_pos = ((top0[2], top0[3]), (top1[2], top1[3]))
                subtrees_evals = (top0[0], top1[0])
            stack.append(
                (
                    functions[idx](top0[0], top1[0]),
                    top0[1] and top1[1],
                    pos,
                    max(top0[3], top1[3])
                )
            )
        else:
            return None
    assert len(stack) == 1
    assert checks is not None
    assert subtrees_pos is not None
    assert subtrees_evals is not None
    ans = stack[0][0] + np.zeros(len(var_data[special_symbol]))
    return ans, checks, subtrees_pos, subtrees_evals

def tc_fast_eval_expr(ids, var_data, params):
    """
    ids: expression (prefix) ids
    var_data: dict var_token -> var_values
    params: torch.Tensor (requires_grad=True)
    """
    stack = []
    special_count = sum([_id == special_parameter_id for _id in ids])
    assert special_count == len(params)
    param_ptr = 0
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if idx == special_parameter_id:
            stack.append(params[param_ptr])
            param_ptr += 1
        elif idx in constants:
            stack.append(tc_functions[idx]())
        elif idx in variables:
            stack.append(
                torch.from_numpy(var_data[bijection[idx]]).float()
            )
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append(tc_functions[idx](top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append(tc_functions[idx](top0, top1))
        else:
            return None
    assert len(stack) == 1
    ans = stack[0] + torch.zeros(len(var_data[special_symbol]))
    return ans

def get_tree_parents(ids):
    """
    ids: expression (prefix) ids
    """
    stack = []
    parents = [0] * len(ids)
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append(pos)
        elif arities[idx] == 1:
            top = stack.pop()
            parents[top] = pos
            stack.append(pos)
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            parents[top0] = pos
            parents[top1] = pos
            stack.append(pos)
        else:
            return None
    if len(stack) != 1:
        return None
    return parents

def _ids2key(ids, domain):
    key = fast_eval_expr(ids, domain[[0, 10, 20, 30, 40]])
    key = tuple(np.round(key, decimals=6))
    return key

def vals2key(vals):
    vals = vals[[-30, -20, -10, 0, 10, 20, 30]]
    vals = 1e6 * np.round(vals, decimals=6)
    key = tuple(int(val) for val in vals)
    return key

def ids2key(ids, domain):
    key = fast_eval_expr(ids, domain)
    key = tuple(int(val) for val in 1e6 * np.round(key, decimals=6))
    return key

def size_lex_order(ids0, ids1):
    if len(ids0) < len(ids1):
        return True
    if len(ids0) == len(ids1):
        if ids0 < ids1:
            return True
    return False

def size_order(ids0, ids1):
    if len(ids0) <= len(ids1):
        return True
    return False

from random import randint

def pick_op(ops):
    return ops[randint(0, len(ops) - 1)]

### The code below needs more testing

def str2ids(expr):
    ids = to_ids(expr.split())
    return ids

def ids2tree(ids):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append((idx,))
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append((idx, top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            # # Treat commutative operators
            # if idx in cops:
            #     top0, top1 = min(top0, top1), max(top0, top1)
            stack.append((idx, top0, top1))
        else:
            return None
    if len(stack) != 1:
        return None
    return stack[0]

def crude_ids2tree(ids):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append((idx,))
        elif arities[idx] == 1:
            top = stack.pop()
            stack.append((idx, top))
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            stack.append((idx, top0, top1))
        else:
            return None
    if len(stack) != 1:
        return None
    return stack[0]

def tree2ids(t):
    if len(t) == 1:
        return [t[0]]
    ids = [t[0]]
    for subt in t[1:]:
        ids += tree2ids(subt)
    return ids

def tree_complexity(tree):
    """
    Number of transformations from "x" to build the tree.
    """
    if len(tree) == 0:
        return 0
    if tree[0] == special_id:
        return 0
    comp = 1
    for child in tree[1:]:
        comp += tree_complexity(child)
    return comp

def _generation_dist(s, t):
    """
        s = ids2tree(s)
        t = ids2tree(t)
    """
    if s[0] != t[0]:
        return tree_complexity(s) + tree_complexity(t)
    # dist = 0
    # for schild, tchild in zip(s[1:], t[1:]):
    #     dist += _generation_dist(schild, tchild)
    # - 1 x
    # - x 1
    if s[0] in uops:
        dist = _generation_dist(s[1], t[1])
    # elif s[0] in bops: ### Commutative operators
    #     comb_dists = [
    #         (
    #             _generation_dist(s[1], t[1])
    #             + _generation_dist(s[2], t[2])
    #         ),
    #         (
    #             _generation_dist(s[1], t[2])
    #             + _generation_dist(s[2], t[1])
    #         ),
    #     ]
    #     dist = min(comb_dists)
    # elif (s[0] in bops) and (s[0] not in cops): ### Non-commutative operators
    elif s[0] in bops:
        dist = (
            _generation_dist(s[1], t[1])
            + _generation_dist(s[2], t[2])
        )
    else:
        dist = 0
    return dist

def generation_dist(s, t):
    return _generation_dist(ids2tree(s), ids2tree(t))

def _nesting_levels(t, ops):
    if len(t) == 0:
        return 0
    levels = (
        int(t[0] in ops)
        + max(
            [_nesting_levels(child, ops) for child in t[1:]] + [0]
        )
    )
    return levels

def nesting_levels(ids, ops):
    t = ids2tree(ids)
    levels = _nesting_levels(t, ops)
    return levels

def root_check_nesting(ids, nesting_rules):
    _root = ids[0]
    if _root in nesting_rules:
        levels = nesting_levels(ids, [_root])
        level_test = (levels <= nesting_rules[_root])
        return level_test
    return True

def complete_check_nesting(ids, nesting_rules):
    ops = set(ids) & set(nesting_rules.keys())
    for op in ops:
        levels = nesting_levels(ids, [op])
        if levels > nesting_rules[op]:
            return False
    return True

def tree_check_complexity(ids, complexity_rules):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            new_counter = np.zeros(len(vocab))
            #######
            new_counter[idx] = 1
            #######
            stack.append(new_counter)
        elif arities[idx] == 1:
            top = stack.pop()
            if idx in complexity_rules:
                rules = complexity_rules[idx]
                if any((rules - top) < 0):
                    return False
            top[idx] += 1
            stack.append(top)
        elif arities[idx] == 2:
            top0 = stack.pop()
            top1 = stack.pop()
            new_counter = top0 + top1
            if idx in complexity_rules:
                rules = complexity_rules[idx]
                if any((rules - new_counter) < 0):
                    return False
            new_counter[idx] += 1
            stack.append(new_counter)
        else:
            return False
    return True

# Test with "-"; (1 - x), (x - 1)
def precedes(s, t, debug=False):
    """
    Returns s <= t (partial order of expression generation)
    """
    if len(s) > len(t):
        return False
    s, t = ids2tree(s), ids2tree(t)
    cs, ct = tree_complexity(s), tree_complexity(t)
    if cs > ct:
        return False
    dist = _generation_dist(s, t)
    ans = (dist == ct - cs)
    return ans

def subtrees_ptrs(ids):
    stack = []
    subtrees = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append(pos)
            subtrees.append((pos, pos))
        elif arities[idx] == 1:
            top = stack[-1]
            subtrees.append((pos, top))
        elif arities[idx] == 2:
            stack.pop()
            top = stack[-1]
            subtrees.append((pos, top))
        else:
            return None
    return subtrees

def canon_subtrees_test(ids, domain, canon_ids):
    stack = []
    for pos in range(len(ids) - 1, -1, -1):
        idx = ids[pos]
        if arities[idx] == 0:
            stack.append(pos)
        elif arities[idx] == 1:
            top = stack[-1]
            sub_ids = ids[pos : (top + 1)]
            test = maybe_canon(sub_ids, domain, canon_ids)
            if not test:
                return False
        elif arities[idx] == 2:
            stack.pop()
            top = stack[-1]
            sub_ids = ids[pos : (top + 1)]
            test = maybe_canon(sub_ids, domain, canon_ids)
            if not test:
                return False
        else:
            return False
    return True

def maybe_canon(ids, domain, canon_ids):
    key = ids2key(ids, domain)
    if key in canon_ids:
        return (canon_ids[key] == ids)
    return True

def is_canon(ids, domain, canon_ids):
    key = ids2key(ids, domain)
    if key in canon_ids:
        return (canon_ids[key] == ids)
    return False

def is_pseudo_canon(
    ids,
    domain,
    canon_ids,
    nesting_rules=nesting_rules,
    complexity_rules=complexity_rules,
):
    ###
    # nesting_test = root_check_nesting(ids, nesting_rules)
    # if not nesting_test:
    #     return False
    complexity_test = tree_check_complexity(ids, complexity_rules)
    if not complexity_test:
        return False
    ###
    key = ids2key(ids, domain)
    if key in canon_ids:
        return (canon_ids[key] == ids)
    t = crude_ids2tree(ids)
    # # Check commutative binary operator
    # if (t[0] in bops) and (t[0] in cops):
    #     # if tree2ids(t[1]) > tree2ids(t[2]):
    #     #     return False
    #     order_test = size_lex_order(
    #         tree2ids(t[1]),
    #         tree2ids(t[2]),
    #     )
    #     if not order_test:
    #         return False
    for subt in t[1:]:
        subids = tree2ids(subt)
        if not is_pseudo_canon(subids, domain, canon_ids):
            return False
        # Check if parent expression is redundant
        if ids2key(subids, domain) == key:
            return False
    return True

def fast_is_pseudo_canon(
    ids,
    domain,
    canon_ids,
    nesting_rules=nesting_rules,
):
    sub_ptrs = subtrees_ptrs(ids)
    for left, right in sub_ptrs[::-1]:
        sub_ids = ids[left : (right + 1)]
        nesting_test = root_check_nesting(sub_ids, nesting_rules)
        if not nesting_test:
            return False
        if not maybe_canon(sub_ids, domain, canon_ids):
            return False
    return True

# TODO: debug recursive expansion; x * (x - (1 / x))
def expand_ids(
    ids,
    key_expr,
    visited_val,
    expansions,
    domain,
    max_len,
    levels=1,
    test_canon=True,
    nesting_rules=nesting_rules,
    complexity_rules=complexity_rules,
):
    if levels < 1:
        return []
    expanded = []
    for pos, idx in enumerate(ids):
        if idx != special_id:
            continue
        for exp in expansions:
            new_ids = ids[:pos] + exp + ids[(pos + 1):]
            if len(new_ids) > max_len:
                continue
            ### Only work with canonical forms
            # WARNING: e.g (x) - (x), even if kept, might get very low score
            #          expand the expr one level more, adding just the "good" children to the stack
            _key = ids2key(new_ids, domain)
            if test_canon:
                canon_test = is_pseudo_canon(
                    new_ids,
                    domain,
                    key_expr,
                )
            else:
                # canon_test = True
                #######
                # canon_test = root_check_nesting(new_ids, nesting_rules)
                # canon_test = complete_check_nesting(new_ids, nesting_rules)
                canon_test = tree_check_complexity(new_ids, complexity_rules)
            if canon_test and (not _key in visited_val):
                expanded.append(new_ids)
            else:
                expanded += expand_ids(
                    new_ids,
                    key_expr,
                    visited_val,
                    expansions,
                    domain,
                    max_len,
                    levels=(levels - 1),
                )
    # TODO: implement with set instead of list
    copy_expanded = []
    for _ids in expanded:
        if _ids not in copy_expanded:
            copy_expanded.append(_ids)
    return copy_expanded

def iterative_expand_ids(
    ids,
    key_expr,
    pre_visited,
    expansions,
    domain,
    max_len,
    levels=2,
    expansion_cache=None,
    test_canon=True,
    nesting_rules=nesting_rules,
    complexity_rules=complexity_rules,
):
    cache_ids = tuple(ids)
    expanded = []
    # stack = [(0, ids)]
    # heapify(stack)
    stack = deque()
    stack.append((0, ids))
    visited = set()
    while len(stack) > 0:
        # ids, gen = stack.pop()
        # ids, gen = stack.pop(0)
        # gen, ids = heappop(stack)
        gen, ids = stack.popleft()
        if gen > levels:
            continue
        tids = tuple(ids)
        if tids in visited:
            continue
        visited.add(tids)
        if expansion_cache is not None:
            if tids in expansion_cache:
                expanded += expansion_cache[tids]
                continue
        ### TODO: use pre-computed expansion from [special_id] up to depth k
        for pos, idx in enumerate(ids):
            if idx != special_id:
                continue
            for exp in expansions:
                new_ids = ids[:pos] + exp + ids[(pos + 1):]
                if len(new_ids) > max_len:
                    continue
                if tuple(new_ids) in pre_visited:
                    continue
                if tuple(new_ids) in visited:
                    continue
                if test_canon:
                    canon_test = is_pseudo_canon(
                        new_ids,
                        domain,
                        key_expr,
                    )
                    # canon_test = fast_is_pseudo_canon(
                    #     new_ids,
                    #     domain,
                    #     key_expr,
                    # )
                else:
                    # canon_test = True
                    #######
                    # canon_test = root_check_nesting(new_ids, nesting_rules)
                    # canon_test = complete_check_nesting(new_ids, nesting_rules)
                    canon_test = tree_check_complexity(new_ids, complexity_rules)
                if canon_test:
                    expanded.append(new_ids)
                else:
                    # stack.append((new_ids, (gen + 1)))
                    # heappush(stack, (gen + 1, new_ids))
                    stack.append((gen + 1, new_ids))
    if expansion_cache is not None:
        expansion_cache[cache_ids] = expanded
    return expanded

def expand_rules(
    key_expr,
    expansions,
    domain,
    max_len,
    levels=2,
    test_canon=True,
):
    src = [special_id]
    expanded = iterative_expand_ids(
        src,
        key_expr,
        set(),
        expansions,
        domain,
        max_len,
        levels=levels,
        expansion_cache=None,
        test_canon=test_canon,
    )
    return expanded

def cheap_expand_ids(
    ids,
    key_expr,
    expansions,
    domain,
    max_len,
):
    expanded = []
    canon_tests = []
    for pos, idx in enumerate(ids):
        if idx != special_id:
            continue
        for exp in expansions:
            new_ids = ids[:pos] + exp + ids[(pos + 1):]
            if len(new_ids) > max_len:
                continue
            _key = ids2key(new_ids, domain)
            canon_test = is_pseudo_canon(
                new_ids,
                domain,
                key_expr,
            )
            expanded.append(new_ids)
            canon_tests.append(canon_test)
    return expanded, canon_tests

def top_down_is_pseudo_canon(
    ids,
    domain,
    canon_ids,
    expansions,
):
    pcanon_test = is_pseudo_canon(
        ids,
        domain,
        canon_ids,
    )
    if not pcanon_test:
        return False
    queue = [[special_id]]
    while len(queue) > 0:
        current = queue.pop(0)
        pcanon_test = is_pseudo_canon(
            current,
            domain,
            canon_ids,
        )
        if not pcanon_test:
            return False
        expanded = expand_ids(
            current,
            canon_ids,
            set(),
            expansions,
            domain,
            len(ids),
            levels=2,
        )
        for _ids in expanded:
            if precedes(_ids, ids):
                queue.append(_ids)
    return True

def unroll(rolls):
    all_ids = []
    for roll in rolls:
        all_ids += roll
    return all_ids

def verify_sol(ids, feat_domain, secret_vals, atol=1e-6, rtol=1e-4):
    _eval = rolling_tree_eval_expr(ids, feat_domain)[0]
    is_sol = np.allclose(secret_vals, _eval, atol=atol, rtol=rtol)
    return is_sol
