import cv2
import numpy as np

from train.behavioral_cloning.datasets.agent_state import OBSERVATION_EQUIP
from train.behavioral_cloning.spaces.action_spaces import BINARY_ACTIONS


def build_data_processor(data_processors: list):
    """
    :param data_processors: list of data processors (e.g.: [to_float32, divide_pov_by_255])
    :return: callable data processing/augmentation function
    """

    def call_processors(data):
        """ Applies multiple data processing functions in sequence. """
        for processor in data_processors:
            data = processor(data)
        return data

    return call_processors


def to_float32(data):
    """ Convert data to float32 """
    data = list(data)
    for i, x in enumerate(data):
        if data[i] is not None:
            data[i] = x.astype(np.float32)
    return data


def divide_pov_by_255(data):
    """ Normalize frames to range [0, 1] """
    data = list(data)
    pov = data[0]
    pov = pov.astype(np.float32) / 255
    data[0] = pov
    return data


def random_left_right_flip(data, p: float = 0.5):
    """ Apply random left right flip """
    data = list(data)
    pov, binary_actions, camera_actions = data[:3]
    pov = pov.astype(np.float32) / 255

    if np.random.random() > p:
        # flip frames
        assert np.ndim(pov) == 4
        pov = np.ascontiguousarray(pov[:, :, :, ::-1])

        # flip actions left and right
        l, r = BINARY_ACTIONS.index("left"), BINARY_ACTIONS.index("right")
        tmp = binary_actions[:, l].copy()
        binary_actions[:, l] = binary_actions[:, r]
        binary_actions[:, r] = tmp

        # flip horizontal camera
        camera_actions[:, 1] *= -1

        data[0] = pov
        data[1] = binary_actions
        data[2] = camera_actions
    return data


def resize_64_to_48(data):
    """ Resize frames to 48 x 48 pixel """
    data = list(data)
    pov = data[0]
    rsz_pov = np.zeros((pov.shape[0], pov.shape[1], 48, 48), dtype=pov.dtype)
    for i, img in enumerate(pov):
        img = cv2.resize(np.transpose(img, (1, 2, 0)), (48, 48))
        rsz_pov[i] = np.transpose(img, (2, 0, 1))
    data[0] = rsz_pov
    return data


def resize_64_to_32(data):
    """ Resize frames to 32 x 32 pixel """
    data = list(data)
    pov = data[0]
    rsz_pov = np.zeros((pov.shape[0], pov.shape[1], 32, 32), dtype=pov.dtype)
    for i, img in enumerate(pov):
        img = cv2.resize(np.transpose(img, (1, 2, 0)), (32, 32))
        rsz_pov[i] = np.transpose(img, (2, 0, 1))
    data[0] = rsz_pov
    return data


def stack_gray_scale_delta_frame(data):
    """ Stack gray scale delta to rgb array """
    data = list(data)
    pov = data[0]
    delta_pov = np.zeros((pov.shape[0], pov.shape[1] + 1, 48, 48), dtype=pov.dtype)
    for i in range(1, pov.shape[0]):
        delta_pov[i, 0:3] = pov[i]
        # convert to grayscale
        gray_i = cv2.cvtColor(np.transpose(pov[i], (1, 2, 0)), cv2.COLOR_BGR2GRAY)
        gray_j = cv2.cvtColor(np.transpose(pov[i - 1], (1, 2, 0)), cv2.COLOR_BGR2GRAY)
        # add frame delta
        delta = gray_i - gray_j
        delta_pov[i, 3] = delta
    data[0] = delta_pov
    return data


def add_equipped_item_to_frame(data):
    data = list(data)

    # get relevant data
    pov = data[0]
    equipped_items = data[5]

    # prepare color for each item
    colors = np.vstack([[255, 0, 0], [0, 255, 0], [0, 0, 255],
                        [128, 0, 0], [0, 128, 0], [0, 0, 128],
                        [255, 255, 0], [0, 255, 255], [255, 0, 255]]).astype(pov.dtype)
    assert len(OBSERVATION_EQUIP) == colors.shape[0]

    if pov.dtype == np.float32:
        colors = colors.astype(np.float32) / 255

    w = 2
    pov = data[0]
    for i, img in enumerate(pov):
        color = colors[int(equipped_items[i])]
        pov[i, :, 0:w, :] = np.repeat(color.reshape((3, 1, 1)), repeats=w, axis=1)
        pov[i, :, -w:, :] = np.repeat(color.reshape((3, 1, 1)), repeats=w, axis=1)
        pov[i, :, :, 0:w] = np.repeat(color.reshape((3, 1, 1)), repeats=w, axis=2)
        pov[i, :, :, -w:] = np.repeat(color.reshape((3, 1, 1)), repeats=w, axis=2)
    data[0] = pov

    return data


