# File containing code for pre and post processing datasets. 

import numpy as np 
import cv2
import os 
import matplotlib.pyplot as plt 
from scipy.io import loadmat
from copy import deepcopy

class processing:

    def __init__(self, xtypes, ytype, farneback_flow_params=[None, 0.5, 3, 15, 3, 5, 1.2, 0]): 
        """
        Initialize the processing class
        args: 
            - xtypes: list of strings - order indicates order of types of x components in dataset
                - 'img-2': 2 images before current img
                - 'img-1': 1 image before current img
                - 'img0': current img - when predicting the next
                - 'vel_diff': image difference velocity ('img0' - 'img-1')
                - 'accel_diff': image difference acceleration (('img0' - 'img-1') - ('img-1'  'img-2'))
                - 'of_x': optical flow x 
                - 'of_y': optical flow y
            - ytype: a string that indicates the type of the y components in dataset (also what is to be predicted)
                - 'img': the next image 
                - 'diff': difference between the last image and the next
            - farneback_flow_params: parameters for optical flow - can be/default to None if not using any optical flow
        """

        self.xtypes = xtypes
        self.ytype = ytype
        self.farneback_flow_params = farneback_flow_params

        if 'of_x' in self.xtypes or 'of_y' in self.xtypes: 
            # ensure you have the correct optical flow parameters
            assert(type(self.farneback_flow_params) != type(None))

        return 

    def create_x(self, image_seq): 
        """
        Creates x dataset of types specified by the xtypes given a sequence of images
        args: 
            - image_seq
        returns: 
            - xdataset
        """
        xdataset = []
        curr_img_num = 0 # index of the current image - the one before the image to be predicted in image_seq
        # offset the start based on what the xtypes contains
        curr_img_num = self.images_per_xdatapoint() - 1 # -1 as the current index should be the last image in the x datapoint - need enough space to create the first x datapoint

        while curr_img_num < len(image_seq) - 1: # you don't want the last image as that will be the corresponding y output 
            curr_img_tup_list = []
            if 'of_x' in self.xtypes or 'of_y' in self.xtypes: 
                # calculate optical flow
                xy_optical_flow = cv2.calcOpticalFlowFarneback(image_seq[curr_img_num - 1], image_seq[curr_img_num], 
                    self.farneback_flow_params[0], self.farneback_flow_params[1], self.farneback_flow_params[2], 
                    self.farneback_flow_params[3], self.farneback_flow_params[4], self.farneback_flow_params[5], 
                    self.farneback_flow_params[6], self.farneback_flow_params[7])
                x_optical_flow = xy_optical_flow[..., 0]
                y_optical_flow = xy_optical_flow[..., 1]

            for xtype in self.xtypes: 
                if xtype == 'img-2':
                    curr_img_tup_list.append(image_seq[curr_img_num - 2])
                elif xtype == 'img-1':
                    curr_img_tup_list.append(image_seq[curr_img_num - 1])
                elif xtype == 'img0': 
                    curr_img_tup_list.append(image_seq[curr_img_num])
                elif xtype == 'vel_diff':
                    curr_img_tup_list.append(image_seq[curr_img_num] - image_seq[curr_img_num - 1])
                elif xtype == 'accel_diff':
                    curr_img_tup_list.append((image_seq[curr_img_num] - image_seq[curr_img_num - 1]) - 
                        (image_seq[curr_img_num - 1] - image_seq[curr_img_num - 2]))
                elif xtype == 'of_x':
                    curr_img_tup_list.append(x_optical_flow)
                elif xtype == 'of_y': 
                    curr_img_tup_list.append(y_optical_flow)
                else: 
                    print("Incorrect xtype found in processing create x!")
                    assert(False)               

            # append image tuple to xdataset
            xdataset.append(tuple(curr_img_tup_list))
            curr_img_num += 1

        return  xdataset

    def create_xy(self, image_seq): 
        """
        Creates corresponding x and y datasets of the types specified by the xtypes and ytype given a sequence of images
        args: 
            - image_seq: sequence of images
        returns: 
            - xdataset
            - ydataset
        """
        xdataset = self.create_x(image_seq)
        
        ydataset = []
        # offset the current image num accordingly 
        # start after the number of images used in the first x point (don't need to add 1 since zero indexed)
        curr_img_num = self.images_per_xdatapoint()

        while curr_img_num < len(image_seq): # loop until the end as these are the associated y values
            if self.ytype == 'img':
                ydataset.append(image_seq[curr_img_num])
            elif self.ytype == 'diff':
                ydataset.append(image_seq[curr_img_num] - image_seq[curr_img_num - 1])
            curr_img_num += 1
        return xdataset, ydataset

    def postprocess_y(self, xdataset, ydataset): 
        """
        Post processes the ydataset output into a sequence of images given the ydataset and the xdataset used to generate 
        the ydatset 
        args:   
            - xdataset
            - ydataset: corresponding to the xdataset
        returns: 
            - image_seq: corresponding to the xdataset and ydataset
        """

        # if the y type is just an image nothing is needed to be done
        if self.ytype == 'img': 
            return ydataset
        elif self.ytype == 'diff':
            recovered_images = []
            # index of the image to which to add the difference in the xdataset tuples
            img0_index = self.xtypes.index('img0')
            for diff_img_index, diff_img in enumerate(ydataset):
                curr_recovered_img = diff_img + xdataset[diff_img_index][img0_index]

                recovered_images.append(curr_recovered_img)
            return recovered_images
        else: 
            print("Incorrect ytype used! ")
            assert(False)
        
        return 

    def convert_imgdataset_to_vecdataset(self, dataset): 
        """
        Convert dataset: list of images or list of tuples of images to a dataset that is a list of 
        1 dimensional vectors. If tuple of images it will create an extended list of the flattened 
        images in the tuple concatenated. 
        args:
            - dataset: list of images or list of tuples of images
        returns: 
            - dataset_vec: list of 1 dimensional arrays 
        """

        dataset_vec = []
        # only do this flattening if the data is images
        if len(np.array(dataset).shape) >= 2: 
            # Training Data
            for train_num in range(len(dataset)):
                # Create current X vector
                if type(dataset[0]) == tuple: 
                    curr_x_vec = []
                    for ximg in dataset[train_num]:
                        ximg_vec = np.array(ximg).reshape(-1)
                        curr_x_vec.extend(list(ximg_vec))
                else: 
                    curr_x_vec = dataset[train_num].reshape(-1)

                dataset_vec.append(np.array(curr_x_vec))

        else: 
            dataset_vec = dataset

        return dataset_vec

    def images_per_xdatapoint(self):
        """
        Based on the xtypes returns the images used to generate each x datapoint
        Note: this includes the current image - the y value corresponding will be the image after
        """
        if 'img-2' in self.xtypes or 'accel_diff' in self.xtypes: 
            return 3
        elif 'img-1' in self.xtypes or 'vel_diff' in self.xtypes: 
            return 2

    def nvstokes_file_to_image_seq(self, filename): 
        """
        Convert the filename to sequence of images 
        args: 
            - filename: type: string: the full path to the nvstokes dataset
        returns: 
            - image_seq: sequence of images
        """
        loaded_dataset = loadmat(filename)
        all_image_mat = loaded_dataset['u'][0]
        image_seq = [all_image_mat[:, :, i] for i in range(all_image_mat.shape[-1])]

        return image_seq

    def nvstokes_file_to_image_sequences(self, filename):
        """
        Convert the filename to pull out all the image sequences stored within the file
        args: 
            - filename: type: string: the full path to the nvstokes dataset
        returns: 
            - image_seq: sequence of images
        """
        loaded_dataset = loadmat(filename)
        num_sequences = len(loaded_dataset['u'])
        image_sequences = []
        for seq_num in range(num_sequences):
            all_image_mat = loaded_dataset['u'][seq_num]
            image_sequences.append([all_image_mat[:, :, i] for i in range(all_image_mat.shape[-1])])
        return image_sequences


    def nvstokes_file_to_xy(self, filename): 
        image_seq = self.nvstokes_file_to_image_seq(filename)
        return self.create_xy(image_seq)

    def nvstokes_file_to_x(self, filename): 
        image_seq = self.nvstokes_file_to_image_seq(filename)
        return self.create_x(image_seq)

    def video_to_imageseq(self, filename): 
        """
        Takes a filename converts it to a sequence of images and returns the list 
        args: 
            - filename: string
        returns: 
            - list of images: list of numpy arrays that are images
        """
        all_images = []
        vidcap = cv2.VideoCapture(filename)
        success, image = vidcap.read()
        
        while success: 
            # convert image and then append
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            all_images.append(image)
            success, image = vidcap.read()
        
        return all_images

    def downsample_imageseq(self, image_seq, num_pyrDowns): 
        """
        downsample the image sequence by the number of pyrDowns
        args: 
            - image_seq
            - num_pyrDowns: number of times to downsample
        returns: 
            - downsampled image sequence
        """
        down_image_seq = []
        for img in image_seq: 
            down_image = deepcopy(img)
            for i in range(num_pyrDowns): 
                down_image = cv2.pyrDown(down_image)
            down_image_seq.append(down_image)
            
        return down_image_seq

    def pre_process_video(self, filename, num_pyrDowns, to_blackandwhite=True, background_subtraction=False):
        """
        Fully pre process data from filename of a video (.mp4) to a sequence of images
        args:  
            - filename: type: string: full filepath to video file
            - num_pyrDowns: type: int: the number of times to pyramid downsample the images in the video 
            - to_blackandwhite: type bool: 
                - True: converts the image to black and white if 3 channeled images
                - False: leaves the images as is
            - background_subtraction: type: bool
                - True: removes the background by subtracting the average image of all the frames
                - False: does nothing
        returns: 
            - image_seq: the pre processed image sequence that corresponds to the video 
        """

        init_image_seq = self.video_to_imageseq(filename=filename)
        image_seq = self.downsample_imageseq(image_seq=init_image_seq, num_pyrDowns=num_pyrDowns)

        if to_blackandwhite:
            if len(image_seq[0].shape) >= 3 and image_seq[0].shape[-1] == 3: 
                # NOTE: the starting colors are RGB as they are not loaded through cv2
                image_seq = [cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in image_seq]
                # convert to float images
                max_img_val = 255.0
                image_seq = [img/max_img_val for img in image_seq]

        if background_subtraction:
            mean_background = np.mean(np.array(image_seq), axis=0)
            image_seq = [img - mean_background for img in image_seq]

        return image_seq

