import torch
from params import PARAMS
from config import CONFIGS
from args import DEVICE3, args

DEVICE = DEVICE3

L = PARAMS["post_steps"] * args.eval_ratio + PARAMS["post_steps"] - 1
STATE_DIM = 22
P = PARAMS["post_steps"]

LL = 100 * P

INITIAL_STATE = torch.zeros((STATE_DIM,), dtype=torch.float32, device=DEVICE)


def zeros(s):
    return torch.zeros(s, dtype=torch.float32, device=DEVICE)


def compose(*tensors: torch.Tensor, dim=-1):
    assert len(tensors) >= 2
    _t1 = tensors[0]
    # assert all(map(lambda t: t.shape == _t1.shape, tensors))
    assert all(map(lambda t: len(t.shape) == 2, tensors))
    return torch.cat(tensors, dim=dim)


def ones(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE)


def mask(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE) * -999


OTHER = {
    "stay-no-mask": zeros((L, STATE_DIM)),
}


THUMB = lambda pos: torch.tensor(
    [
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        0,
        1.33,
        1.57,
        1.57,
        0,
        0,
        -0.209,
        -0.698,
        pos,
    ],
    dtype=torch.float32,
    device=DEVICE,
)
SIX = torch.tensor(
    [
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        -0.349,
        0,
        0,
        0,
        0,
        0,
        -0.209,
        -0.698,
        0,
    ],
    dtype=torch.float32,
    device=DEVICE,
).unsqueeze(0)

FIRE = torch.tensor(
    [
        0,
        0,
        0,
        0,
        0,
        1.37,  # fire finger
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        0,
        1.37,
        1.57,
        1.57,
        0,
        0,
        0,
        -0.698,
        0,
    ],
    dtype=torch.float32,
    device=DEVICE,
)

FIRE_2nd_SMOOTH = lambda x: torch.tensor(
    [
        0,
        0,
        0,
        0,
        0,
        0.23 + x * (1.37 - 0.23),  # fire finger
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        0,
        1.33,
        1.57,
        1.57,
        0,
        0,
        0,
        -0.698,
        0,
    ],
    dtype=torch.float32,
    device=DEVICE,
).unsqueeze(0)

FIRE_1st_SMOOTH = lambda x, i, mask_step=1: (
    torch.tensor(
        [
            0,
            0,
            0,
            0,
            0,
            1.37,
            1.57,
            1.57,
            0,
            1.37,
            1.57,
            1.57,
            0,
            0,
            1.33,
            1.57,
            1.57,
            0,
            0,
            0,
            -0.698 + (0.698 + 0.698) * x,  # fire finger
            0,
        ],
        dtype=torch.float32,
        device=DEVICE,
    ).unsqueeze(0)
    if i % mask_step == 0
    else mask(STATE_DIM).unsqueeze(0)
)


