import gym
import numpy as np
import einops
from scipy.spatial.transform import Rotation as R
import pdb

from .d4rl import load_environment

# -----------------------------------------------------------------------------#
# -------------------------------- general api --------------------------------#
# -----------------------------------------------------------------------------#


def compose(*fns):
    def _fn(x):
        for fn in fns:
            x = fn(x)
        return x

    return _fn


def get_preprocess_fn(fn_names, env):
    fns = [eval(name)(env) for name in fn_names]
    return compose(*fns)


def get_policy_preprocess_fn(fn_names):
    fns = [eval(name) for name in fn_names]
    return compose(*fns)


# -----------------------------------------------------------------------------#
# -------------------------- preprocessing functions --------------------------#
# -----------------------------------------------------------------------------#

# ------------------------ @TODO: remove some of these ------------------------#


def arctanh_actions(*args, **kwargs):
    epsilon = 1e-4

    def _fn(dataset):
        actions = dataset["actions"]
        assert (
            actions.min() >= -1 and actions.max() <= 1
        ), f"applying arctanh to actions in range [{actions.min()}, {actions.max()}]"
        actions = np.clip(actions, -1 + epsilon, 1 - epsilon)
        dataset["actions"] = np.arctanh(actions)
        return dataset

    return _fn


def add_deltas(env):
    def _fn(dataset):
        deltas = dataset["next_observations"] - dataset["observations"]
        dataset["deltas"] = deltas
        return dataset

    return _fn


def maze2d_set_terminals(env):
    env = load_environment(env) if type(env) == str else env
    goal = np.array(env._target)
    threshold = 0.5

    def _fn(dataset):
        xy = dataset["observations"][:, :2]
        distances = np.linalg.norm(xy - goal, axis=-1)
        at_goal = distances < threshold
        timeouts = np.zeros_like(dataset["timeouts"])

        ## timeout at time t iff
        ##      at goal at time t and
        ##      not at goal at time t + 1
        timeouts[:-1] = at_goal[:-1] * ~at_goal[1:]

        timeout_steps = np.where(timeouts)[0]
        path_lengths = timeout_steps[1:] - timeout_steps[:-1]

        print(
            f"[ utils/preprocessing ] Segmented {env.name} | {len(path_lengths)} paths | "
            f"min length: {path_lengths.min()} | max length: {path_lengths.max()}"
        )

        dataset["timeouts"] = timeouts
        return dataset

    return _fn


# -------------------------- block-stacking --------------------------#


def blocks_quat_to_euler(observations):
    """
    input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ]
        xyz: 3
        quat: 4
        contact: 1

    returns : [ N x robot_dim + n_blocks * 10] = [ N x 47 ]
        xyz: 3
        sin: 3
        cos: 3
        contact: 1
    """
    robot_dim = 7
    block_dim = 8
    n_blocks = 4
    assert observations.shape[-1] == robot_dim + n_blocks * block_dim

    X = observations[:, :robot_dim]

    for i in range(n_blocks):
        start = robot_dim + i * block_dim
        end = start + block_dim

        block_info = observations[:, start:end]

        xpos = block_info[:, :3]
        quat = block_info[:, 3:-1]
        contact = block_info[:, -1:]

        euler = R.from_quat(quat).as_euler("xyz")
        sin = np.sin(euler)
        cos = np.cos(euler)

        X = np.concatenate(
            [
                X,
                xpos,
                sin,
                cos,
                contact,
            ],
            axis=-1,
        )

    return X


def blocks_euler_to_quat_2d(observations):
    robot_dim = 7
    block_dim = 10
    n_blocks = 4

    assert observations.shape[-1] == robot_dim + n_blocks * block_dim

    X = observations[:, :robot_dim]

    for i in range(n_blocks):
        start = robot_dim + i * block_dim
        end = start + block_dim

        block_info = observations[:, start:end]

        xpos = block_info[:, :3]
        sin = block_info[:, 3:6]
        cos = block_info[:, 6:9]
        contact = block_info[:, 9:]

        euler = np.arctan2(sin, cos)
        quat = R.from_euler("xyz", euler, degrees=False).as_quat()

        X = np.concatenate(
            [
                X,
                xpos,
                quat,
                contact,
            ],
            axis=-1,
        )

    return X


