import torch
from termcolor import colored, cprint  # noqa:F401

torch.set_printoptions(edgeitems=30, linewidth=256)


def print_row(row, color=None, w=16, title="...", title_w=16, false_value=None, true_value=None, mask=None, mask_false_color="dark_grey"):
    if isinstance(row, torch.Tensor):
        if row.dtype in [torch.bfloat16, torch.float32, torch.float16]:
            row = row.to(torch.float32).numpy().round(4)

        row = row.tolist()

    if color is not None:
        cprint(f"{title + ':':<{title_w}}", color, end=' ')
    else:
        print(f"{title + ':':<{title_w}}", end=' ')

    for j, x in enumerate(row):
        if false_value is not None and not x:
            print(f"{false_value:<{w}}", end=' ')
            continue
        if true_value is not None and x:
            print(f"{true_value:<{w}}", end=' ')
            continue

        if isinstance(x, float):
            x = round(x, 4)
        if isinstance(x, str):
            x = x[:w]

        if color is None:
            print(f"{x:<{w}}", end=' ')
            continue

        if mask is not None:
            if not mask[j]:
                cprint(f"{x:<{w}}", mask_false_color, end=' ')
                continue
        cprint(f"{x:<{w}}", color, end=' ')

    print()


def vis_cycle(cycle, tokenizer):
    draft_tokens_decoded = tokenizer.convert_ids_to_tokens(cycle['draft_tokens'])
    target_tokens_decoded = tokenizer.convert_ids_to_tokens(cycle['target_argmax'])
    print_row(draft_tokens_decoded, "cyan", title="draft_tokens", mask=cycle['draft_tokens'] == cycle['target_argmax'][:-1], mask_false_color="yellow")
    print_row(target_tokens_decoded, "blue", title="target_tokens")

    print_row(cycle['draft_probas'], "cyan", title="draft_probas")
    print_row(cycle['target_probas'], "blue", title="target_probas")
    print_row(cycle['target_max'], "dark_grey", title="target_max")

    print_row(cycle['adjusted_probas'], "green", title="adj_probas", mask=cycle['accept_mask'], mask_false_color="red")
    print_row(cycle['accept_mask'], "dark_grey", title="accept_mask")

    accepted = cycle['accept_mask'].cumprod(dim=-1)
    print_row(accepted, None, title="acceptance", false_value="--", true_value="accepted")

    false_positive = torch.logical_and((cycle['draft_tokens'] != cycle['target_argmax'][:-1]), accepted)
    if false_positive.any():
        print_row(false_positive * 1111, "magenta", title="false_positive", false_value="")

    false_negative = torch.logical_and((cycle['draft_tokens'] == cycle['target_argmax'][:-1]), ~cycle['accept_mask'])
    if false_negative.any():
        print_row(false_negative * 2222, 'red', title="false_negative", false_value="")
    print()