ROCK = torch.tensor(
    [
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0,
        1.37,
        1.57,
        1.57,
        0.00712,
        0,
        1.33,
        1.57,
        1.57,
        0,
        1.22,
        0.209,
        0.6987,
        1.1,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)

SCISSOR = torch.tensor(
    [
        -0.349,
        -0.262,
        0,
        0,
        0.349,
        -0.262,
        0,
        0,
        0,
        1.37,
        1.57,
        1.57,
        0,
        0,
        1.33,
        1.57,
        1.57,
        0,
        1.22,
        0.209,
        0.6987,
        0.783,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)


PAPER = torch.tensor(
    [
        -0.349,
        -0.262,
        0,
        0,
        -0.178,
        -0.262,
        0,
        0,
        -0.0524,
        -0.262,
        0,
        0,
        0.00712,
        -0.349,
        -0.262,
        0,
        0,
        -1.05,
        0,
        -0.209,
        0.105,
        -0.262,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)


CIRCLE_1st = torch.tensor(
    [
        0,
        0.82,
        1.03,
        1.3,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        -0.0105,
        0.831,
        0,
        0.161,
        1.45,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)

CIRCLE_2nd = torch.tensor(
    [
        0,
        0.82,
        1.03,
        1.3,
        0,
        0.82,
        1.03,
        1.3,
        0,
        0.82,
        1.03,
        0.13,
        0,
        0,
        0.82,
        1.03,
        1.3,
        0.303,
        1.22,
        0,
        0.161,
        1.45,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)

CIRCLE_3rd = torch.tensor(
    [
        0,
        0,
        0,
        0,
        0,
        0.00396,
        0,
        0,
        0.349,
        0.801,
        1.03,
        1.3,
        0,
        0,
        0.82,
        1.03,
        1.3,
        0.429,
        1.22,
        0.209,
        0.328,
        1.05,
    ],
    device=DEVICE,
    dtype=torch.float32,
).unsqueeze(0)

LONG = {
    "rock_no_mask": compose(
        *(ROCK for _ in range(LL)),
        dim=0,
    ),
    "rsp": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(ROCK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(SCISSOR for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            # *(PAPER for _ in range(2 * P)),
            *(PAPER for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(10, 10 + 3 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(20 + 3 * P, 20 + 6 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    30 + 6 * P, LL + 30 + 6 * P, dtype=torch.int64, device=DEVICE
                ),
            ),
            dim=0,
        ),
    ),
    "rsp-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(ROCK for _ in range(int(P / 2))),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(SCISSOR for _ in range(int(P / 2))),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(PAPER for _ in range(int(P / 2))),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(ROCK for _ in range(int(P / 2))),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(SCISSOR for _ in range(int(P / 2))),  # 50 + 2.5P
            *(mask((1, STATE_DIM)) for _ in range(10)),
            # *(PAPER for _ in range(P / 2)),
            *(PAPER for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(10, 10 + P / 2, dtype=torch.int64, device=DEVICE),
                torch.arange(20 + P / 2, 20 + P, dtype=torch.int64, device=DEVICE),
                torch.arange(30 + P, 30 + 1.5 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    40 + 1.5 * P, 40 + 2 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    50 + 2 * P, 50 + 2.5 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    60 + 2.5 * P, LL + 60 + 2.5 * P, dtype=torch.int64, device=DEVICE
                ),
            ),
            dim=0,
        ),
    ),
    "circles": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(CIRCLE_1st for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            *(CIRCLE_2nd for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(10)),
            # *(PAPER for _ in range(2 * P)),
            *(CIRCLE_3rd for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(10, 10 + 3 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(20 + 3 * P, 20 + 6 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    30 + 6 * P, LL + 30 + 6 * P, dtype=torch.int64, device=DEVICE
                ),
            ),
            dim=0,
        ),
    ),
    "stay": zeros((LL, STATE_DIM)).cumsum(dim=0),
    **{
        f"thumb": (
            compose(
                *(mask((1, STATE_DIM)) for _ in range(19)),
                THUMB(pos).unsqueeze(0) - INITIAL_STATE,
                *(THUMB(pos).unsqueeze(0) - INITIAL_STATE for _ in range(LL - 20)),
                dim=0,
            ),
            torch.arange(19, LL),
        )
        for pos in [0]
    },
    "six": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SIX - INITIAL_STATE,
            *(SIX - INITIAL_STATE for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "six_early": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(9)),
            SIX - INITIAL_STATE,
            *(SIX - INITIAL_STATE for _ in range(LL - 10)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    f"fire": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE.unsqueeze(0) - INITIAL_STATE for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "dynamic_fire_2nd": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            dim=0,
        ),
        torch.arange(19, 200),
    ),
    "dynamic_fire_1st": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            dim=0,
        ),
        torch.arange(19, 200),
    ),
    "dynamic_fire_1st-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i, 2) for i, x in enumerate(torch.linspace(1, 0, 20))),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(0, 1, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(1, 0, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(0, 1, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(1, 0, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(0, 1, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(1, 0, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(0, 1, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(1, 0, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(0, 1, 20)),
            # *(FIRE_1st_SMOOTH(x, 2) for x in torch.linspace(1, 0, 20)),
            dim=0,
        ),
        torch.cat((torch.arange(19, 20), torch.arange(20, 200, 2)), dim=0),
    ),
}

