import torch
from params import PARAMS
from utils.common import get_device
from args import DEVICE3, args
import math
from itertools import chain
import numpy as np

DEVICE = DEVICE3

L = PARAMS["post_steps"] * args.eval_ratio + PARAMS["post_steps"] - 1
P = PARAMS["post_steps"]
VELOCITYS = [0.05]

LL = 12 * P

INITIAL_STATE = torch.tensor([0, 0], dtype=torch.float32, device=DEVICE)


def compose(*tensors: torch.Tensor, dim=-1):
    assert len(tensors) >= 2
    _t1 = tensors[0]
    # assert all(map(lambda t: t.shape == _t1.shape, tensors))
    assert all(map(lambda t: len(t.shape) == 2, tensors))
    return torch.cat(tensors, dim=dim)


def zeros(s):
    return torch.zeros(s, dtype=torch.float32, device=DEVICE)


def ones(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE)


def mask(s):
    return torch.ones(s, dtype=torch.float32, device=DEVICE) * -999


def smooth(s, smooth_end=True):
    (N, _) = s.shape
    S = 4
    # assert N / 2 > S

    for i in range(S):
        s[i].div_(S - i)
        if smooth_end:
            s[N - 1 - i].div_(S - i)

    return s


COMMANDS_RAW = {
    v: {
        f"move_up_right_{v}-no-mask": smooth(
            compose(v * ones((L, 1)), v * ones((L, 1))), smooth_end=False
        ),
        f"move_up_left_{v}-no-mask": smooth(
            compose(-v * ones((L, 1)), v * ones((L, 1))), smooth_end=False
        ),
        f"move_down_right_{v}-no-mask": smooth(
            compose(v * ones((L, 1)), -v * ones((L, 1))), smooth_end=False
        ),
        f"move_down_left_{v}-no-mask": smooth(
            compose(-v * ones((L, 1)), -v * ones((L, 1))), smooth_end=False
        ),
        f"move_up_{v}-no-mask": smooth(
            compose(zeros((L, 1)), v * ones((L, 1))), smooth_end=False
        ),
        f"move_down_{v}-no-mask": smooth(
            compose(zeros((L, 1)), -v * ones((L, 1))), smooth_end=False
        ),
        f"move_left_{v}-no-mask": smooth(
            compose(-v * ones((L, 1)), zeros((L, 1))), smooth_end=False
        ),
        f"move_right_{v}-no-mask": smooth(
            compose(v * ones((L, 1)), zeros((L, 1))), smooth_end=False
        ),
    }
    for v in VELOCITYS
}

FIXED_VELOCITY = 0.025

OTHER = {
    "stay-no-mask": zeros((L, 2)),
    "move_right_then_stay-no-mask": compose(
        smooth(compose(FIXED_VELOCITY * ones((int(L / 2), 1)), zeros((int(L / 2), 1)))),
        zeros((L - int(L / 2), 2)),
        dim=0,
    ),
    "move_left_then_stay-no-mask": compose(
        smooth(
            compose(-FIXED_VELOCITY * ones((int(L / 2), 1)), zeros((int(L / 2), 1)))
        ),
        zeros((L - int(L / 2), 2)),
        dim=0,
    ),
    "move_up_then_stay-no-mask": compose(
        smooth(
            compose(zeros(((int(L / 2), 1))), FIXED_VELOCITY * ones((int(L / 2), 1)))
        ),
        zeros((L - int(L / 2), 2)),
        dim=0,
    ),
    "move_bottom_then_stay-no-mask": compose(
        smooth(
            compose(zeros(((int(L / 2), 1))), -FIXED_VELOCITY * ones((int(L / 2), 1)))
        ),
        zeros((L - int(L / 2), 2)),
        dim=0,
    ),
}


def get_circle():
    pi = math.pi
    coords = [
        (0.2 * math.cos(t / 20 * pi) - 0.2, 0.2 * math.sin(t / 20 * pi))
        for t in range(40)
    ]
    reps = int(LL / len(coords)) * 2
    all_coords = list(chain(*(coords for _ in range(reps))))

    return torch.as_tensor(all_coords, device=DEVICE, dtype=torch.float32)


def get_heart2():
    xp = np.linspace(0, 2, 100)
    xn = np.linspace(-2, 0, 100)

    # y1 = scipy.sqrt(1-(abs(x)-1)**2)
    # y2 = -3*scipy.sqrt(1-(abs(x)/2)**0.5)

    y1 = lambda x: np.sqrt(1 - (np.abs(x) - 1) ** 2)
    y2 = lambda x: -3 * np.sqrt(1 - (np.abs(x) / 2) ** 0.5)
    corrd = []
    corrd.extend([(x, y1(x)) for x in xp.tolist()])
    corrd.extend([(x, y2(x)) for x in reversed(xp.tolist())])
    corrd.extend([(x, y2(x)) for x in reversed(xn.tolist())])
    corrd.extend([(x, y1(x)) for x in xn.tolist()])

    coord = torch.as_tensor(corrd, device=DEVICE, dtype=torch.float32)

    return compose(coord, zeros((LL, 2)), dim=0)


TRIANGLE_VELOCITY = 0.025

LONG = {
    "move_up_right": smooth(
        compose(
            0.05 * ones((LL, 1)),
            0.05 * ones((LL, 1)),
        ),
        smooth_end=False,
    ).cumsum(dim=0),
    "arrive": compose(
        *(mask((1, 2)) for _ in range(7)),
        torch.tensor([0.09, 0.12], dtype=torch.float32, device=DEVICE)
        .unsqueeze(0)
        .repeat_interleave(LL, dim=0),
        dim=0,
    ),
    "stay": zeros((LL, 2)).cumsum(dim=0),
    "circle": get_circle(),
    "heart": get_heart2() / 2,
    "triangles": compose(
        *(
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
                -TRIANGLE_VELOCITY * ones((2 * P, 1)),
            ),
            compose(
                -TRIANGLE_VELOCITY * ones((P, 1)),
                TRIANGLE_VELOCITY * ones((P, 1)),
            ),
        ),
        dim=0,
    ).cumsum(dim=0),
}

COMMANDS = {
    **{key: val.cumsum(dim=0) for v in COMMANDS_RAW.values() for key, val in v.items()},
    **{k: v.cumsum(dim=0) for k, v in OTHER.items()},
    "circle-no-mask": get_circle(),
    "heart-no-mask": get_heart2(),
}

if __name__ == "__main__":
    import seaborn as sns
    import pandas as pd
    import matplotlib.pyplot as plt

    circle = get_circle().tolist()
    data = pd.DataFrame({"x": [c[0] for c in circle], "y": [c[1] for c in circle]})

    sns.scatterplot(data=data, x="x", y="y")
    plt.show()
    plt.close()
