from train.behavioral_cloning.datasets.minerl_dataset import MineRLBaseDataset
from train.behavioral_cloning.spaces.action_spaces import MultiBinaryWithCameraRegression
from train.behavioral_cloning.spaces.input_spaces import SingleFrameWithBinaryActionAndContinuousCameraSequence
from train.behavioral_cloning.datasets.transforms import build_data_processor

import cv2
import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression


VERBOSE = False
MAX_SHIFT = 8
SEQ_LENGTH = 32
INPUT_SPACE = SingleFrameWithBinaryActionAndContinuousCameraSequence(SEQ_LENGTH)
ACTION_SPACE = MultiBinaryWithCameraRegression()
DATA_TRANSFORM = build_data_processor([])


def main():

    # load dataset
    root, experiment = "tmp/data/MineRL/log_only", "MineRLObtainDiamond-v0"
    data = MineRLBaseDataset(root=root, sequence_len=SEQ_LENGTH, train=True, prepare=False,
                             experiment=experiment,
                             data_transform=DATA_TRANSFORM)

    # sequence
    cam_angles = []
    img_shifts = []

    for i in tqdm.tqdm(np.random.permutation(len(data))):

        if len(cam_angles) > 10000:
            break

        pov, discrete_action_matrix, camera_actions, rewards = data[i]
        # print(pov.shape, pov.dtype)
        # print(camera_actions.shape, camera_actions.dtype)

        lr_action = camera_actions[:, 1]

        for j in range(1, SEQ_LENGTH):

            img0 = np.transpose(pov[j-1], (1, 2, 0))
            img1 = np.transpose(pov[j], (1, 2, 0))
            img1_float32 = img1.astype(np.float32) / 255
            action = lr_action[j-1]

            if np.abs(action) < 2.5:
                continue

            # if discrete_action_matrix[j-1, 2] == 0:
            #     continue

            shifts = list(range(-MAX_SHIFT, MAX_SHIFT+1))
            diffs = []
            for shift in shifts:

                matrix = np.zeros((2, 3), dtype=np.float64)
                matrix[0, 0] = 1
                matrix[1, 1] = 1
                matrix[0, 2] = shift
                warped = cv2.warpAffine(img0, matrix, (64, 64))
                warped = warped.astype(np.float32) / 255

                diff = np.abs(img1_float32[:, MAX_SHIFT:-MAX_SHIFT] - warped[:, MAX_SHIFT:-MAX_SHIFT])
                diffs.append(np.sum(diff))

            # book keeping
            if np.min(diffs) < 100:
                cam_angles.append(action)
                min_shift = shifts[np.argmin(diffs)]
                img_shifts.append(min_shift)

                if VERBOSE:
                    # recompute min_shift image
                    matrix = np.zeros((2, 3), dtype=np.float64)
                    matrix[0, 0] = 1
                    matrix[1, 1] = 1
                    matrix[0, 2] = min_shift
                    warped = cv2.warpAffine(img0, matrix, (64, 64))
                    warped = warped.astype(np.float32) / 255
                    diff = np.abs(img1_float32[:, MAX_SHIFT:-MAX_SHIFT] - warped[:, MAX_SHIFT:-MAX_SHIFT])

                    plt.figure("Shift Plot")
                    plt.clf()

                    plt.subplot(2, 4, 1)
                    plt.imshow(img0, interpolation="nearest")
                    plt.title("t-1")

                    plt.subplot(2, 4, 2)
                    plt.imshow(img1, interpolation="nearest")
                    plt.title("t")

                    plt.subplot(2, 4, 3)
                    plt.imshow(warped, interpolation="nearest")
                    plt.title("t -> t-1")

                    plt.subplot(2, 4, 4)
                    plt.imshow(diff, interpolation="nearest", vmin=0, vmax=1)
                    plt.title("diff")

                    plt.subplot(2, 1, 2)
                    plt.plot(shifts, diffs, "bo-")
                    plt.plot(2*[min_shift], [0, 500], "r--")
                    plt.grid()
                    plt.ylim([0, 500])
                    plt.xlabel("image shift")
                    plt.ylabel("pixel difference")

                    plt.suptitle("Camera LR Action %.2f" % action)

                    plt.show(block=True)

    # fit linear model
    angles = np.linspace(-30, 30, 10)
    img_shifts = np.asarray(img_shifts)
    cam_angles = np.asarray(cam_angles)

    indices = np.nonzero(np.abs(cam_angles) < 20)[0]
    reg = LinearRegression().fit(cam_angles[indices].reshape((-1, 1)),
                                 img_shifts[indices].reshape((-1, 1)))
    print("reg.coef_", reg.coef_)

    plt.figure("Cam to Shift")
    plt.clf()
    plt.plot(cam_angles, img_shifts, "o", color="cornflowerblue", alpha=0.5)
    for k in np.linspace(-0.5, 0.5, 10):
        pred = k * angles
        plt.plot(angles, pred, "-", alpha=0.5, label="k=%.1f" % k)
    plt.plot(angles, reg.predict(angles.reshape((-1, 1))), "k", linewidth=3)
    plt.legend()
    plt.grid()
    plt.show(block=True)


if __name__ == "__main__":
   main()
