import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy

class Flow(object):
    def __init__(self):
        # self.diaglen = np.sqrt(200**2 + 200**2)
        pass

    def get_image(self, uv_o, uv_g_f, rgb_o=None, rgb_g=None):
        # uv_o, uv_o_nonan = remove_dups(self.env.camera_params, uv_o_f, coords_o, depth, rgb=rgb_o, zthresh=0.001)

        # Compute uv diff
        uv_diff = np.rint((uv_g_f - uv_o)/719*199)
        non_nan_idxs = ~np.isnan(uv_diff).any(axis=1)
        uv_diff_nonan = uv_diff[non_nan_idxs]
        uv_o_nonan = np.rint(uv_o[non_nan_idxs]/719*199).astype(int)
        # print(f'Any nan in uv_g: {np.any(np.isnan(uv_g_f))}')

        # Make diff image
        # uv_diff_nonan = uv_diff_nonan / self.diaglen # normalize
        im_diff_sp = np.zeros((200, 200, 2))
        try:
            im_diff_sp[uv_o_nonan[:, 0], uv_o_nonan[:, 1], :] = uv_diff_nonan
        except Exception as e:
            print(e)
            import IPython; IPython.embed()

        if False: # visualize
            s = 20
            # uv_diff_nonan = uv_diff_nonan * self.diaglen
            # rgb_o = cv2.resize(rgb_o, (200, 200))
            im = np.zeros((200, 200))
            fig, ax = plt.subplots(1, 5, figsize=(16, 8))
            ax[0].set_title('occluded particles')
            uv_o_occl = np.rint(uv_o[np.isnan(uv_diff).any(axis=1)]/719*199).astype(int) # Doesn't work, need original uv_o
            ax[0].scatter(uv_o_occl[:, 1], uv_o_occl[:, 0], c='r', s=0.1)
            # ax[0].imshow(rgb_o)
            ax[0].imshow(im)
            ax[1].set_title('visible particles')
            # ax[1].imshow(rgb_o)
            ax[1].imshow(im)
            ax[1].scatter(uv_o_nonan[:, 1], uv_o_nonan[:, 0], c='b', s=0.1)
            ax[2].set_title('sparse flow plot')
            ax[2].imshow(im)
            ax[2].quiver(uv_o_nonan[::, 1], uv_o_nonan[::, 0], 
                        uv_diff_nonan[::, 1], uv_diff_nonan[::, 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
            ax[3].set_title('sparse flow plot (downsampled)')
            ax[3].imshow(im)
            ax[3].quiver(uv_o_nonan[::s, 1], uv_o_nonan[::s, 0], 
                        uv_diff_nonan[::s, 1], uv_diff_nonan[::s, 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
            ax[4].set_title('goal image')
            # ax[4].imshow(rgb_g)
            ax[4].imshow(im)
            uv_g = np.rint(uv_g_f/719*199).astype(int)
            ax[4].scatter(uv_g[:, 1], uv_g[:, 0], c='r', s=0.1)
            plt.show()

        return im_diff_sp