COMMANDS = {
    **{k: v.cumsum(dim=0) for k, v in OTHER.items()},
    "thumb-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            THUMB(0).unsqueeze(0) - INITIAL_STATE,
            *(THUMB(0).unsqueeze(0) - INITIAL_STATE for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "thumb-no-mask": compose(
        *(THUMB(0).unsqueeze(0) for _ in range(LL)),
        dim=0,
    ),
    "thumb-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            THUMB(0).unsqueeze(0),
            *(mask((1, STATE_DIM)) for _ in range(19)),
            THUMB(0).unsqueeze(0),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(THUMB(0).unsqueeze(0) for _ in range(11)),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(THUMB(0).unsqueeze(0) for _ in range(11)),
            *(THUMB(0).unsqueeze(0) for _ in range(LL - 80)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(19, 20),
                torch.arange(39, 40),
                torch.arange(49, 60),
                torch.arange(69, LL),
            ),
            dim=0,
        ),
    ),
    "dynamic_fire_2nd-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(1, 0, 20)),
            *(FIRE_2nd_SMOOTH(x) for x in torch.linspace(0, 1, 20)),
            dim=0,
        ),
        torch.arange(19, 200),
    ),
    "dynamic_fire_1st-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(0, 1, 20))),
            *(FIRE_1st_SMOOTH(x, i) for i, x in enumerate(torch.linspace(1, 0, 20))),
            dim=0,
        ),
        torch.arange(19, 200),
    ),
    "fire-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0) - INITIAL_STATE,
            *(FIRE.unsqueeze(0) - INITIAL_STATE for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "fire-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0),
            *(mask((1, STATE_DIM)) for _ in range(19)),
            FIRE.unsqueeze(0),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(FIRE.unsqueeze(0) for _ in range(11)),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(FIRE.unsqueeze(0) for _ in range(11)),
            *(FIRE.unsqueeze(0) for _ in range(LL - 80)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(19, 20),
                torch.arange(39, 40),
                torch.arange(49, 60),
                torch.arange(69, LL),
            ),
            dim=0,
        ),
    ),
    "fire-no-mask": compose(
        *(FIRE.unsqueeze(0) - INITIAL_STATE for _ in range(LL)),
        dim=0,
    ),
    "six-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SIX - INITIAL_STATE,
            *(SIX - INITIAL_STATE for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "six-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SIX,
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SIX,
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(SIX for _ in range(11)),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(SIX for _ in range(11)),
            *(SIX for _ in range(LL - 80)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(19, 20),
                torch.arange(39, 40),
                torch.arange(49, 60),
                torch.arange(69, LL),
            ),
            dim=0,
        ),
    ),
    "six-no-mask": compose(
        # *(mask((1, STATE_DIM)) for _ in range(19)),
        # SIX - INITIAL_STATE,
        *(SIX - INITIAL_STATE for _ in range(LL)),
        dim=0,
    ),
    "rock-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            ROCK,
            *(ROCK for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "rock-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            ROCK,
            *(mask((1, STATE_DIM)) for _ in range(19)),
            ROCK,
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(ROCK for _ in range(11)),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(ROCK for _ in range(11)),
            *(ROCK for _ in range(LL - 80)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(19, 20),
                torch.arange(39, 40),
                torch.arange(49, 60),
                torch.arange(69, LL),
            ),
            dim=0,
        ),
    ),
    "rock-no-mask": compose(
        *(ROCK for _ in range(LL)),
        dim=0,
    ),
    "scissor-half-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SCISSOR,
            *(SCISSOR for _ in range(LL - 20)),
            dim=0,
        ),
        torch.arange(19, LL),
    ),
    "scissor-no-mask": compose(
        *(SCISSOR for _ in range(LL)),
        dim=0,
    ),
    "scissor-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SCISSOR,
            *(mask((1, STATE_DIM)) for _ in range(19)),
            SCISSOR,
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(SCISSOR for _ in range(11)),
            *(mask((1, STATE_DIM)) for _ in range(9)),
            *(SCISSOR for _ in range(11)),
            *(SCISSOR for _ in range(LL - 80)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(19, 20),
                torch.arange(39, 40),
                torch.arange(49, 60),
                torch.arange(69, LL),
            ),
            dim=0,
        ),
    ),
}
