from time import time
import numpy as np
import torch
import torch.nn as nn
from heapq import (
    heapify,
    heappush,
    heappop,
)
from htssr.primitives import (
    special_id,
    special_parameter_id,
    uops,
    bops,
    inversions
)
from htssr.utils import (
    str2ids,
    ids2key,
    vals2key,
    expand_ids,
    iterative_expand_ids,
    expand_rules,
    cheap_expand_ids,
    is_pseudo_canon,
    fast_is_pseudo_canon,
    tc_fast_eval_expr,
    verify_sol,
)
from htssr.grammar import expansions
from htssr.canon import max_canon_size
from htssr.batching import (
    make_heldout_batch,
    make_search_batch,
    make_ids_batch,
    make_free_heldout_batch,
)
from htssr.utils import (
    precedes,
    to_infix,
    to_latex,
    rolling_check_determined_expr,
)
from htssr.sampling import step_ids
from htssr.numerical import (
    make_vals,
    make_vec_from_vals,
    make_tree_vec,
)
from htssr.fitting import fit_constants


def make_learned_heuristic(
    model,
    secret,
    domain,
    feat_domain,
    # random=False,
    # dummy=False,
    # random_transform=None,
    # noise_level=0,
    device="cpu",
    tree_vec=False,
    force_padding=None,
):
    def _make_vals(all_ids):
        if tree_vec:
            bids, bparents, vals, _ = make_free_heldout_batch(
                [[_ids] for _ids in all_ids],
                feat_domain,
                force_padding=force_padding,
                device=device,
            )
        else:
            vals = make_vals(all_ids, feat_domain)
        return bids, bparents, vals

    if isinstance(secret, np.ndarray):
        secret_vals = torch.tensor(secret).float()
        if not tree_vec:
            secret_vals = secret_vals.unsqueeze(0)
        secret_vals = secret_vals.to(device)
    elif isinstance(secret, str):
        secret = str2ids(secret)
    elif isinstance(secret, torch.Tensor):
        secret_vals = secret.float()
        if not tree_vec:
            secret_vals = secret_vals.unsqueeze(0)
        secret_vals = secret_vals.to(device)
    if isinstance(secret, list):
        _, _, secret_vals = _make_vals([secret])

    def parallel_heuristic(guesses, rel_tol=None, domain_id=None, sorts=None):
        chunk_size = int(2 ** 14)
        hvals = []
        for chunk_id in range(0, len(guesses), chunk_size):
            _start = chunk_id
            _end = _start + chunk_size
            _guesses = guesses[_start : _end]
            expr_ids, parent_ids, tc_guesses_vals = _make_vals(_guesses)
            if rel_tol is None:
                slack_mask = (tc_guesses_vals[:, 0] - secret_vals).abs().sum(dim=-1) < 1e-8
                # slack_mask = (tc_guesses_vals[:, 0] - secret_vals).abs().mean(dim=-1) < 1e-6
            else:
                guess_diff = tc_guesses_vals[:, 0] - secret_vals
                # slack_mask = ((guess_diff / secret_vals).abs().max(dim=-1).values) < rel_tol
                slack_mask = (
                    (
                        (
                            (guess_diff ** 2).mean(dim=-1)
                            / (secret_vals ** 2).mean(dim=-1)
                        )
                    ) < rel_tol
                )
            neg_slack_mask = ~slack_mask
            with torch.no_grad():
                if domain_id is None:
                    _hvals = model.tgt_forward(
                        tc_guesses_vals,
                        secret_vals,
                        expr_ids,
                        parent_ids,
                    )
                else:
                    _hvals = model.tgt_forward(
                        tc_guesses_vals,
                        secret_vals,
                        expr_ids,
                        parent_ids,
                        # expert_id=domain_id,
                        domain_ids=domain_id,
                        sorts=sorts,
                    )
            _hvals = nn.functional.sigmoid(-_hvals)
            _hvals[slack_mask] = 0.0
            _hvals[neg_slack_mask] += 1e-4
            hvals.append(_hvals)
        hvals = torch.concat(hvals, dim=0)
        return hvals
    return parallel_heuristic, secret_vals

