from collections import defaultdict as dd
from random import shuffle, randint
from htssr.canon import max_canon_size
from htssr.utils import (
    is_pseudo_canon,
    expand_ids,
    unroll,
    root_check_nesting,
    complete_check_nesting,
    tree_check_complexity,
    get_tree_parents,
)
from htssr.primitives import (
    bops,
    special_id,
    nesting_rules,
    complexity_rules,
)
from htssr.grammar import expansions
from htssr.numerical import make_vec, make_tree_vec
from htssr.categorical import make_cls, _make_cls


def sample_ids(
    ids_list,
    ptrs,
    min_pos=1,
    max_pos=max_canon_size,
):
    max_pos = min(len(ptrs), max_pos)
    #######
    min_pos = min(min_pos, max_pos)
    #######
    ptr_pos = randint(min_pos, max_pos)
    ids_pos = randint(ptrs[ptr_pos][0], ptrs[ptr_pos][1])
    ids = ids_list[ids_pos]
    return ids

def _sample_ids_except(
    ids_list,
    ptrs,
    exceptions,
    min_pos=1,
    max_pos=max_canon_size,
):
    ids = sample_ids(
        ids_list,
        ptrs,
        min_pos=min_pos,
        max_pos=max_pos,
    )
    while (tuple(ids) in exceptions):
        ids = sample_ids(
            ids_list,
            ptrs,
            min_pos=min_pos,
            max_pos=max_pos,
        )
    return ids

# def is_all_commut(ids):
#     for idx in ids:
#         if (idx in bops) and (idx not in cops):
#             return False
#     return True

def sample_ids_except(
    ids_list,
    ptrs,
    exceptions,
    min_pos=1,
    max_pos=max_canon_size,
):
    ids = sample_ids(
        ids_list,
        ptrs,
        min_pos=min_pos,
        max_pos=max_pos,
    )
    # while (tuple(ids) in exceptions) or (not is_all_commut(ids)):
    while tuple(ids) in exceptions:
        ids = sample_ids(
            ids_list,
            ptrs,
            min_pos=min_pos,
            max_pos=max_pos,
        )
    return ids

def step_ids(ids, expansions):
    xpos = [idx for idx in range(len(ids)) if ids[idx] == special_id]
    if len(xpos) == 0:
        return None
    exp_pos = xpos[randint(0, (len(xpos) - 1))]
    exp = expansions[randint(0, (len(expansions) - 1))]
    new_ids = ids[:exp_pos] + exp + ids[(exp_pos + 1):]
    return new_ids

def expensive_step_ids(
    ids,
    key_expr,
    expansions,
    domain,
    max_len,
):
    expanded = expand_ids(
        ids,
        key_expr,
        set(),
        expansions,
        domain,
        max_len,
        levels=1,
    )
    if len(expanded) == 0:
        return None
    new_ids = expanded[randint(0, (len(expanded) - 1))]
    return new_ids

def linear_rollout_except(
    src,
    key_expr,
    exceptions,
    domain,
    max_len,
    expansions=expansions,
    test_canon=True,
):
    assert max_len > 0
    if len(src) > max_len:
        return []
    def _ids_check(ids):
        ids_test = (
            (ids is not None)
            and (len(new_ids) <= max_len)
            and (tuple(new_ids) not in exceptions)
        )
        return ids_test
    roll = [src]
    max_tries = 3 * max_len
    ### x -> x * x -> x * (x + x) -> x * (x + (-x)) -> x * (x + (1 / (-x))) -> ...
    ### TODO: debug expansion; x * (x + (1 / (-x)))
    for _ in range(max_tries):
        if special_id not in roll[-1]:
            break
        node = roll[-1]
        new_ids = step_ids(node, expansions)
        if not _ids_check(new_ids):
            continue
        if test_canon:
            pcanon_test = is_pseudo_canon(
                new_ids,
                domain,
                key_expr,
            )
            if pcanon_test:
                roll.append(new_ids)
        else:
            # nesting_test = root_check_nesting(new_ids, nesting_rules)
            #######
            # nesting_test = complete_check_nesting(new_ids, nesting_rules)
            nesting_test = tree_check_complexity(new_ids, complexity_rules)
            if nesting_test:
                roll.append(new_ids)
    return roll

