from scipy.io import loadmat
import numpy as np 
import cv2

# create dictionaries for xy data for each dataset
def list_loadmats_todatasets(list_loadmats, num_x_predictors): 
    """
    This converts the saved outputs from the navier stokes solver to a dictionary 
    dataset
    args: 
        - list_loadmats: A list of things loaded from scipy.io.loadmat
        - num_x_predictors: the number of x values to use in predicting
    returns: 
        - nvstokes_xy: a dictionary of dictionaries: in the first dictionary the key 
        int(i) indexes into a dictionary where the key 'x' stores the x values and the key
        'y' indexes into the y values 
    """
    nvstokes_datasets = list_loadmats
    nvstokes_xy = {}
    for i in range(len(nvstokes_datasets)):
        nvstokes_xy[i] = {'x':[], 'y':[]}


    #num_x_predictors = 2 # the number of x images to use for y prediction
    # populate the xy dictionaries for the Navier Stokes datasets
    for dataset_num, dataset in enumerate(nvstokes_datasets): 
        dataset_length = dataset['u'][0].shape[-1]
        num_xy_pairs = dataset_length - num_x_predictors # the number of xy pairs in the dataset given you need some amount of images for prediction 

        for img_num in np.arange(0, num_xy_pairs):
            # create the x data
            curr_x_data = []
            for i in range(num_x_predictors): 
                curr_x_data.append(dataset['u'][0][:, :, img_num + i])
            curr_x_data = tuple(curr_x_data)

            # create the y data 
            curr_y_data = dataset['u'][0][:, :, img_num + num_x_predictors]


            nvstokes_xy[dataset_num]['x'].append(curr_x_data)
            nvstokes_xy[dataset_num]['y'].append(curr_y_data)
            
    return nvstokes_xy
        

# create dictionaries for xy data for each dataset
def list_loadmats_todatasets_skip(list_loadmats, num_x_predictors, skip_factor): 
    """
    This converts the saved outputs from the navier stokes solver to a dictionary 
    dataset
    args: 
        - list_loadmats: A list of things loaded from scipy.io.loadmat
        - num_x_predictors: the number of x values to use in predicting
        - skip_factor: the number of images to skip in between - while creating datapoints
    returns: 
        - nvstokes_xy: a dictionary of dictionaries: in the first dictionary the key 
        int(i) indexes into a dictionary where the key 'x' stores the x values and the key
        'y' indexes into the y values 
    """
    nvstokes_datasets = list_loadmats
    nvstokes_xy = {}
    for i in range(len(nvstokes_datasets)):
        nvstokes_xy[i] = {'x':[], 'y':[]}


    #num_x_predictors = 2 # the number of x images to use for y prediction
    # populate the xy dictionaries for the Navier Stokes datasets
    for dataset_num, dataset in enumerate(nvstokes_datasets): 
        dataset_length = dataset['u'][0].shape[-1]
        img_indices_touse = np.arange(0, dataset_length, skip_factor)

        num_xy_pairs = len(img_indices_touse) - num_x_predictors # the number of xy pairs in the dataset given you need some amount of images for prediction 

        for img_index_num in np.arange(0, num_xy_pairs):
            # create the x data
            curr_x_data = []
            for i in range(num_x_predictors): 
                img = dataset['u'][0][:, :, img_indices_touse[img_index_num + i]]
                curr_x_data.append(img)
            curr_x_data = tuple(curr_x_data)

            # create the y data 
            curr_y_data = dataset['u'][0][:, :, img_indices_touse[img_index_num + num_x_predictors]]


            nvstokes_xy[dataset_num]['x'].append(curr_x_data)
            nvstokes_xy[dataset_num]['y'].append(curr_y_data)
            
    return nvstokes_xy
        