def fit_parameters(
    guesses,
    feat_domain,
    secret_vals,
    max_iter=5,
    n_inits=1,
):
    all_params = []
    for pos, _ids in enumerate(guesses):
        params, _it, err, rel_err = fit_constants(
            _ids,
            feat_domain,
            # secret_vals.to("cpu"),
            secret_vals,
            max_iter=max_iter,
            damping=1e-3,
            d_mult=10.0,
            n_inits=n_inits,
        )
        all_params.append((params, _it, err, rel_err))
    return all_params

def print_debug(pq, secret, max_dist):
    from IPython.display import display, Latex
    display(Latex(to_latex(secret)))
    for dist, ids in sorted(pq):
        if dist[0] > max_dist:
            break
        prec = precedes(ids, secret)
        prec = "\\vdash" if prec else "\\cdots"
        history = " \\qquad ".join([f"{_dist:.2f}" for _dist in dist])
        _latex = f"${prec} \\qquad {to_latex(ids)[1:-1]} \\qquad {history}$"
        display(Latex(_latex))

def print_similar(pq, secret, beam_w):
    from htssr.utils import generation_dist
    from IPython.display import display, Latex
    display(Latex(to_latex(secret)))
    _pq = [((generation_dist(ids, secret), dist[0]), ids) for dist, ids in pq]
    beam_th = sorted(pq)[:beam_w][-1]
    beam_th = beam_th[0][0]
    for dist, ids in sorted(_pq):
        prec = precedes(ids, secret)
        prec = "\\vdash" if prec else "\\cdots"
        history = " \\qquad ".join([f"{_dist:.2f}" for _dist in dist + (beam_th,)])
        _latex = f"${prec} \\qquad {to_latex(ids)[1:-1]} \\qquad {history}$"
        display(Latex(_latex))