def free_rollout_except(
    src,
    exceptions,
    max_len,
    expansions=expansions,
    nesting_rules=nesting_rules,
    complexity_rules=complexity_rules,
    test_canon=False,
    domain=None,
    key_expr=None,
):
    assert max_len > 0
    if len(src) > max_len:
        return []
    def _ids_check(ids):
        ids_test = (
            (ids is not None)
            and (len(new_ids) <= max_len)
            and (tuple(new_ids) not in exceptions)
        )
        return ids_test
    roll = [src]
    if test_canon:
        max_tries = 3 * max_len
    else:
        max_tries = max_len
    for _ in range(max_tries):
        node = roll[-1]
        new_ids = step_ids(node, expansions)
        if not _ids_check(new_ids):
            continue
        #######
        # nesting_test = complete_check_nesting(new_ids, nesting_rules)
        nesting_test = tree_check_complexity(new_ids, complexity_rules)
        #######
        if not nesting_test:
            continue
        if test_canon:
            canon_test = is_pseudo_canon(
                new_ids,
                domain,
                key_expr,
            )
        roll.append(new_ids)
    return roll

def uniform_sample_ids(
    ids_list,
    ptrs,
    exceptions,
    size,
    min_pos=1,
    max_pos=max_canon_size,
):
    all_ids = [
        sample_ids_except(
            ids_list,
            ptrs,
            exceptions,
            min_pos=min_pos,
            max_pos=max_pos,
        )
        for _ in range(size)
    ]
    mask_pos = np.ones((len(all_ids), len(all_ids)))
    mask_neg = mask_pos
    return all_ids, mask_pos, mask_neg

def disjunct_linear_rolls(
    ids_list,
    key_expr,
    ptrs,
    exceptions,
    domain,
    size,
    min_pos=1,
    max_len=max_canon_size,
    root_src=False,
    expansions=expansions,
    test_canon=True,
):
    get_src = (
        lambda: [special_id] if root_src else (
            sample_ids_except(
                ids_list,
                ptrs,
                exceptions,
                min_pos=min_pos,
                max_pos=max_len,
            )
        )
    )
    rolls = []
    total_size = 0
    # TODO: make this loop concurrent
    while total_size < size:
        roll = []
        while len(roll) < 2:
            roll = linear_rollout_except(
                get_src(),
                key_expr,
                exceptions,
                domain,
                max_len,
                expansions=expansions,
                test_canon=test_canon,
            )
        roll = roll[:(size - total_size)]
        total_size += len(roll)
        rolls.append(roll)
    return rolls

def disjunct_free_rolls(
    exceptions,
    size,
    min_pos=1,
    max_len=max_canon_size,
    root_src=False,
    expansions=expansions,
    test_canon=False,
    domain=None,
    key_expr=None,
):
    def get_src():
        if root_src:
            return [special_id]
        _roll = free_rollout_except(
            [special_id],
            exceptions,
            max_len,
            expansions=expansions,
            test_canon=test_canon,
            domain=domain,
            key_expr=key_expr,
        )
        _ids = _roll[randint(0, (len(_roll) - 1))]
        return _ids
    rolls = []
    total_size = 0
    # TODO: make this loop concurrent
    while total_size < size:
        roll = []
        while len(roll) < 2:
            roll = free_rollout_except(
                get_src(),
                exceptions,
                max_len,
                expansions=expansions,
                test_canon=test_canon,
                domain=domain,
                key_expr=key_expr,
            )
        roll = roll[:(size - total_size)]
        total_size += len(roll)
        rolls.append(roll)
    return rolls

def disjunct_mask(rolls):
    def _make_tri(num):
        tri = np.eye(num)
        tri_ = np.eye(num - 1)
        tri[:-1, 1:] += tri_
        tri[1:, :-1] += tri_
        return tri
    all_ids = unroll(rolls)
    nids = len(all_ids)
    mask_pos = np.zeros((nids, nids), dtype=int)
    mask_neg = np.zeros((nids, nids), dtype=int)
    start = 0
    bases = []
    for roll in rolls:
        bases.append(start + int(len(roll) > 1))
        end = start + len(roll)
        mask_pos[start: end, start: end] = _make_tri(len(roll))
        start = end
    for idx0 in range(len(bases)):
        for idx1 in range(idx0 + 1, len(bases)):
            mask_neg[bases[idx0], bases[idx1]] = 1.0
            mask_neg[bases[idx1], bases[idx0]] = 1.0
    return all_ids, mask_pos, mask_neg

