import torch
from params import PARAMS
from config import CONFIGS
from args import DEVICE3, args
import math

DEVICE = DEVICE3

L = PARAMS["post_steps"] * args.eval_ratio + PARAMS["post_steps"] - 1
STATE_DIM = 5
P = PARAMS["post_steps"]

LL = 100 * P

INITIAL_STATE = torch.zeros((5,), dtype=torch.float32, device=DEVICE)


def zeros(s):
    return torch.zeros(s, dtype=torch.float32, device=DEVICE)


def compose(*tensors: torch.Tensor, dim=-1):
    assert len(tensors) >= 2
    _t1 = tensors[0]
    # assert all(map(lambda t: t.shape == _t1.shape, tensors))
    assert all(map(lambda t: len(t.shape) == 2, tensors))
    return torch.cat(tensors, dim=dim)


def ones(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE)


def mask(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE) * -999


OTHER = {
    "stay-no-mask": zeros((L, STATE_DIM)),
}


REACHER_RIGHT = torch.tensor(
    [0, 0, 1.57, 0, 0], dtype=torch.float32, device=DEVICE
).unsqueeze(0)
REACHER_LEFT = torch.tensor(
    [0, 0, -1.57, 0, 0], dtype=torch.float32, device=DEVICE
).unsqueeze(0)
REACHER_MASK = torch.cat(
    (
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
        torch.ones((20,), dtype=torch.long, device=DEVICE),
        torch.zeros((20,), dtype=torch.long, device=DEVICE),
    ),
    dim=0,
)

FRONT = torch.tensor(
    [torch.pi / 4, 0, 1.57, -1.57, 0], dtype=torch.float32, device=DEVICE
).unsqueeze(0)
BACK = torch.tensor(
    [-torch.pi / 4, 0, -1.57, 1.57, 0], dtype=torch.float32, device=DEVICE
).unsqueeze(0)
ORIGIN = torch.tensor([0, 0, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(0)

HANG_1st = torch.tensor(
    # [1.57, 1.57, -1.57, 0, 0], dtype=torch.float32, device=DEVICE
    [0, 0, 1.57, -1.57, 0],
    dtype=torch.float32,
    device=DEVICE,
).unsqueeze(0)
HANG_2nd = torch.tensor(
    [0, 0, -1.57, 1.57, 0], dtype=torch.float32, device=DEVICE
).unsqueeze(0)


def get_reacher():
    return compose(
        *(
            torch.tensor(
                [0, 0, torch.pi / 2 * x / 19, 0, 0], dtype=torch.float32, device=DEVICE
            ).unsqueeze(0)
            for x in range(20)
        ),
        *(
            torch.tensor(
                [0, 0, torch.pi / 2, 0, 0], dtype=torch.float, device=DEVICE
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2 * x / 19, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(20)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2 + torch.pi * x / 39, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(40)
        ),
        *(
            torch.tensor(
                [torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [torch.pi / 2 + -torch.pi * x / 39, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2 + torch.pi * x / 39, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(40)
        ),
        *(
            torch.tensor(
                [torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [torch.pi / 2 + -torch.pi * x / 39, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(40)
        ),
        *(
            torch.tensor(
                [-torch.pi / 2 + torch.pi * x / 39, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for x in range(40)
        ),
        *(
            torch.tensor(
                [torch.pi / 2, 0, torch.pi / 2, 0, 0],
                dtype=torch.float32,
                device=DEVICE,
            ).unsqueeze(0)
            for _ in range(80)
        ),
        dim=0,
    )


def get_swing(a: float = 0.52):
    return compose(
        # *(mask((1, STATE_DIM)) for _ in range(19)),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, -a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, -a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, -a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, -a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, -a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        (torch.arange(20, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        *(
            torch.tensor(
                [0, a, 0.0, 0.0, 0],
                device=DEVICE,
                dtype=torch.float32,
            ).unsqueeze(0)
            for _ in range(3 * P)
        ),
        (torch.arange(20, 0, -1, device=DEVICE) / 20)
        .unsqueeze(1)
        .repeat_interleave(STATE_DIM, dim=1)
        * torch.tensor([0, a, 0, 0, 0], dtype=torch.float32, device=DEVICE).unsqueeze(
            0
        ),
        dim=0,
    )


LONG = {
    "stay": zeros((LL, STATE_DIM)).cumsum(dim=0),
    "front_back": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(ORIGIN for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(30, 30 + 3 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(60 + 3 * P, 60 + 6 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(90 + 6 * P, 90 + 9 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    120 + 9 * P, 120 + 12 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    150 + 12 * P, 150 + 15 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    180 + 15 * P, 180 + 18 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(210 + 18 * P, LL, dtype=torch.int64, device=DEVICE),
            ),
            dim=0,
        ),
    ),
    "hang": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(ORIGIN for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(30, 30 + 4 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(60 + 4 * P, 60 + 8 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(90 + 8 * P, 90 + 12 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    120 + 12 * P, 120 + 16 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    150 + 16 * P, 150 + 20 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    180 + 20 * P, 180 + 24 * P, dtype=torch.int64, device=DEVICE
                ),
            ),
            dim=0,
        ),
    ),
    "hang2": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_1st for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(HANG_2nd for _ in range(4 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(ORIGIN for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(30, 30 + 4 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(60 + 4 * P, 60 + 8 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(90 + 8 * P, 90 + 12 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    120 + 12 * P, 120 + 16 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    150 + 16 * P, 150 + 20 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    180 + 20 * P, 180 + 24 * P, dtype=torch.int64, device=DEVICE
                ),
            ),
            dim=0,
        ),
    ),
    "reacher": get_reacher(),
    f"swing_0.52": get_swing(),
}

COMMANDS = {
    **{k: v.cumsum(dim=0) for k, v in OTHER.items()},
    "reacher-no-mask": get_reacher(),
    "swing-no-mask": get_swing(0.52),
    "front_back-full-mask": (
        compose(
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(FRONT for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(BACK for _ in range(3 * P)),
            *(mask((1, STATE_DIM)) for _ in range(30)),
            *(ORIGIN for _ in range(LL)),
            dim=0,
        ),
        torch.cat(
            (
                torch.arange(30, 30 + 3 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(60 + 3 * P, 60 + 6 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(90 + 6 * P, 90 + 9 * P, dtype=torch.int64, device=DEVICE),
                torch.arange(
                    120 + 9 * P, 120 + 12 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    150 + 12 * P, 150 + 15 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(
                    180 + 15 * P, 180 + 18 * P, dtype=torch.int64, device=DEVICE
                ),
                torch.arange(210 + 18 * P, LL, dtype=torch.int64, device=DEVICE),
            ),
            dim=0,
        ),
    ),
}

if __name__ == "__main__":
    ...