def search_prefix(
    key_expr,
    heuristic,
    secret_vals,
    domain,
    feat_domain,
    eps=1e-8,
    max_len=max_canon_size,
    max_visited=int(1e6),
    expansions=expansions,
    debug_secret=None,
    expansion_cache=None,
    levels=1,
    penalty_factor=0.0,
    prune_dist=np.inf,
    fit_params=False,
    max_fit_iter=5,
    n_param_inits=1,
    fit_err_tol=1e-9,
    h_err_tol=None,
    dropout=0.0,
    starting_point=[[special_id]],
    test_canon=False,
    beam_w=1,
    topk_children=8,
    domain_id=None,
    sorts=None,
):
    pq = starting_point
    # pq = [((heuristic([ids])[0],), ids) for ids in pq]
    init_dists = heuristic(
        pq,
        rel_tol=h_err_tol,
        domain_id=domain_id,
        sorts=sorts,
    )
    # pq = [((heuristic([ids])[0],), ids) for ids in pq]
    pq = [((_dist, 0), ids) for ids, _dist in zip(pq, init_dists)]
    heapify(pq)
    visited = set()
    visited_ids = set()
    added = set()
    search_trace = []
    params = None
    visited_val = set()
    redundant = set()
    time_sum = 0
    num_pops = 0
    htime = 0
    ptime = 0
    ftime = 0
    cpu_secret_vals = secret_vals.to("cpu")
    loop_time = time()
    while len(pq) > 0:
        if (debug_secret is not None):
            import pdb; pdb.set_trace()
        current_popped = []
        true_beam_w = min(beam_w, len(pq))
        for _ in range(true_beam_w):
            dist, ids = heappop(pq)
            current_popped.append((dist, ids))
            num_pops += 1
        for dist, ids in current_popped:
            ft0 = time()
            if fit_params:
                params = fit_parameters(
                    [ids],
                    feat_domain,
                    cpu_secret_vals,
                    max_iter=max_fit_iter,
                    n_inits=n_param_inits,
                )[0]
                if (params[3] is not None) and (params[3] < fit_err_tol):
                    dist = (0.0,)
                params = params[0]
            ftime += time() - ft0
            ### TODO: adapt acceptance condition for the noisy case
            if dist[0] < eps:
                loop_time = time() - loop_time
                print(f"[Hit ] Loop: {loop_time:.3f}; heuristic: {htime:.3f}; fit: {ftime:.3f}; expand: {time_sum:.3f}; push: {ptime:.3f}; pops: {num_pops};")
                return ids, len(visited_ids), search_trace, params
            _hash = tuple(ids)
            visited_ids.add(_hash)
            _hash_val = ids2key(ids, domain)
            _final_hash = (tuple(sorted(_hash)), _hash_val)
            if _final_hash in redundant:
                continue
            redundant.add(_final_hash)
            if len(visited_ids) > max_visited:
                break
            if dropout > 0.0:
                if np.random.uniform() < dropout:
                    continue
            t0 = time()
            to_push = iterative_expand_ids(
                ids,
                key_expr,
                set(),
                expansions,
                domain,
                max_len,
                levels=levels,
                expansion_cache=expansion_cache,
                test_canon=test_canon,
            )
            dt = time() - t0
            time_sum += dt
            if len(to_push) > 0:
                ht0 = time()
                new_dists = heuristic(
                    to_push,
                    rel_tol=h_err_tol,
                    domain_id=domain_id,
                    sorts=sorts,
                )
                htime += time() - ht0
                chosen = torch.argsort(new_dists, descending=False)[:topk_children]
                _to_push = [to_push[_id.item()] for _id in chosen]
                _new_dists = [new_dists[_id.item()] for _id in chosen]
                pt0 = time()
                for new_ids, new_dist in zip(_to_push, _new_dists):
                    if tuple(new_ids) in visited_ids:
                        continue
                    if new_dist > prune_dist:
                        continue
                    ### TODO: penalize by number of rules applications (generation)
                    new_dist *= (1.0 + penalty_factor) ** (dist[1] + 1)
                    tids = tuple(new_ids)
                    if tids not in added:
                        # heappush(pq, ((new_dist,), new_ids))
                        heappush(pq, ((new_dist, dist[1] + 1), new_ids))
                        added.add(tids)
                ptime += time() - pt0
        if len(visited_ids) > max_visited:
            break
    loop_time = time() - loop_time
    print(f"[Miss] Loop: {loop_time:.3f}; heuristic: {htime:.3f}; fit: {ftime:.3f}; expand: {time_sum:.3f}; push: {ptime:.3f}; pops: {num_pops};")
    return None, len(visited_ids), search_trace, params