def make_exceptions(
    ids_list,
    key_expr,
    ptrs,
    domain,
    nexc=100,
    min_pos=6,
    max_pos=max_canon_size,
    expansions=expansions,
    test_canon=True,
):
    from random import seed
    from time import time
    seed(1968)
    rolls = disjunct_linear_rolls(
        ids_list,
        key_expr,
        ptrs,
        set(),
        domain,
        nexc,
        min_pos=min_pos,
        max_len=max_pos,
        root_src=False,
        expansions=expansions,
        test_canon=test_canon,
    )
    exc = unroll(rolls)
    exc = [tuple(exc_) for exc_ in exc]
    exc = set(exc)
    seed(time())
    return exc, rolls

def generate_all_pairs_order_samples(
    ids_list,
    key_expr,
    ptrs,
    exceptions,
    size,
    domain,
    feat_domain,
    rollout=True,
    dummy=False,
    min_pos=1,
    max_pos=max_canon_size,
    root_src=False,
    noisy_fast=False,
    expansions=expansions,
    noise_level=0,
    test_canon=True,
):
    if rollout:
        rolls = disjunct_linear_rolls(
            ids_list,
            key_expr,
            ptrs,
            exceptions,
            domain,
            size,
            min_pos=min_pos,
            max_len=max_pos,
            root_src=root_src,
            expansions=expansions,
            test_canon=test_canon,
        )
        cls = make_cls(rolls, noisy_fast=noisy_fast)
        all_ids = unroll(rolls)
    else:
        all_ids, _, _ = uniform_sample_ids(
            ids_list,
            exceptions,
            size,
            min_pos=min_pos,
            max_pos=max_pos,
        )
        cls = make_cls([all_ids], noisy_fast=False)
    all_vals = make_vec(
        all_ids,
        feat_domain,
        dummy=dummy,
        noise_level=noise_level,
    )
    samples = {
        "vecs": all_vals,
        "order": cls,
    }
    return samples

def generate_all_pairs_free_samples(
    exceptions,
    size,
    feat_domain,
    min_pos=1,
    max_pos=max_canon_size,
    root_src=False,
    noisy_fast=False,
    expansions=expansions,
    force_padding=None,
    test_canon=False,
    domain=None,
    key_expr=None,
):
    rolls = disjunct_free_rolls(
        exceptions,
        size,
        min_pos=min_pos,
        max_len=max_pos,
        root_src=root_src,
        expansions=expansions,
        test_canon=test_canon,
        domain=domain,
        key_expr=key_expr,
    )
    cls = make_cls(rolls, noisy_fast=noisy_fast)
    all_ids = unroll(rolls)
    parents = [get_tree_parents(_ids) for _ids in all_ids]
    all_vals = make_tree_vec(
        all_ids,
        feat_domain,
        force_padding=force_padding,
    )
    samples = {
        "vecs": all_vals,
        "ids": all_ids,
        "parents": parents,
        "order": cls,
    }
    return samples

def generate_all_pairs_heldout_samples(
    exception_rolls,
    feat_domain,
    dummy=False,
):
    all_ids = unroll(exception_rolls)
    all_vals = make_vec(
        all_ids,
        feat_domain,
        dummy=dummy,
    )
    cls = make_cls(exception_rolls, noisy_fast=False)
    samples = {
        "vecs": all_vals,
        "order": cls,
    }
    return samples

def generate_all_pairs_free_heldout_samples(
    exception_rolls,
    feat_domain,
    force_padding=None,
):
    all_ids = unroll(exception_rolls)
    parents = [get_tree_parents(_ids) for _ids in all_ids]
    all_vals = make_tree_vec(
        all_ids,
        feat_domain,
        force_padding=force_padding,
    )
    cls = make_cls(exception_rolls, noisy_fast=False)
    samples = {
        "ids": all_ids,
        "parents": parents,
        "vecs": all_vals,
        "order": cls,
    }
    return samples