def img_datasets_difference_datasets(x_dataset, y_dataset, use_difference_output, image_in_x): 
	"""
	Converts imgae datasets into difference datasets. When use_difference_output == True 
	then the outputted y dataset should be the difference between the last true image and the 
	next true image: so the model is trained to just output the difference to the next image 
	when use_difference_output == False the model is trained to output the whole image 

	_image_in_x: the whole image is included in x tuple when this is True, else it is not 
	and the x tuple is just (velocity, acceleration)

	args: 
		- x_dataset: dataset of tuple of 3 images
		- y_dataset: dataset of single images
		- use_difference_output: boolean indicating whether the outputted y dataset is just difference images or the whole image itself (just the inputted y dataset)
		- image_in_x: boolean indicating whether the x dataset tuples contain the whole image or just (vel, accel)
	returns: 
		- diff_x_dataset
		- diff_y_dataset

	NOTE: can specify x_dataset and y_dataset as None if you don't want them
	"""
	assert(type(x_dataset[0]) == tuple)
	assert(len(x_dataset[0]) == 3)

	diff_x_dataset = []
	diff_y_dataset = []

	num_datapoints = None
	if type(x_dataset) != type(None): 
		num_datapoints = len(x_dataset)
	elif type(y_dataset) != type(None): 
		if type(num_datapoints) != type(None): 
			num_datapoints = min(num_datapoints, len(y_dataset))
		else: 
			num_datapoints = len(y_dataset)
	else: 
		num_datapoints = 0

	for i in range(num_datapoints):
		new_x_datapoint = None
		new_y_datapoint = None

		if (type(x_dataset) != type(None)):
			x_datapoint = x_dataset[i]
			last_x_img = x_datapoint[-1]
			x_velocity = x_datapoint[-1] - x_datapoint[-2]
			x_past_velocity = x_datapoint[-2] - x_datapoint[-3]
			x_acceleration = x_velocity - x_past_velocity

			if image_in_x: 
				new_x_datapoint = (last_x_img, x_velocity, x_acceleration)
			else: 
				new_x_datapoint = (x_velocity, x_acceleration)

		if (type(y_dataset) != type(None)):
			y_datapoint = y_dataset[i]
			if use_difference_output: 
				new_y_datapoint = y_datapoint - last_x_img # difference between the last x image and the y image 
			else:
				new_y_datapoint = y_datapoint 

		# append to the datasets 
		diff_x_dataset.append(new_x_datapoint)
		diff_y_dataset.append(new_y_datapoint)

	return diff_x_dataset, diff_y_dataset