def simplify_prefix(
    model,
    key_expr,
    heuristic,
    secret_vals,
    domain,
    feat_domain,
    eps=1e-8,
    max_len=max_canon_size,
    max_visited=int(1e6),
    expansions=expansions,
    debug_secret=None,
    prune_dist=0.1,
    starting_point=[[special_id]],
    topk_children=-1,
    domain_id=None,
    force_padding=None,
    allow_simplification=True,
    device="cpu",
):
    pq = starting_point
    init_dists = heuristic(pq, domain_id=domain_id)
    subproblems = [(ids, secret_vals, heuristic) for ids, _dist in zip(pq, init_dists)]
    # pq = [((_dist, 0), (ids, secret_vals, heuristic)) for ids, _dist in zip(pq, init_dists)]
    pq = [((_dist, 0), subpos) for subpos, (ids, _dist) in enumerate(zip(pq, init_dists))]
    heapify(pq)
    visited_ids = set()
    redundant = set()
    redundant_sub = set()
    added = set()
    while len(pq) > 0:
        dist, subpos = heappop(pq)
        #######
        ids, curr_secret_vals, curr_heuristic = subproblems[subpos]
        #######
        np_secret_vals = curr_secret_vals # .cpu().detach().numpy()
        _hash = tuple(ids)
        visited_ids.add(_hash)
        _hash_val = ids2key(ids, domain)
        _final_hash = (tuple(sorted(_hash)), _hash_val)
        if _final_hash in redundant:
            continue
        redundant.add(_final_hash)
        if len(visited_ids) > max_visited:
            break
        if (debug_secret is not None):
            import pdb; pdb.set_trace()
        if dist[0] < eps:
            return ids, len(visited_ids)
        ####### Check for simplification
        if allow_simplification:
            sub_sol = None
            sub_visited = None
            if dist[0] <= prune_dist:
                if (ids[0] in uops) and (ids[0] in inversions):
                    subproblem_ids = ids[1:]
                    pos0, pos1 = 1, len(ids) - 1
                    subproblem_eval = inversions[ids[0]](curr_secret_vals)
                    # simplifications.append((subproblem_ids, subproblem_eval, ids, dist[0]))
                    new_heuristic, _ = make_learned_heuristic(
                        model,
                        subproblem_eval,
                        domain,
                        feat_domain,
                        tree_vec=True,
                        force_padding=force_padding,
                        device=device,
                    )
                    _new_ids = subproblem_ids
                    new_max_len = max_len - (len(ids) - len(_new_ids))
                    new_max_len = max(1, new_max_len)
                    sub_sol, sub_visited = simplify_prefix(
                        model,
                        key_expr,
                        new_heuristic,
                        subproblem_eval,
                        domain,
                        feat_domain,
                        eps=eps,
                        max_len=new_max_len,
                        max_visited=max_visited,
                        expansions=expansions,
                        debug_secret=None,
                        prune_dist=prune_dist,
                        starting_point=[_new_ids],
                        topk_children=topk_children,
                        domain_id=domain_id,
                        force_padding=force_padding,
                        allow_simplification=allow_simplification,
                        device=device,
                    )
                    # _new_dist = new_heuristic([subproblem_ids], domain_id=domain_id)[0]
                    # new_subproblem = (subproblem_ids, subproblem_eval, new_heuristic)
                    # subproblems.append(new_subproblem)
                    # heappush(
                    #     pq,
                    #     (
                    #         (_new_dist, dist[1] + 1),
                    #         len(subproblems) - 1,
                    #     ),
                    # )
                    # continue
                elif (ids[0] in bops) and (ids[0] in inversions):
                    ids_eval, checks, subtrees_pos, subtrees_evals = (
                        rolling_check_determined_expr(ids, feat_domain)
                    )
                    if checks == (False, True):
                        pos0, pos1 = subtrees_pos[0]
                        # subproblem_eval = (
                        #     inversions[ids[0]][1](
                        #         curr_secret_vals,
                        #         torch.from_numpy(subtrees_evals[1]).float().to(device),
                        #     )
                        # )
                        subproblem_eval = (
                            inversions[ids[0]][1](
                                curr_secret_vals,
                                subtrees_evals[1],
                            )
                        )
                        # simplifications.append((ids[pos0 : (pos1 + 1)], subproblem_eval, ids, dist[0]))
                        new_heuristic, _ = make_learned_heuristic(
                            model,
                            subproblem_eval,
                            domain,
                            feat_domain,
                            tree_vec=True,
                            force_padding=force_padding,
                            device=device,
                        )
                        _new_ids = ids[pos0 : (pos1 + 1)]
                        new_max_len = max_len - (len(ids) - len(_new_ids))
                        new_max_len = max(1, new_max_len)
                        sub_sol, sub_visited = simplify_prefix(
                            model,
                            key_expr,
                            new_heuristic,
                            subproblem_eval,
                            domain,
                            feat_domain,
                            eps=eps,
                            max_len=new_max_len,
                            max_visited=max_visited,
                            expansions=expansions,
                            debug_secret=None,
                            prune_dist=prune_dist,
                            starting_point=[_new_ids],
                            topk_children=topk_children,
                            domain_id=domain_id,
                            force_padding=force_padding,
                            allow_simplification=allow_simplification,
                            device=device,
                        )
                        # _new_ids = ids[pos0 : (pos1 + 1)]
                        # _new_dist = new_heuristic([_new_ids], domain_id=domain_id)[0]
                        # new_subproblem = (_new_ids, subproblem_eval, new_heuristic)
                        # subproblems.append(new_subproblem)
                        # heappush(
                        #     pq,
                        #     (
                        #         (_new_dist, dist[1] + 1),
                        #         len(subproblems) - 1,
                        #     ),
                        # )
                        # continue
                    elif checks == (True, False):
                        pos0, pos1 = subtrees_pos[1]
                        # subproblem_eval = (
                        #     inversions[ids[0]][0](
                        #         curr_secret_vals,
                        #         torch.from_numpy(subtrees_evals[0]).float().to(device),
                        #     )
                        # )
                        subproblem_eval = (
                            inversions[ids[0]][0](
                                curr_secret_vals,
                                subtrees_evals[0],
                            )
                        )
                        # simplifications.append((ids[pos0 : (pos1 + 1)], subproblem_eval, ids, dist[0]))
                        new_heuristic, _ = make_learned_heuristic(
                            model,
                            subproblem_eval,
                            domain,
                            feat_domain,
                            tree_vec=True,
                            force_padding=force_padding,
                            device=device,
                        )
                        _new_ids = ids[pos0 : (pos1 + 1)]
                        new_max_len = max_len - (len(ids) - len(_new_ids))
                        new_max_len = max(1, new_max_len)
                        sub_sol, sub_visited = simplify_prefix(
                            model,
                            key_expr,
                            new_heuristic,
                            subproblem_eval,
                            domain,
                            feat_domain,
                            eps=eps,
                            max_len=new_max_len,
                            max_visited=max_visited,
                            expansions=expansions,
                            debug_secret=None,
                            prune_dist=prune_dist,
                            starting_point=[_new_ids],
                            topk_children=topk_children,
                            domain_id=domain_id,
                            force_padding=force_padding,
                            allow_simplification=allow_simplification,
                            device=device,
                        )
                        # _new_ids = ids[pos0 : (pos1 + 1)]
                        # _new_dist = new_heuristic([_new_ids], domain_id=domain_id)[0]
                        # new_subproblem = (_new_ids, subproblem_eval, new_heuristic)
                        # subproblems.append(new_subproblem)
                        # heappush(
                        #     pq,
                        #     (
                        #         (_new_dist, dist[1] + 1),
                        #         len(subproblems) - 1,
                        #     ),
                        # )
                        # continue
                if sub_sol is not None:
                    sol = ids[:pos0] + sub_sol + ids[(pos1 + 1):]
                    verif = verify_sol(
                        sol,
                        feat_domain,
                        curr_secret_vals,
                        atol=1e-6,
                        rtol=1e-3,
                    )
                    if not verif:
                        continue
                    print(f"Seed: {to_infix(ids)}; New: {to_infix(sol)}")
                    return sol, len(visited_ids) + sub_visited
        to_push = iterative_expand_ids(
            ids,
            key_expr,
            set(),
            expansions,
            domain,
            max_len,
            levels=1,
            expansion_cache=None,
            test_canon=False,
        )
        if len(to_push) > 0:
            new_dists = curr_heuristic(to_push, domain_id=domain_id)
            chosen = torch.argsort(new_dists, descending=False)[:topk_children]
            _to_push = [to_push[_id.item()] for _id in chosen]
            _new_dists = [new_dists[_id.item()] for _id in chosen]
            for new_ids, new_dist in zip(_to_push, _new_dists):
                if tuple(new_ids) in visited_ids:
                    continue
                if new_dist > prune_dist:
                    continue
                tids = tuple(new_ids)
                if tids not in added:
                    new_subproblem = (new_ids, curr_secret_vals, curr_heuristic)
                    subproblems.append(new_subproblem)
                    heappush(
                        pq,
                        (
                            (new_dist, dist[1] + 1),
                            len(subproblems) - 1,
                        ),
                    )
                    added.add(tids)
    return None, len(visited_ids)