def blocks_euler_to_quat(paths):
    return np.stack([blocks_euler_to_quat_2d(path) for path in paths], axis=0)


def blocks_process_cubes(env):
    def _fn(dataset):
        for key in ["observations", "next_observations"]:
            dataset[key] = blocks_quat_to_euler(dataset[key])
        return dataset

    return _fn


def blocks_remove_kuka(env):
    def _fn(dataset):
        for key in ["observations", "next_observations"]:
            dataset[key] = dataset[key][:, 7:]
        return dataset

    return _fn


def blocks_add_kuka(observations):
    """
    observations : [ batch_size x horizon x 32 ]
    """
    robot_dim = 7
    batch_size, horizon, _ = observations.shape
    observations = np.concatenate(
        [
            np.zeros((batch_size, horizon, 7)),
            observations,
        ],
        axis=-1,
    )
    return observations


def blocks_cumsum_quat(deltas):
    """
    deltas : [ batch_size x horizon x transition_dim ]
    """
    robot_dim = 7
    block_dim = 8
    n_blocks = 4
    assert deltas.shape[-1] == robot_dim + n_blocks * block_dim

    batch_size, horizon, _ = deltas.shape

    cumsum = deltas.cumsum(axis=1)
    for i in range(n_blocks):
        start = robot_dim + i * block_dim + 3
        end = start + 4

        quat = deltas[:, :, start:end].copy()

        quat = einops.rearrange(quat, "b h q -> (b h) q")
        euler = R.from_quat(quat).as_euler("xyz")
        euler = einops.rearrange(euler, "(b h) e -> b h e", b=batch_size)
        cumsum_euler = euler.cumsum(axis=1)

        cumsum_euler = einops.rearrange(cumsum_euler, "b h e -> (b h) e")
        cumsum_quat = R.from_euler("xyz", cumsum_euler).as_quat()
        cumsum_quat = einops.rearrange(cumsum_quat, "(b h) q -> b h q", b=batch_size)

        cumsum[:, :, start:end] = cumsum_quat.copy()

    return cumsum


def blocks_delta_quat_helper(observations, next_observations):
    """
    input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ]
        xyz: 3
        quat: 4
        contact: 1
    """
    robot_dim = 7
    block_dim = 8
    n_blocks = 4
    assert (
        observations.shape[-1]
        == next_observations.shape[-1]
        == robot_dim + n_blocks * block_dim
    )

    deltas = (next_observations - observations)[:, :robot_dim]

    for i in range(n_blocks):
        start = robot_dim + i * block_dim
        end = start + block_dim

        block_info = observations[:, start:end]
        next_block_info = next_observations[:, start:end]

        xpos = block_info[:, :3]
        next_xpos = next_block_info[:, :3]

        quat = block_info[:, 3:-1]
        next_quat = next_block_info[:, 3:-1]

        contact = block_info[:, -1:]
        next_contact = next_block_info[:, -1:]

        delta_xpos = next_xpos - xpos
        delta_contact = next_contact - contact

        rot = R.from_quat(quat)
        next_rot = R.from_quat(next_quat)

        delta_quat = (next_rot * rot.inv()).as_quat()
        w = delta_quat[:, -1:]

        ## make w positive to avoid [0, 0, 0, -1]
        delta_quat = delta_quat * np.sign(w)

        ## apply rot then delta to ensure we end at next_rot
        ## delta * rot = next_rot * rot' * rot = next_rot
        next_euler = next_rot.as_euler("xyz")
        next_euler_check = (R.from_quat(delta_quat) * rot).as_euler("xyz")
        assert np.allclose(next_euler, next_euler_check)

        deltas = np.concatenate(
            [
                deltas,
                delta_xpos,
                delta_quat,
                delta_contact,
            ],
            axis=-1,
        )

    return deltas


def blocks_add_deltas(env):
    def _fn(dataset):
        deltas = blocks_delta_quat_helper(
            dataset["observations"], dataset["next_observations"]
        )
        # deltas = dataset['next_observations'] - dataset['observations']
        dataset["deltas"] = deltas
        return dataset

    return _fn