def img_datasets_to_opticalflow_datasets(x_dataset, y_dataset, 
    image_in_x, time_diff_vel_in_x, time_diff_accel_in_x, use_double_optical_flow, 
    use_difference_output, farneback_flow_params=[ None, 0.5, 3, 15, 3, 5, 1.2, 0]):
    """
    Converts the normal 3 image datasets: (img -2, img -1, img 0) -> (img 1)
    to a dataset where the X dataset has components of optical flow
    
    args: 
        - x_dataset: dataset of tuples (img -2, img -1, img 0) 
        - y_dataset: dataset of images (img 1)
        - image_in_x: if the new x dataset has the previous image in it
        - time_diff_vel_in_x: if the time difference is in the x dataset of the image
        - time_diff_accel_in_x: if difference of difference images is in the x dataset of the image
        - use_double_optical_flow: optical flow of optical flow: somehow create black and white images 
            from the magnitude and angle images returned from the single layer optical flow 
        - use_difference_output: rather than output the next image output the difference between the previous 
            image and the predicted next one
        - farneback_flow_params: list of parameters for cv2's farneback optical flow 
            
        NOTE: can specify y dataset as None if you don't want them
    returns: 
        - x dataset 
        - y dataset
        - title_order: the order of the titles for the images in each x dataset tuple
    """
    assert(type(x_dataset[0]) == tuple)
    assert(len(x_dataset[0]) == 3)
    
    num_datapoints = None
    num_datapoints = None
    if type(x_dataset) != type(None): 
        num_datapoints = len(x_dataset)
    elif type(y_dataset) != type(None): 
        if type(num_datapoints) != type(None): 
            num_datapoints = min(num_datapoints, len(y_dataset))
        else: 
            num_datapoints = len(y_dataset)
    else: 
        num_datapoints = 0
        
    of_x_dataset = []
    of_y_dataset = []
    
    # title order 
    title_order = []
        
    for i in range(num_datapoints): 
        new_x_datapoint = None
        new_y_datapoint = None
        
        if (type(x_dataset) != type(None)): 
            new_x_datapoint = []
            
            x_datapoint = x_dataset[i]
            last_x_img = x_datapoint[-1]
            last_last_x_img = x_datapoint[-2]
            last_last_last_x_img = x_datapoint[-3]
            
            x_optical_flow = cv2.calcOpticalFlowFarneback(last_last_x_img, last_x_img, 
                            farneback_flow_params[0], farneback_flow_params[1], 
                            farneback_flow_params[2], farneback_flow_params[3], 
                            farneback_flow_params[4], farneback_flow_params[5], 
                            farneback_flow_params[6], farneback_flow_params[7])
            x_optical_flow_mag = x_optical_flow[..., 0]
            x_optical_flow_ang = x_optical_flow[..., 1]
            
            if image_in_x: 
                new_x_datapoint.append(last_x_img)
                if i == 0:
                    title_order.append('image')
                
            # append the optical flow 
            new_x_datapoint.append(x_optical_flow_mag)
            new_x_datapoint.append(x_optical_flow_ang)
            if i == 0: 
                title_order.append('optical flow x comp')
                title_order.append('optical flow y comp')
            
            if time_diff_vel_in_x:
                x_velocity = last_x_img - last_last_x_img
                new_x_datapoint.append(x_velocity)
                if i == 0: 
                    title_order.append('velocity (diff)')
            if time_diff_accel_in_x: 
                x_acceleration = (last_last_last_x_img - last_last_x_img) - (last_last_x_img - last_x_img)
                new_x_datapoint.append(x_acceleration)
                if i == 0: 
                    title_order.append('accel (diff)')
            if use_double_optical_flow: 
                # NOTE: DOES NOT WORK - MAYBE DOESN'T MAKE SENSE DON'T USE YET
                # black and white with latest optical flow
                hsv = np.zeros((last_x_img.shape[0], last_x_img.shape[1], 3))
                hsv[:, :, 0] = x_optical_flow_mag
                hsv[:, :, 2] = x_optical_flow_ang
                print(hsv.shape)
                print(hsv)
                bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
                of_black_white = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
                
                # black and white with last optical flow
                x_optical_flow2 = cv2.calcOpticalFlowFarneback(last_last_last_x_img, last_last_x_img, 
                            farneback_flow_params[0], farneback_flow_params[1], 
                            farneback_flow_params[2], farneback_flow_params[3], 
                            farneback_flow_params[4], farneback_flow_params[5], 
                            farneback_flow_params[6], farneback_flow_params[7])
                x_optical_flow2_mag = x_optical_flow2[..., 0]
                x_optical_flow2_ang = x_optical_flow2[..., 1]
                hsv2 = np.zeros((last_x_img.shape[0], last_x_img.shape[1], 3))
                hsv2[:, :, 0] = x_optical_flow2_mag
                hsv2[:, :, 2] = x_optical_flow2_ang
                bgr2 = cv2.cvtColor(hsv2, cv2.COLOR_HSV2BGR)
                of_black_white2 = cv2.cvtColor(bgr2, cv2.COLOR_BGR2GRAY)
                
                x_dbl_optical_flow = cv2.calcOpticalFlowFarneback(of_black_white, of_black_white2, 
                            farneback_flow_params[0], farneback_flow_params[1], 
                            farneback_flow_params[2], farneback_flow_params[3], 
                            farneback_flow_params[4], farneback_flow_params[5], 
                            farneback_flow_params[6], farneback_flow_params[7])
                x_dbl_optical_flow_mag = x_optical_flow[..., 0]
                x_dbl_optical_flow_ang = x_optical_flow[..., 1]
                
                new_x_datapoint.append(x_dbl_optical_flow_mag)
                new_x_datapoint.append(x_dbl_optical_flow_ang)
                title_order.append('dbl of mag')
                title_order.append('dbl of ang')
            
            new_x_datapoint = tuple(new_x_datapoint)
                
        if (type(y_dataset) != type(None)): 
            y_datapoint = y_dataset[i]
            if use_difference_output: 
                new_y_datapoint = y_datapoint - last_x_img
            else: 
                new_y_datapoint = y_datapoint
                
        of_x_dataset.append(new_x_datapoint)
        of_y_dataset.append(new_y_datapoint)
        
    return of_x_dataset, of_y_dataset, title_order