def add_inventory_to_frame(data):
    data = list(data)

    # get relevant data
    pov = data[0]
    inventory = data[4]

    # fix inventory shape
    if inventory.shape[1] != 18:
        inventory = inventory.T

    for i, img in enumerate(pov):

        # prepare inventory image
        inventory_img = inventory[i:i+1]
        inventory_img = np.repeat(inventory_img.T, axis=0, repeats=3).reshape((1, -1))
        inventory_img = np.repeat(inventory_img, axis=0, repeats=3)[np.newaxis]
        inventory_img = np.repeat(inventory_img, axis=0, repeats=3)

        # scale inventory image
        scale = 10
        inventory_img *= scale
        inventory_img = 255 - np.clip(inventory_img, 0, 255)

        # update inventory image
        pov[i, 0:3, 1:4, 5:5+54] = inventory_img

    data[0] = pov

    return data


# hack for collecting statistics
cobblestones = []
dirts = []


def add_pseudo_depth_to_frame(data):
    """
    selected quantiles: [0.5, 0.75, 0.9]

    cobblestone [ 48.  84. 291.]
    dirts       [ 7. 24. 27.]
    both        [ 54.  95. 318.]
    """
    data = list(data)

    # get relevant data
    pov = data[0]
    inventory = data[4]

    # fix inventory shape
    if inventory.shape[1] != 18:
        inventory = inventory.T

    # hack for collecting statistics
    # global cobblestones, dirts
    # cobblestones.append(max(inventory[:, 1]))
    # dirts.append(max(inventory[:, 3]))
    # print("cobblestone", np.quantile(cobblestones, [0.5, 0.75, 0.9]))
    # print("      dirts", np.quantile(dirts, [0.5, 0.75, 0.9]))
    # print("       both", np.quantile(np.array(cobblestones) + np.array(dirts), [0.5, 0.75, 0.9]))
    # print("-" * 25)

    max_dirt_and_cobblestone = 95.0

    for i, img in enumerate(pov):

        # prepare inventory image
        dirt_and_cobblestone = inventory[i, 1] + inventory[i, 3]
        progress_bar = np.zeros((3, 3, 54), dtype=pov.dtype)
        dirt_and_cobblestone = np.clip(dirt_and_cobblestone, 0, max_dirt_and_cobblestone)
        bar_length = int(54.0 * dirt_and_cobblestone / max_dirt_and_cobblestone)
        progress_bar[:, :, :bar_length] = 255

        # update inventory image
        pov[i, 0:3, 1:4, 5:5+54] = progress_bar

    data[0] = pov

    return data


def default_data_transform(data):
    data = to_float32(data)
    data = divide_pov_by_255(data)
    return data

###
### DO NOT USE TRANSFORMS BELOW THIS POINT
###

def _random_cam_shift_on_64(data):
    """ Randomly shift image and according to augmented camera angle """
    pov, discrete_action_matrix, camera_actions, rewards, inventory, equipped_items = data

    # cam_shift = np.random.normal(loc=0, scale=15)
    # img_shift = int(np.round(-0.15 * cam_shift))

    # let's try something more direct
    cam_shift = np.random.normal(loc=0, scale=3)
    img_shift = int(np.round(-1.0 * cam_shift))

    # adopt camera actions
    camera_actions[:, 1] += cam_shift

    # adopt image
    # TODO: could be much more efficient!
    shifted_pov = np.zeros_like(pov, dtype=pov.dtype)
    for i, img in enumerate(pov):
        matrix = np.zeros((2, 3), dtype=np.float64)
        matrix[0, 0] = 1
        matrix[1, 1] = 1
        matrix[0, 2] = img_shift
        shifted_pov[i] = np.transpose(cv2.warpAffine(np.transpose(img, (1, 2, 0)), matrix, (64, 64)),
                                      (2, 0, 1))

    return shifted_pov, discrete_action_matrix, camera_actions, rewards, inventory, equipped_items


def _random_black_border(data):
    """ Randomly add black border to left or right part of the image """

    pov, discrete_action_matrix, camera_actions, rewards, inventory, equipped_items = data

    # sample border width
    border = np.random.randint(1, 4)

    # add border to the left or to the right
    if np.random.random() > 0.5:
        pov[:, :, :, -border:] = 0
    else:
        pov[:, :, :, :border] = 0

    return pov, discrete_action_matrix, camera_actions, rewards, inventory, equipped_items
