import torch
from htssr.sampling import (
    generate_all_pairs_order_samples,
    generate_all_pairs_heldout_samples,
    generate_all_pairs_free_samples,
    generate_all_pairs_free_heldout_samples,
)
from htssr.canon import max_canon_size
from htssr.grammar import expansions
from htssr.utils import unroll
from htssr.numerical import make_vec


def make_search_batch(
    exception_rolls,
    feat_domain,
    dummy=False,
    device="cpu",
):
    all_ids = unroll(exception_rolls)
    all_vals = make_vec(
        all_ids,
        feat_domain,
        dummy=dummy,
    )
    all_vals = (
        torch.from_numpy(all_vals).float().to(device)
    )
    return all_vals

def make_heldout_batch(
    exception_rolls,
    feat_domain,
    dummy=False,
    # binary=False,
    device="cpu",
):
    samples = generate_all_pairs_heldout_samples(
        exception_rolls,
        feat_domain,
        dummy=dummy,
        # binary=binary,
    )
    batch_x = torch.from_numpy(samples["vecs"]).float().to(device)
    batch_y = torch.from_numpy(samples["order"]).long().to(device)
    return batch_x, batch_y

def make_free_heldout_batch(
    exception_rolls,
    feat_domain,
    force_padding=None,
    device="cpu",
):
    samples = generate_all_pairs_free_heldout_samples(
        exception_rolls,
        feat_domain,
        force_padding=force_padding,
    )
    return free_batch_from_samples(samples, force_padding=force_padding, device=device)

def make_all_pairs_order_batch(
    ids_list,
    key_expr,
    ptrs,
    exceptions,
    bsize,
    domain,
    feat_domain,
    device="cpu",
    rollout=True,
    dummy=False,
    # binary=False,
    min_pos=1,
    max_pos=max_canon_size,
    root_src=False,
    noisy_fast=False,
    expansions=expansions,
    noise_level=0,
    test_canon=True,
):
    samples = generate_all_pairs_order_samples(
        ids_list,
        key_expr,
        ptrs,
        exceptions,
        bsize,
        domain,
        feat_domain,
        rollout=rollout,
        dummy=dummy,
        # binary=binary,
        min_pos=min_pos,
        max_pos=max_pos,
        root_src=root_src,
        noisy_fast=noisy_fast,
        expansions=expansions,
        noise_level=noise_level,
        test_canon=test_canon,
    )
    batch_x = torch.from_numpy(samples["vecs"]).float().to(device)
    batch_y = torch.from_numpy(samples["order"]).long().to(device)
    return batch_x, batch_y

def make_ids_batch(all_ids, force_padding=None):
    max_seq_len = max([len(_ids) for _ids in all_ids])
    if force_padding is not None:
        max_seq_len = max(max_seq_len, force_padding)
    padded_ids = [_ids + [-1] * (max_seq_len - len(_ids)) for _ids in all_ids]
    batch_ids = 1 + torch.tensor(padded_ids).long()
    return batch_ids

def make_all_pairs_free_batch(
    exceptions,
    size,
    feat_domain,
    min_pos=1,
    max_pos=max_canon_size,
    root_src=False,
    noisy_fast=False,
    expansions=expansions,
    force_padding=None,
    test_canon=False,
    domain=None,
    key_expr=None,
    device="cpu",
):
    samples = generate_all_pairs_free_samples(
        exceptions,
        size,
        feat_domain,
        min_pos=min_pos,
        max_pos=max_pos,
        root_src=root_src,
        noisy_fast=noisy_fast,
        expansions=expansions,
        force_padding=force_padding,
        test_canon=test_canon,
        domain=domain,
        key_expr=key_expr,
    )
    return free_batch_from_samples(samples, force_padding=force_padding, device=device)

def free_batch_from_samples(samples, force_padding=None, device="cpu"):
    batch_ids = make_ids_batch(samples["ids"], force_padding=force_padding).to(device)
    batch_parents = make_ids_batch(samples["parents"], force_padding=force_padding).to(device)
    batch_vals = torch.from_numpy(samples["vecs"]).float().to(device)
    batch_y = torch.from_numpy(samples["order"]).long().to(device)
    return batch_ids, batch_parents, batch_vals, batch_y
