# This file contains all the functions to patchify and unpatchify images in the dataset. 

import numpy as np
import cv2
import matplotlib.pyplot as plt 
import GPy 
import pdb
from copy import deepcopy

class patcher:

    def __init__(self, img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight):
        """
        Initialize class to patchify and unpatchify images
        args:  
            - img_dim
            - patch_dim
            - patch_border
            - img_padlen
            - img_padtype
            - wrap_x
            - wrap_y
            - stride 
            - patch_weight: the patch weight to use when recombining
        NOTE: for argument descriptions look at the reinit function
        """
        #print("I AM A  TACO")
        self.reinit(img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight)
        return 


    def reinit(self, img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight):
        """
        Re-Initialize class to patchify and unpatchify images. 
        Sets the class attributes. 

        args: 
            - img_dim: tuple (rows, cols): dimension of the image
            - patch_dim: tuple (rows, cols): dimension of the patch
            - patch_border: tuple (rows, cols): number of rows, cols that are used as the border of the patch 
            - img_padlen: tuple (rows, cols): number of rows, cols used to pad the image
            - img_padtype: string: to indicate the method used for padding the image. For viable options look at self.pad_image 
            - wrap_x: whether you want the patches to wrap in the x or column direction
            - wrap_y: whether you want the patches to wrap in the y or row direction
            - stride: tuple (row stride, col stride): the number of pixels to skip by when creating image patches
            - patch_weight: the patch weight to use when recombining
        returns: 
            - None: NOTE: will print a warning if the attributes will not give proper coverage of the image for full reconstruction
        """
        #print("I AM A TACO 2")
        self.img_dim = img_dim
        self.patch_dim = patch_dim
        self.patch_border = patch_border
        self.img_padlen = img_padlen
        self.img_padtype = img_padtype
        self.wrap_x = wrap_x
        self.wrap_y = wrap_y
        self.stride = stride
        self.patch_weight = patch_weight

        # assert things are of the correct type
        assert(type(img_dim) == tuple)
        assert(type(patch_dim) == tuple)
        assert(type(patch_border) == tuple)
        assert(type(img_padlen) == tuple)
        assert(type(stride) == tuple)
        
        print("Stride: ", stride)
        
        # TESTS: make sure patchification is reconstructable
        # For the unpatched image to make sense/be the same you need the following
        #     1. img_padlen >= patch_border
        #     2. stride <= patch_border
        assert(self.patch_border[0] < self.patch_dim[0]/2)
        assert(self.patch_border[1] < self.patch_dim[1]/2)

        if not self.wrap_x:
            assert(self.img_padlen[0] >= self.patch_border[0])
        if not self.wrap_y:
            assert(self.img_padlen[1] >= self.patch_border[1])
        
        #assert(self.stride[0] <= self.patch_dim[0] - 2*self.patch_border[0])
        #assert(self.stride[1] <= self.patch_dim[1] - 2*self.patch_border[1])
        #assert(self.patch_dim[0] >= self.stride[0])
        #assert(self.patch_dim[1] >= self.stride[1])
        
        if (self.stride[0] > self.patch_dim[0] - 2*self.patch_border[0]) or (self.patch_dim[0] < self.stride[0]):
            print("WARNING: these patches will not cover the whole image in the X direction")
            print("\t X stride: ", self.stride[0])
            print("\t X input patch dimension: ", self.patch_dim[1])
            print("\t X output patch dimension: ", self.patch_dim[0] - 2*self.patch_border[0])
        if (self.stride[0] > self.patch_dim[0] - 2*self.patch_border[1]) or (self.patch_dim[1] < self.stride[1]):
            print("WARNING: these patches will not cover the whole image in the Y direction")
            print("\t Y stride: ", self.stride[1])
            print("\t Y input patch dimension: ", self.patch_dim[1])
            print("\t Y output patch dimension: ", self.patch_dim[0] - 2*self.patch_border[1])
        # set all start tuples
        self.all_start_tuples = self.get_row_col_start_tuples(img_dim=self.img_dim, 
            img_padlen=self.img_padlen, 
            patch_dim=self.patch_dim, 
            patch_border=self.patch_border, 
            stride=self.stride, 
            wrap_x=self.wrap_x, 
            wrap_y=self.wrap_y)
        
        # padding should allow for complete reconstruction - make sure the last patches prediction region ends at least as far as the border starts 
        #(no part of the non border image is not in the potential prediction region of the patches)
        padded_img_dim = (self.img_dim[0] + self.img_padlen[0], self.img_dim[1] + self.img_padlen[1])
        if not self.wrap_x: 
            assert(np.arange(0, padded_img_dim[1], self.stride[1])[-1] + self.patch_dim[1] - self.patch_border[1] >= self.img_dim[1] + self.img_padlen[1])
            #assert(np.arange(0, padded_img_dim[1], stride)[-1] + patch_dim - patch_border >= img_dim[1] + img_padlen)

        if not wrap_y: 
            assert(np.arange(0, padded_img_dim[0], self.stride[0])[-1] + self.patch_dim[0] - self.patch_border[0] >= self.img_dim[0] + self.img_padlen[0])
            #assert(np.arange(0, padded_img_dim[0], stride)[-1] + patch_dim - patch_border >= img_dim[0] + img_padlen)
                
        return 

    def patchify_image(self, img, img_type): 
        """
        Patchify image
        args: 
            - img
            - img_type: string: the type of image 'x' or 'y'

        args needed from attributes: 
            - patch_dim
            - stride
            - wrap_x
            - wrap_y
            - all_start_tuples
            - patch_border
            - img_padtype
            - img_padlen

        returns: 
            - patch_list: list of patches corresponding to the image

        Example of patch border: 
        Eg: of input and output dataset patches when patch_dim = 4, patch_border = 1
            patch 1       output patch  
            0 0 0 0 
            0 0 0 0  -->     0 0 
            0 0 0 0          0 0  
            0 0 0 0 

        """
        if self.img_padlen[0] > 0 or self.img_padlen[1] > 0:
            padded_img = self.pad_image(img, self.img_padlen, self.img_padtype)
        else:
            padded_img = deepcopy(img)
        padded_img_dim = padded_img.shape
        
        # if wrapping further pad the image
        tile_repeats = [1,1]
        row_offset = 0
        col_offset = 0
        if self.wrap_x: 
            tile_repeats[1] = 3
            col_offset = padded_img.shape[1]
        if self.wrap_y:
            tile_repeats[0] = 3
            row_offset = padded_img.shape[0]
        padded_img = np.tile(padded_img, tile_repeats)

        if img_type == 'y':
            used_patch_dim = self.get_ypatch_dim()
        else: 
            used_patch_dim = self.patch_dim
        
        all_patches = []
        for (row_start, col_start) in self.all_start_tuples:
            if  img_type == 'y':
                row_start += self.patch_border[0]
                col_start += self.patch_border[1]
            row_start = row_start + row_offset
            col_start = col_start + col_offset
            current_patch = padded_img[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]]

            all_patches.append(current_patch)
            
        return all_patches

    def unpatchify_image(self, patch_list, patch_variance_list, img_type): 
        """
        Unpatchify image from patch list 
        args: 
            - patch_list: list of patches that form an image
            - patch_variance_list: list of floats of length patsch_list corresponding to the variance of the predicted patch
                in patch_list - if specified as None: then do not use
                NOTE: there is one variance value per patch 
            - img_type: string: whether the patches correspond to a 'y' or 'x' image
        returns: 
            - reconstructed img: numpy array 
            - reconstructed padded image
            - reconstructed variance image 
        """
        padded_img_dim = list(self.img_dim)
        padded_img_dim[0] += 2*self.img_padlen[0]
        padded_img_dim[1] += 2*self.img_padlen[1]
        padded_img_dim = tuple(padded_img_dim)

        patch_counter = 0
        patch_img_list = []
        mask_img_list = []

        if img_type == 'y':
            # get the smaller patch dim
            used_patch_dim = self.get_ypatch_dim()
        else: 
            used_patch_dim = self.patch_dim
        
        patch_weight = self.patch_weight
        if type(patch_weight) == type(None): 
            # create corresponding boolean mask 
            if img_type == 'y': 
                patch_weight = np.ones(self.get_ypatch_dim())
            else: 
                patch_weight = np.ones(self.patch_dim)
        else:
            if img_type == 'y':
                old_patch_weight = deepcopy(patch_weight)
                if self.patch_border[0] == 0 and self.patch_border[1] == 0:
                    patch_weight = old_patch_weight
                elif self.patch_border[0] == 0:
                    patch_weight = old_patch_weight[self.patch_border[0]:-self.patch_border[0], :]
                elif self.patch_border[1] == 0:
                    patch_weight = old_patch_weight[:, self.patch_border[1]:-self.patch_border[1]]
                else: 
                    patch_weight = old_patch_weight[self.patch_border[0]:-self.patch_border[0], self.patch_border[1]:-self.patch_border[1]]

        tile_repeats = [1,1]
        row_offset = 0
        col_offset = 0
        if self.wrap_x: 
            tile_repeats[1] = 3
            col_offset = padded_img_dim[1]
        if self.wrap_y:
            tile_repeats[0] = 3
            row_offset = padded_img_dim[0]

        padded_img = np.zeros(padded_img_dim)
        # update the padded image with tiling to account for wrapping
        super_padded_img = np.tile(padded_img, tile_repeats)
        super_padded_img_dim = tuple(super_padded_img.shape)

        patch_img_list = []
        mask_img_list = []
        for (row_start, col_start) in self.all_start_tuples: 
            if img_type == 'y':
                row_start += self.patch_border[0]
                col_start += self.patch_border[1]
            
            row_start = row_start + row_offset
            col_start = col_start + col_offset
            current_patch = patch_list[patch_counter]

            current_patch_img  = np.zeros(super_padded_img_dim)
            current_patch_mask = np.zeros(super_padded_img_dim)
            
            try: 
                current_patch_img[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]] = current_patch
                # only encourage averaging where the actual predicted pixels are 
                current_patch_mask[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]] = patch_weight
            except: 
                pdb.set_trace()

            if type(patch_variance_list) != type(None): 
                # weight using variance - you want higher variance to be weighted less!
                current_patch_mask = current_patch_mask * 1/patch_variance_list[patch_counter]

            # append to lists 
            patch_img_list.append(current_patch_img)
            mask_img_list.append(current_patch_mask)
            # increase counter
            patch_counter += 1

        # If there is tiling add the tiles together
        reconstructed_padded_img = np.zeros(padded_img_dim)
        reconstructed_padded_normalization = np.zeros(padded_img_dim)
        reconstructed_padded_var_img = np.zeros(padded_img_dim)
        for found_patch_num in range(len(patch_img_list)):
            current_patch_img = patch_img_list[found_patch_num]
            current_patch_mask = mask_img_list[found_patch_num]

            # add all the tiles together
            for row_tile_num in range(tile_repeats[0]):
                for col_tile_num in range(tile_repeats[1]):
                    curr_patch_img_tile = current_patch_img[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]
                    curr_patch_mask_tile = current_patch_mask[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]
                    
                    reconstructed_padded_img += curr_patch_img_tile * curr_patch_mask_tile
                    reconstructed_padded_normalization += curr_patch_mask_tile

        # normalize by the total variance weight
        reconstructed_padded_normalization_muldivide = deepcopy(reconstructed_padded_normalization)
        reconstructed_padded_normalization_muldivide[reconstructed_padded_normalization_muldivide == 0] = 1
        #reconstructed_padded_normalization_divide[np.where(reconstructed_padded_normalization != 0)] = 1/reconstructed_padded_normalization[np.where(reconstructed_padded_normalization != 0)] 
        reconstructed_padded_img = reconstructed_padded_img / reconstructed_padded_normalization_muldivide
        
        if self.img_padlen[0] == 0 and self.img_padlen[1] == 0:
            reconstructed_img = reconstructed_padded_img[:, :]
            reconstructed_variance_img = reconstructed_padded_normalization[:]
        elif self.img_padlen[0] == 0:
            reconstructed_img = reconstructed_padded_img[:, self.img_padlen[1]:-self.img_padlen[1]]
            reconstructed_variance_img = reconstructed_padded_normalization[:, self.img_padlen[1]:-self.img_padlen[1]]
        elif self.img_padlen[1] == 0:
            reconstructed_img = reconstructed_padded_img[self.img_padlen[0]:-self.img_padlen[0], :]
            reconstructed_variance_img = reconstructed_padded_normalization[self.img_padlen[0]:-self.img_padlen[0], :]
        else:
            reconstructed_img = reconstructed_padded_img[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])]
            reconstructed_variance_img = reconstructed_padded_normalization[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])] # NOTE: only valid if the patch variance list is actually specified - UNSURE IF THIS IS VALID

        return reconstructed_img, reconstructed_padded_img, reconstructed_variance_img, patch_img_list, mask_img_list

    def unpatchify_mean_variance_image(self, mean_patch_list, var_patch_list, patch_weight=None):
        """
        This function is used to unpatchify mean and variance images. The patches when combined together 
        form a Gaussian Mixture model at each pixel location. Using a moment matching method these must
        be properly combined to form a new gaussian at each pixel. 

        NOTE: Call refine patch weight before using this function

        Formulas: 
        Moment matching Gaussian Mixture Model to a single Gaussian distribution 
        1. Mean: weighted average of individual gaussian means 
        2. Variance: sum(pi_i * (cov_i + [mu_i mu_i ^ T])) - [mu mu^T]

        args: 
            - mean_patch_list: the mean patch prediction outputted from the mean propagation 
            - var_patch_list: the var patch prediction outputted from the var propagation
                - these correspond to the mean patches 
            - patch_weight: patch weight to be used for weighting the Gaussian Mixture Model
        returns: 
            - mean_image
            - variance_image
        """
        img_type = 'y' # img type is always y for this function 

        used_patch_dim = self.get_ypatch_dim()
        assert(patch_weight.shape == used_patch_dim)

        if type(patch_weight) == type(None):
            patch_weight = np.ones(used_patch_dim)

        padded_img_dim = list(self.img_dim)
        padded_img_dim[0] += 2*self.img_padlen[0]
        padded_img_dim[1] += 2*self.img_padlen[1]
        padded_img_dim = tuple(padded_img_dim)

        tile_repeats = [1,1]
        row_offset = 0
        col_offset = 0
        if self.wrap_x: 
            tile_repeats[1] = 3
            col_offset = padded_img_dim[1]
        if self.wrap_y:
            tile_repeats[0] = 3
            row_offset = padded_img_dim[0]

        padded_img = np.zeros(padded_img_dim)
        # update the padded image with tiling to account for wrapping
        super_padded_img = np.tile(padded_img, tile_repeats)
        super_padded_img_dim = tuple(super_padded_img.shape)

        ######################################################################################
        # Mean and Variance Image: unpatchify loop
        patch_counter = 0
        mean_patch_img_list = []
        var_patch_img_list_presub = []
        mask_img_list = []
        for (row_start, col_start) in self.all_start_tuples: 
            if img_type == 'y':
                row_start += self.patch_border[0]
                col_start += self.patch_border[1]
            
            row_start = row_start + row_offset
            col_start = col_start + col_offset
            current_mean_patch = mean_patch_list[patch_counter]
            current_var_patch = var_patch_list[patch_counter]

            current_mean_patch_img  = np.zeros(super_padded_img_dim)
            current_var_patch_img = np.zeros(super_padded_img_dim)
            current_patch_mask = np.zeros(super_padded_img_dim)
            
            current_mean_patch_img[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]] = current_mean_patch
            current_var_patch_img[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]] = current_var_patch # OLD REMOVE + np.square(current_mean_patch)
            # only encourage averaging where the actual predicted pixels are 
            current_patch_mask[row_start:row_start + used_patch_dim[0], col_start:col_start + used_patch_dim[1]] = patch_weight

            # append to lists 
            mean_patch_img_list.append(current_mean_patch_img)
            var_patch_img_list_presub.append(current_var_patch_img)
            mask_img_list.append(current_patch_mask)
            # increase counter
            patch_counter += 1

        # If there is tiling add the tiles together
        mean_reconstructed_padded_img = np.zeros(padded_img_dim)
        var_reconstructed_padded_img_presub = np.zeros(padded_img_dim)
        reconstructed_padded_normalization = np.zeros(padded_img_dim)
        for found_patch_num in range(len(mean_patch_img_list)):
            current_mean_patch_img = mean_patch_img_list[found_patch_num]
            current_var_patch_img = var_patch_img_list_presub[found_patch_num]
            current_patch_mask = mask_img_list[found_patch_num]

            # add all the tiles together
            for row_tile_num in range(tile_repeats[0]):
                for col_tile_num in range(tile_repeats[1]):
                    curr_mean_patch_img_tile = current_mean_patch_img[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]
                    curr_var_patch_img_tile = current_var_patch_img[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]
                    curr_patch_mask_tile = current_patch_mask[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]

                    mean_reconstructed_padded_img += curr_mean_patch_img_tile * curr_patch_mask_tile
                    var_reconstructed_padded_img_presub += curr_var_patch_img_tile * curr_patch_mask_tile
                    reconstructed_padded_normalization += curr_patch_mask_tile

        # normalize by the total variance weight
        reconstructed_padded_normalization_muldivide = deepcopy(reconstructed_padded_normalization)
        reconstructed_padded_normalization_muldivide[reconstructed_padded_normalization_muldivide == 0] = 1
        #reconstructed_padded_normalization_divide[np.where(reconstructed_padded_normalization != 0)] = 1/reconstructed_padded_normalization[np.where(reconstructed_padded_normalization != 0)] 
        mean_reconstructed_padded_img = mean_reconstructed_padded_img / reconstructed_padded_normalization_muldivide
        # var_reconstructed_padded_img_presub = var_reconstructed_padded_img_presub / reconstructed_padded_normalization_muldivide
        # var_reconstructed_padded_img = var_reconstructed_padded_img_presub - np.square(mean_reconstructed_padded_img)
        var_reconstructed_padded_img = var_reconstructed_padded_img_presub / np.square(reconstructed_padded_normalization_muldivide) # Correct averaging method 7/22

        if self.img_padlen[0] == 0 and self.img_padlen[1] == 0:
            reconstructed_mean_img = mean_reconstructed_padded_img[:, :]
            reconstructed_var_img = var_reconstructed_padded_img[:,:]
        elif self.img_padlen[0] == 0:
            reconstructed_mean_img = mean_reconstructed_padded_img[:, self.img_padlen[1]:-self.img_padlen[1]]
            reconstructed_var_img = var_reconstructed_padded_img[:, self.img_padlen[1]:-self.img_padlen[1]]
        elif self.img_padlen[1] == 0:
            reconstructed_mean_img = mean_reconstructed_padded_img[self.img_padlen[0]:-self.img_padlen[0], :]
            reconstructed_var_img = var_reconstructed_padded_img[self.img_padlen[0]:-self.img_padlen[0], :]
        else:
            reconstructed_mean_img = mean_reconstructed_padded_img[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])]
            reconstructed_var_img = var_reconstructed_padded_img[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])]

        #return reconstructed_img, reconstructed_padded_img, reconstructed_variance_img, patch_img_list, mask_img_list
        return reconstructed_mean_img, reconstructed_var_img


    def patchify_dataset(self, dataset, dataset_type): 
        """
        Patchify image on whole dataset
        args: 
            - dataset: list of images or tuples of images
            - dataset_type: string: 'x' or 'y' dataset
        returns: 
            - patch_list: list of patches - concatenated patch list of each of the datapoints in the dataset
        """
        dataset_patches = []
        for datapoint in dataset:
            if type(datapoint) == tuple: 
                # datapoint is a tuple of images
                data_point_patches = []
                for img in datapoint: 
                    data_point_patches.append(self.patchify_image(img, dataset_type))

                patch_tuple_list = []
                for i in range(len(data_point_patches[0])):
                    curr_patch_tuple = tuple([data_point_patches[tup_index][i] for tup_index in range(len(data_point_patches))])
                    patch_tuple_list.append(curr_patch_tuple)

                # extend dataset patches with found patch tuple list 
                dataset_patches.extend(patch_tuple_list)

            else: 
                # datapoint is a single image 
                dataset_patches.extend(self.patchify_image(datapoint, dataset_type))

        return dataset_patches

    def unpatchify_dataset(self, dataset_patch_list, dataset_patch_variance_list, dataset_type): 
        """
        Unpatchify list of patches for whole dataset to list of images
        args: 
            - dataset_patch_list: list of patches: concatenated patch list of each of the datapoints in the dataset
            - dataset_patch_variance_list: list of floats of length dataset_patch_list: there is one variance value per patch 
                in the dataset_patch_list
            - dataset_type: string: 'x' or 'y' dataset
        returns: 
            - img_dataset: list of images reconstructed from the dataset_patch_list
            - var_dataset: list of images that correspond to the variance image reconstructed from the variance values of the patches and their masks
        # NOTE: no support for unpatchifying image tuples only datasets where each datapoint is a single image
        """
        num_patches_per_image = len(self.all_start_tuples) # the number of patches per image

        num_images = len(dataset_patch_list)/num_patches_per_image
        assert(num_images%1==0) # assert enough patches for num_images full images

        img_dataset = []
        var_dataset = []
        all_patch_img_lists = []
        all_mask_img_lists = []
        for img_num in range(int(num_images)): 
            if type(dataset_patch_variance_list) != type(None): 
                curr_img, curr_img_padded, curr_var_img, curr_patch_img_list, curr_mask_img_list = \
                self.unpatchify_image(dataset_patch_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    patch_variance_list=dataset_patch_variance_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    img_type=dataset_type)
            else: 
                curr_img, curr_img_padded, curr_var_img, curr_patch_img_list, curr_mask_img_list = \
                self.unpatchify_image(dataset_patch_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    patch_variance_list=None, 
                    img_type=dataset_type)

            img_dataset.append(curr_img)
            var_dataset.append(curr_var_img)
            #all_patch_img_lists.extend(list(curr_patch_img_list))
            #all_mask_img_lists.extend(list(curr_mask_img_list))
        return img_dataset, var_dataset, all_patch_img_lists, all_mask_img_lists

    ####################################################################################################################
    ############################################## Helper functions ####################################################
    ####################################################################################################################

    def get_row_col_start_tuples(self, img_dim, img_padlen, patch_dim, patch_border, stride, wrap_x, wrap_y, allow_negative_wrap=True): 
        """ 
        Get the starting rows and columns (denoting the top left pixel of the patch)
        from which to sample the patches. 
        
        NOTE: 
        - when wrapping you may get images that are less than 0
        
        args: 
            - img_dim: tuple (rows, cols)
            - img_padlen: tuple (rows, cols): number of pixels to pad the rows and cols
            - patch_dim: tuple (rows, cols)
            - patch_border: tuple (rows, cols): how many rows, cols to use for the patch border
            - stride: tuple (row stride, col stride)
            - wrap_x: Boolean: if the patch should wrap the x axis (columns)
            - wrap_y: Boolean: if the patch should wrap the y axis (rows)
            - allow_negative_wrap: boolean: when True allows negative values while wrapping
                DEFAULT: TRUE
        returns: 
            - list of tuples: (row, col) start points - assuming the appropriate tiling
        """
        padded_img_dim = list(img_dim)
        padded_img_dim[0] += 2*img_padlen[0]
        padded_img_dim[1] += 2*img_padlen[1]
        padded_img_dim = tuple(padded_img_dim)
        
        if allow_negative_wrap : 
            negative_row_start = 0 - patch_dim[0]
            negative_col_start = 0 - patch_dim[1]
            if wrap_x: 
                start_x = negative_col_start + stride[1]
            else: 
                start_x = 0
            if wrap_y:
                start_y = negative_row_start + stride[0]
            else: 
                start_y = 0
        else: 
            start_x = 0
            start_y = 0
            
        # col_list: denotes starting columns of the patches
        if wrap_x: 
            col_list = np.arange(start=start_x, stop=padded_img_dim[1] , step=stride[1])
        else: 
            print("padded img dim: ", padded_img_dim)
            print("patch dim: ", patch_dim)
            print("stride: ", stride)
            col_list = np.arange(start=start_x, stop=padded_img_dim[1] - patch_dim[1] + 1, step=stride[1])

        # row_list: denotes starting rows of the patches
        if wrap_y: 
            row_list = np.arange(start=start_y, stop=padded_img_dim[0] , step=stride[0])
        else:
            row_list = np.arange(start=start_y, stop=padded_img_dim[0] - patch_dim[0] + 1, step=stride[0])

        all_start_tuples = []
        for row in row_list: 
            for col in col_list: 
                orig_row = row
                orig_col = col
                # Note: might change - don't consider patch starts that you have already considered
                if wrap_y and row < 0:
                    row = row + img_dim[0]
                if wrap_x and col < 0: 
                    col = col + img_dim[1]
                if (row, col) in all_start_tuples: 
                    continue
                else:
                    all_start_tuples.append((orig_row, orig_col))
        
        print("padded img dim: ", padded_img_dim)
        print("row list: ", row_list)
        print("col list: ", col_list)
        # print warning if the whole
        # if the last patch prediction region does not cover the end of meaningful image
        if not wrap_y: 
            if row_list[-1] + patch_dim[0] - patch_border[0] < img_dim[0] + img_padlen[0]: 
                print("WARNING: your current parameters will not allow whole image reconstruction: ROW")
                print("Increase row padding by: ", (img_dim[0] + img_padlen[0]) - (row_list[-1] + patch_dim[0] - patch_border[0]))
        if not wrap_x: 
            if col_list[-1] + patch_dim[1] - patch_border[1] < img_dim[1] + img_padlen[1]: 
                print("WARNING: your current parameters will not allow whole image reconstruction: COLUMN")
                print("Increase col padding by: ", (img_dim[1] + img_padlen[1]) - (col_list[-1] + patch_dim[1] - patch_border[1]))
        
        self.all_start_tuples = all_start_tuples
        all_start_tuples = self.refine_row_start_col_tuples()   
        return all_start_tuples

    def refine_row_start_col_tuples(self):
        """
        Refines the all_start_tuples so that there are no patches that cause an exact overlap in the final image
        methodology: 
            - Take a randomly generated image of floats. Patchify it. if any two patches are the exact same the corresponding
            start tuple is redundant so don't use it
            - NOTE: this may cause issues so be careful
        """
        test_img = np.random.normal(size=(self.img_dim))
        test_patches = self.patchify_image(img=test_img, img_type='x')
        
        non_redundant_start_tuples = []
        non_redundant_patches = []
        for i in range(len(self.all_start_tuples)):
            current_patch = test_patches[i]
            add_patch = True
            for j in range(len(non_redundant_patches)):
                established_patch = non_redundant_patches[j]

                if np.sum(current_patch - established_patch) == 0:
                    add_patch = False
                    break
            if add_patch:
                non_redundant_start_tuples.append(self.all_start_tuples[i])
                non_redundant_patches.append(current_patch)

        print("Finished refinement")
        print(len(non_redundant_start_tuples))
        return non_redundant_start_tuples

    def pad_image(self, img, img_padlen, img_padtype): 
        """
        Pad the image with a border of size img_padlen with the img_padtype
        
        Methodology: 
            - img_padtype: 'black'
                have the border just be black pixels (all 0)
            - img_padtype: 'extend': 
                extend the last pixel on the border out by img_padlen rows/cols
            - img_padtype: 'mirror'
                mirror the edge to expand the border
            - img_padtype: 'wrap'
                wrap the image around - do this by tiling the image and then selecting the appropriate image
                
        args: 
            - img:
            - img_padlen: tuple (rows, cols): number of pixels to pad the rows and cols
            - img_padtype: string: one of the above strings specifying the method of padding
            
        returns: 
            - padded image
        """
        padded_img_shape = list(img.shape)
        padded_img_shape[0] += 2*img_padlen[0]
        padded_img_shape[1] += 2*img_padlen[1]
        padded_img_shape = tuple(padded_img_shape)
        
        if img_padtype == 'black': 
            padded_img = np.zeros(padded_img_shape)
            if img_padlen[0] == 0 and img_padlen[1] == 0:
                padded_img = img
            elif img_padlen[0] == 0:
                padded_img[:, img_padlen[1]:-img_padlen[1]] = img
            elif img_padlen[1] == 0:
                padded_img[img_padlen[0]:-img_padlen[0], :] = img
            else:   
                padded_img[img_padlen[0]:-(img_padlen[0]), img_padlen[1]:-(img_padlen[1])] = img
            
        elif img_padtype == 'extend':
            padded_img = np.zeros(padded_img_shape)
            padded_img[img_padlen[0]:-img_padlen[0], img_padlen[1]:-(img_padlen[1])] = img

            # TODO: optimize out the for loop later - use tile or repeat 
            # handle edges - the [img_padlen:-img_padlen] indexing is to exclude the corners  
            
            # Extend rows - NOTE: in these two loops the 1 is added as iteration from 0 to end - 1, so need to add 1 for the negative
            for row_border_level in range(img_padlen[0]): 
                if img_padlen[0] == 0:
                    padded_img[row_border_level, :][:] = img[0,:]
                    padded_img[-(row_border_level+1),:][:] = img[-1,:]
                else:
                    padded_img[row_border_level, :][img_padlen[0]:-(img_padlen[0])] = img[0,:]
                    padded_img[-(row_border_level+1),:][img_padlen[0]:-(img_padlen[0])] = img[-1,:]
            # Extend columns
            for col_border_level in range(img_padlen[1]):
                if img_padlen[1] == 0:
                    padded_img[:, col_border_level][:] = img[:, 0]
                    padded_img[:,-(col_border_level + 1)][:] = img[-1,:] 
                else:
                    padded_img[:, col_border_level][img_padlen[1]:-(img_padlen[1])] = img[:,0]
                    padded_img[:,-(col_border_level+1)][img_padlen[1]:-(img_padlen[1])] = img[:,-1]

            # handle corners 
            padded_img[0:img_padlen[0], 0:img_padlen[1]] = img[0,0]    # top left
            padded_img[0:img_padlen[0], -img_padlen[1]:] = img[0,-1]   # top right
            padded_img[-img_padlen[0]:, 0:img_padlen[1]] = img[-1, 0]  # bottom left
            padded_img[-img_padlen[0]:, -img_padlen[1]:] = img[-1, -1] # bottom right
            
        elif img_padtype == 'mirror': 
            if img_padlen[1] != 0:
                x_padded_img = np.hstack(((np.hstack((np.flip(img[:, 0:img_padlen[1]], axis=1), img))), 
                                          np.flip(img[:, -(img_padlen[1]):], axis=1))) # axis=1 flips img along y axis
            else:
                x_padded_img = img

            if img_padlen[0] != 0:
                padded_img = np.vstack((np.vstack((np.flip(x_padded_img[0:img_padlen[0], :], axis=0), x_padded_img)), 
                                       np.flip(x_padded_img[-(img_padlen[0]):, :], axis=0)))
            else: 
                padded_img = x_padded_img

        elif img_padtype == 'wrap': 
            if img_padlen[1] != 0: 
                x_padded_img = np.hstack(((np.hstack((img[:, -img_padlen[1]:], img))), 
                                        img[:, 0:img_padlen[1]])) # axis=1 flips img along y axis
            else:
                x_padded_img = img

            if img_padlen[0] != 0:
                padded_img = np.vstack((np.vstack((x_padded_img[0:img_padlen[0], :], x_padded_img)), 
                                       x_padded_img[-img_padlen[0]:, :]))
            else: 
                padded_img = img 
        
        assert(tuple(padded_img.shape) == padded_img_shape)
        return padded_img


    def get_ypatch_dim(self): 
        """ 
        Get the shape of the ypatches (patch dim with the patch borders removed)
        """
        ypatch_dim = (self.patch_dim[0] - 2*self.patch_border[0], self.patch_dim[1] - 2*self.patch_border[1])
        return ypatch_dim


##################################################################################################################################################
##################################################################################################################################################
##################################################################################################################################################
############################################################ OLD CODE ############################################################################
##################################################################################################################################################
##################################################################################################################################################
##################################################################################################################################################

class patcher_OLD:

    def __init__(self, img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight):
        """
        Initialize class to patchify and unpatchify images
        args:  
            - img_dim
            - patch_dim
            - patch_border
            - img_padlen
            - img_padtype
            - wrap_x
            - wrap_y
            - stride 
            - patch_weight: the patch weight to use when recombining
        NOTE: for argument descriptions look at the reinit function
        """
        #print("I AM A  TACO")
        self.reinit(img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight)
        return 


    def reinit(self, img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight):
        """
        Re-Initialize class to patchify and unpatchify images. 
        Sets the class attributes. 

        args: 
            - img_dim: tuple (rows, cols): dimension of the image
            - patch_dim: tuple (rows, cols): dimension of the patch
            - patch_border: tuple (rows, cols): number of rows, cols that are used as the border of the patch 
            - img_padlen: tuple (rows, cols): number of rows, cols used to pad the image
            - img_padtype: string: to indicate the method used for padding the image. For viable options look at self.pad_image 
            - wrap_x: whether you want the patches to wrap in the x or column direction
            - wrap_y: whether you want the patches to wrap in the y or row direction
            - stride: tuple (row stride, col stride): the number of pixels to skip by when creating image patches
            - patch_weight: the patch weight to use when recombining
        returns: 
            - None: NOTE: will print a warning if the attributes will not give proper coverage of the image for full reconstruction
        """
        #print("I AM A TACO 2")
        self.img_dim = img_dim
        self.patch_dim = patch_dim
        self.patch_border = patch_border
        self.img_padlen = img_padlen
        self.img_padtype = img_padtype
        self.wrap_x = wrap_x
        self.wrap_y = wrap_y
        self.stride = stride
        self.patch_weight = patch_weight

        # assert things are of the correct type
        assert(type(img_dim) == tuple)
        assert(type(patch_dim) == tuple)
        assert(type(patch_border) == tuple)
        assert(type(img_padlen) == tuple)
        assert(type(stride) == tuple)
        
        print("Stride: ", stride)
        
        # TESTS: make sure patchification is reconstructable
        # For the unpatched image to make sense/be the same you need the following
        #     1. img_padlen >= patch_border
        #     2. stride <= patch_border
        assert(self.patch_border[0] < self.patch_dim[0]/2)
        assert(self.patch_border[1] < self.patch_dim[1]/2)

        assert(self.img_padlen[0] >= self.patch_border[0])
        assert(self.img_padlen[1] >= self.patch_border[1])
        
        assert(self.stride[0] <= self.patch_dim[0] - 2*self.patch_border[0])
        assert(self.stride[1] <= self.patch_dim[1] - 2*self.patch_border[1])
        
        assert(self.patch_dim[0] >= self.stride[0])
        assert(self.patch_dim[1] >= self.stride[1])

        
        # set all start tuples
        self.all_start_tuples = self.get_row_col_start_tuples(img_dim=self.img_dim, 
            img_padlen=self.img_padlen, 
            patch_dim=self.patch_dim, 
            patch_border=self.patch_border, 
            stride=self.stride, 
            wrap_x=self.wrap_x, 
            wrap_y=self.wrap_y)
        
        # padding should allow for complete reconstruction - make sure the last patches prediction region ends at least as far as the border starts 
        #(no part of the non border image is not in the potential prediction region of the patches)
        padded_img_dim = (self.img_dim[0] + self.img_padlen[0], self.img_dim[1] + self.img_padlen[1])
        if not self.wrap_x: 
            assert(np.arange(0, padded_img_dim[1], self.stride[1])[-1] + self.patch_dim[1] - self.patch_border[1] >= self.img_dim[1] + self.img_padlen[1])
            #assert(np.arange(0, padded_img_dim[1], stride)[-1] + patch_dim - patch_border >= img_dim[1] + img_padlen)

        if not wrap_y: 
            assert(np.arange(0, padded_img_dim[0], self.stride[0])[-1] + self.patch_dim[0] - self.patch_border[0] >= self.img_dim[0] + self.img_padlen[0])
            #assert(np.arange(0, padded_img_dim[0], stride)[-1] + patch_dim - patch_border >= img_dim[0] + img_padlen)
                
        return 

    def patchify_image(self, img, img_type): 
        """
        Patchify image
        args: 
            - img
            - img_type: string: the type of image 'x' or 'y'

        args needed from attributes: 
            - patch_dim
            - stride
            - wrap_x
            - wrap_y
            - all_start_tuples
            - patch_border
            - img_padtype
            - img_padlen

        returns: 
            - patch_list: list of patches corresponding to the image

        Example of patch border: 
        Eg: of input and output dataset patches when patch_dim = 4, patch_border = 1
            patch 1       output patch  
            0 0 0 0 
            0 0 0 0  -->     0 0 
            0 0 0 0          0 0  
            0 0 0 0 

        """
        if self.img_padlen[0] > 0 or self.img_padlen[1] > 0:
            padded_img = self.pad_image(img, self.img_padlen, self.img_padtype)
        else:
            padded_img = deepcopy(img)
        padded_img_dim = padded_img.shape
        
        # if wrapping further pad the image
        tile_repeats = [1,1]
        row_offset = 0
        col_offset = 0
        if self.wrap_x: 
            tile_repeats[1] = 3
            col_offset = padded_img.shape[1]
        if self.wrap_y:
            tile_repeats[0] = 3
            row_offset = padded_img.shape[0]
        padded_img = np.tile(padded_img, tile_repeats)
            
        all_patches = []
        for (row_start, col_start) in self.all_start_tuples:
            row_start = row_start + row_offset
            col_start = col_start + col_offset
            current_patch = padded_img[row_start:row_start + self.patch_dim[0], col_start:col_start + self.patch_dim[1]]
            if img_type == 'y':
                """
                # truncate patch borders
                fig = plt.figure()
                plt.title((row_start - row_offset,col_start - col_offset))
                pos = plt.imshow(current_patch, vmin=0, vmax=1)
                fig.colorbar(pos)
                plt.show()
                """
                if self.patch_border[0] == 0 and self.patch_border[1] == 0:
                    current_patch = current_patch
                elif self.patch_border[0] == 0: 
                    current_patch = current_patch[:, self.patch_border[1]:-self.patch_border[1]]
                elif self.patch_border[1] == 0:
                    current_patch = current_patch[self.patch_border[0]:-self.patch_border, :]
                else:
                    current_patch = current_patch[self.patch_border[0]:-self.patch_border[0], self.patch_border[1]:-self.patch_border[1]]
            
            all_patches.append(current_patch)
            
        return all_patches

    def unpatchify_image(self, patch_list, patch_variance_list, img_type): 
        """
        Unpatchify image from patch list 
        args: 
            - patch_list: list of patches that form an image
            - patch_variance_list: list of floats of length patsch_list corresponding to the variance of the predicted patch
                in patch_list - if specified as None: then do not use
                NOTE: there is one variance value per patch 
            - img_type: string: whether the patches correspond to a 'y' or 'x' image
        returns: 
            - reconstructed img: numpy array 
            - reconstructed padded image
            - reconstructed variance image 
        """
        padded_img_dim = list(self.img_dim)
        padded_img_dim[0] += 2*self.img_padlen[0]
        padded_img_dim[1] += 2*self.img_padlen[1]
        padded_img_dim = tuple(padded_img_dim)

        patch_counter = 0
        patch_img_list = []
        mask_img_list = []

        if type(self.patch_weight) == type(None): 
            # create corresponding boolean mask 
            if img_type == 'y': 
                self.patch_weight = np.zeros(self.patch_dim)
                if self.patch_border[0] == 0 and self.patch_border[1] == 0:
                    self.patch_weight[:, :] = 1
                elif self.patch_border[0] == 0: 
                    self.patch_weight[:, self.patch_border[1]:-self.patch_border[1]] = 1
                elif self.patch_border[1] == 0:
                    self.patch_weight[self.patch_border[0]:-self.patch_border[0], :] = 1
                else:
                    self.patch_weight[self.patch_border[0]:-self.patch_border[0], self.patch_border[1]:-self.patch_border[1]] = 1
            else: 
                self.patch_weight = np.ones(self.patch_dim)

        tile_repeats = [1,1]
        row_offset = 0
        col_offset = 0
        if self.wrap_x: 
            tile_repeats[1] = 3
            col_offset = padded_img_dim[1]
        if self.wrap_y:
            tile_repeats[0] = 3
            row_offset = padded_img_dim[0]

        padded_img = np.zeros(padded_img_dim)
        # update the padded image with tiling to account for wrapping
        super_padded_img = np.tile(padded_img, tile_repeats)
        super_padded_img_dim = tuple(super_padded_img.shape)

        patch_img_list = []
        mask_img_list = []
        for (row_start, col_start) in self.all_start_tuples: 
            row_start = row_start + row_offset
            col_start = col_start + col_offset
            current_patch = patch_list[patch_counter]

            if img_type == 'y': 
                smaller_patch = deepcopy(current_patch)
                current_patch = np.zeros(self.patch_dim)

                if self.patch_border[0] == 0 and self.patch_border[1] == 0:
                    current_patch[:, :] = smaller_patch
                elif self.patch_border[0] == 0:
                    current_patch[:, self.patch_border[1]:-self.patch_border[1]] = smaller_patch
                elif self.patch_border[1] == 0:
                    current_patch[self.patch_border[0]:-self.patch_border[0], :] = smaller_patch
                else:
                    current_patch[self.patch_border[0]:-self.patch_border[0], self.patch_border[1]:-self.patch_border[1]] = smaller_patch

            current_patch_img  = np.zeros(super_padded_img_dim)
            current_patch_mask = np.zeros(super_padded_img_dim)

            #print(tile_repeats)
            #print(super_padded_img_dim)

            current_patch_img[row_start:row_start + self.patch_dim[0], col_start:col_start + self.patch_dim[1]] = current_patch
            # only encourage averaging where the actual predicted pixels are 
            if img_type == 'y':
                if self.patch_border[0] == 0 and self.patch_border[1] == 0:
                    current_patch_mask[row_start + self.patch_border[0]:row_start + self.patch_dim[0] - self.patch_border[0], 
                        col_start + self.patch_border[1]:col_start + self.patch_dim[1] - self.patch_border[1]] = self.patch_weight[:, :]
                elif self.patch_border[0] == 0:
                    current_patch_mask[row_start + self.patch_border[0]:row_start + self.patch_dim[0] - self.patch_border[0], 
                        col_start + self.patch_border[1]:col_start + self.patch_dim[1] - self.patch_border[1]] = self.patch_weight[:, self.patch_border[1]:-self.patch_border[1]]
                elif self.patch_border[1] == 0:
                    current_patch_mask[row_start + self.patch_border[0]:row_start + self.patch_dim[0] - self.patch_border[0], 
                        col_start + self.patch_border[1]:col_start + self.patch_dim[1] - self.patch_border[1]] = self.patch_weight[self.patch_border[0]:-self.patch_border, :]
                else: 
                    current_patch_mask[row_start + self.patch_border[0]:row_start + self.patch_dim[0] - self.patch_border[0], 
                        col_start + self.patch_border[1]:col_start + self.patch_dim[1] - self.patch_border[1]] = self.patch_weight[self.patch_border[0]:-self.patch_border[0], self.patch_border[1]:-self.patch_border[1]]

            else:
                current_patch_mask[row_start:row_start + self.patch_dim[0], col_start:col_start + self.patch_dim[1]] = np.ones(self.patch_dim)

            if type(patch_variance_list) != type(None): 
                # weight using variance
                current_patch_mask = current_patch_mask * patch_variance_list[patch_counter]

            # append to lists 
            patch_img_list.append(current_patch_img)
            mask_img_list.append(current_patch_mask)

            # increase counter
            patch_counter += 1

        # If there is tiling add the tiles together
        reconstructed_padded_img = np.zeros(padded_img_dim)
        reconstructed_padded_normalization = np.zeros(padded_img_dim)
        reconstructed_padded_var_img = np.zeros(padded_img_dim)
        for found_patch_num in range(len(patch_img_list)):
            current_patch_img = patch_img_list[found_patch_num]
            current_patch_mask = mask_img_list[found_patch_num]

            # add all the tiles together
            for row_tile_num in range(tile_repeats[0]):
                for col_tile_num in range(tile_repeats[1]):
                    curr_patch_img_tile = current_patch_img[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]
                    curr_patch_mask_tile = current_patch_mask[row_tile_num*padded_img_dim[0]:(row_tile_num + 1)*padded_img_dim[0], col_tile_num*padded_img_dim[1]: (col_tile_num + 1)*padded_img_dim[1]]

                    reconstructed_padded_img += curr_patch_img_tile * curr_patch_mask_tile
                    reconstructed_padded_normalization += curr_patch_mask_tile

        # normalize by the total variance weight
        reconstructed_padded_normalization_divide = deepcopy(reconstructed_padded_normalization)
        reconstructed_padded_normalization_divide[np.where(reconstructed_padded_normalization != 0)] = 1/reconstructed_padded_normalization[np.where(reconstructed_padded_normalization != 0)] 
        reconstructed_padded_img = reconstructed_padded_img * reconstructed_padded_normalization_divide
        if self.img_padlen[0] == 0 and self.img_padlen[1] == 0:
            reconstructed_img = reconstructed_padded_img[:, :]
            reconstructed_variance_img = reconstructed_padded_normalization[:]
        elif self.img_padlen[0] == 0:
            reconstructed_img = reconstructed_padded_img[:, self.img_padlen[1]:-self.img_padlen[1]]
            reconstructed_variance_img = reconstructed_padded_normalization[:, self.img_padlen[1]:-self.img_padlen[1]]
        elif self.img_padlen[1] == 0:
            reconstructed_img = reconstructed_padded_img[self.img_padlen[0]:-self.img_padlen[0], :]
            reconstructed_variance_img = reconstructed_padded_normalization[self.img_padlen[0]:-self.img_padlen[0], :]
        else:
            reconstructed_img = reconstructed_padded_img[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])]
            reconstructed_variance_img = reconstructed_padded_normalization[self.img_padlen[0]:-(self.img_padlen[0]), self.img_padlen[1]:-(self.img_padlen[1])] # NOTE: only valid if the patch variance list is actually specified - UNSURE IF THIS IS VALID

        return reconstructed_img, reconstructed_padded_img, reconstructed_variance_img

    def patchify_dataset(self, dataset, dataset_type): 
        """
        Patchify image on whole dataset
        args: 
            - dataset: list of images or tuples of images
            - dataset_type: string: 'x' or 'y' dataset
        returns: 
            - patch_list: list of patches - concatenated patch list of each of the datapoints in the dataset
        """
        dataset_patches = []
        for datapoint in dataset:
            if type(datapoint) == tuple: 
                # datapoint is a tuple of images
                data_point_patches = []
                for img in datapoint: 
                    data_point_patches.append(self.patchify_image(img, dataset_type))

                patch_tuple_list = []
                for i in range(len(data_point_patches[0])):
                    curr_patch_tuple = tuple([data_point_patches[tup_index][i] for tup_index in range(len(data_point_patches))])
                    patch_tuple_list.append(curr_patch_tuple)

                # extend dataset patches with found patch tuple list 
                dataset_patches.extend(patch_tuple_list)

            else: 
                # datapoint is a single image 
                dataset_patches.extend(self.patchify_image(datapoint, dataset_type))

        return dataset_patches

    def unpatchify_dataset(self, dataset_patch_list, dataset_patch_variance_list, dataset_type): 
        """
        Unpatchify list of patches for whole dataset to list of images
        args: 
            - dataset_patch_list: list of patches: concatenated patch list of each of the datapoints in the dataset
            - dataset_patch_variance_list: list of floats of length dataset_patch_list: there is one variance value per patch 
                in the dataset_patch_list
            - dataset_type: string: 'x' or 'y' dataset
        returns: 
            - img_dataset: list of images reconstructed from the dataset_patch_list
            - var_dataset: list of images that correspond to the variance image reconstructed from the variance values of the patches and their masks
        # NOTE: no support for unpatchifying image tuples only datasets where each datapoint is a single image
        """
        num_patches_per_image = len(self.all_start_tuples) # the number of patches per image

        num_images = len(dataset_patch_list)/num_patches_per_image
        assert(num_images%1==0) # assert enough patches for num_images full images

        img_dataset = []
        var_dataset = []
        for img_num in range(int(num_images)): 
            if type(dataset_patch_variance_list) != type(None): 
                curr_img, curr_img_padded, curr_var_img = self.unpatchify_image(dataset_patch_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    patch_variance_list=dataset_patch_variance_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    img_type=dataset_type)
            else: 
                curr_img, curr_img_padded, curr_var_img = self.unpatchify_image(dataset_patch_list[img_num*num_patches_per_image:(img_num + 1)*num_patches_per_image], 
                    patch_variance_list=None, 
                    img_type=dataset_type)

            img_dataset.append(curr_img)
            var_dataset.append(curr_var_img)

        return img_dataset, var_dataset 

    ####################################################################################################################
    ############################################## Helper functions ####################################################
    ####################################################################################################################

    def get_row_col_start_tuples(self, img_dim, img_padlen, patch_dim, patch_border, stride, wrap_x, wrap_y, allow_negative_wrap=True): 
        """ 
        Get the starting rows and columns (denoting the top left pixel of the patch)
        from which to sample the patches. 
        
        NOTE: 
        - when wrapping you may get images that are less than 0
        
        args: 
            - img_dim: tuple (rows, cols)
            - img_padlen: tuple (rows, cols): number of pixels to pad the rows and cols
            - patch_dim: tuple (rows, cols)
            - patch_border: tuple (rows, cols): how many rows, cols to use for the patch border
            - stride: tuple (row stride, col stride)
            - wrap_x: Boolean: if the patch should wrap the x axis (columns)
            - wrap_y: Boolean: if the patch should wrap the y axis (rows)
            - allow_negative_wrap: boolean: when True allows negative values while wrapping
                DEFAULT: TRUE
        returns: 
            - list of tuples: (row, col) start points - assuming the appropriate tiling
        """
        padded_img_dim = list(img_dim)
        padded_img_dim[0] += 2*img_padlen[0]
        padded_img_dim[1] += 2*img_padlen[1]
        padded_img_dim = tuple(padded_img_dim)
        
        if allow_negative_wrap : 
            negative_row_start = 0 - patch_dim[0]
            negative_col_start = 0 - patch_dim[1]
            if wrap_x: 
                start_x = negative_col_start + stride[1]
            else: 
                start_x = 0
            if wrap_y:
                start_y = negative_row_start + stride[0]
            else: 
                start_y = 0
        else: 
            start_x = 0
            start_y = 0
            
        # col_list: denotes starting columns of the patches
        if wrap_x: 
            col_list = np.arange(start=start_x, stop=padded_img_dim[1] , step=stride[1])
        else: 
            print("padded img dim: ", padded_img_dim)
            print("patch dim: ", patch_dim)
            print("stride: ", stride)
            col_list = np.arange(start=start_x, stop=padded_img_dim[1] - patch_dim[1] + 1, step=stride[1])

        # row_list: denotes starting rows of the patches
        if wrap_y: 
            row_list = np.arange(start=start_y, stop=padded_img_dim[0] , step=stride[0])
        else:
            row_list = np.arange(start=start_y, stop=padded_img_dim[0] - patch_dim[0] + 1, step=stride[0])

        all_start_tuples = []
        for row in row_list: 
            for col in col_list: 
                orig_row = row
                orig_col = col
                # Note: might change - don't consider patch starts that you have already considered
                if wrap_y and row < 0:
                    row = row + img_dim[0]
                if wrap_x and col < 0: 
                    col = col + img_dim[1]
                if (row, col) in all_start_tuples: 
                    continue
                else:
                    all_start_tuples.append((orig_row, orig_col))
        
        print("padded img dim: ", padded_img_dim)
        print("row list: ", row_list)
        print("col list: ", col_list)
        # print warning if the whole
        # if the last patch prediction region does not cover the end of meaningful image
        if not wrap_y: 
            if row_list[-1] + patch_dim[0] - patch_border[0] < img_dim[0] + img_padlen[0]: 
                print("WARNING: your current parameters will not allow whole image reconstruction: ROW")
                print("Increase row padding by: ", (img_dim[0] + img_padlen[0]) - (row_list[-1] + patch_dim[0] - patch_border[0]))
        if not wrap_x: 
            if col_list[-1] + patch_dim[1] - patch_border[1] < img_dim[1] + img_padlen[1]: 
                print("WARNING: your current parameters will not allow whole image reconstruction: COLUMN")
                print("Increase col padding by: ", (img_dim[1] + img_padlen[1]) - (col_list[-1] + patch_dim[1] - patch_border[1]))
                
        return all_start_tuples

    def pad_image(self, img, img_padlen, img_padtype): 
        """
        Pad the image with a border of size img_padlen with the img_padtype
        
        Methodology: 
            - img_padtype: 'black'
                have the border just be black pixels (all 0)
            - img_padtype: 'extend': 
                extend the last pixel on the border out by img_padlen rows/cols
            - img_padtype: 'mirror'
                mirror the edge to expand the border
            - img_padtype: 'wrap'
                wrap the image around - do this by tiling the image and then selecting the appropriate image
                
        args: 
            - img:
            - img_padlen: tuple (rows, cols): number of pixels to pad the rows and cols
            - img_padtype: string: one of the above strings specifying the method of padding
            
        returns: 
            - padded image
        """
        padded_img_shape = list(img.shape)
        padded_img_shape[0] += 2*img_padlen[0]
        padded_img_shape[1] += 2*img_padlen[1]
        padded_img_shape = tuple(padded_img_shape)
        
        if img_padtype == 'black': 
            padded_img = np.zeros(padded_img_shape)
            if img_padlen[0] == 0 and img_padlen[1] == 0:
                padded_img = img
            elif img_padlen[0] == 0:
                padded_img[:, img_padlen[1]:-img_padlen[1]] = img
            elif img_padlen[1] == 0:
                padded_img[img_padlen[0]:-img_padlen[0], :] = img
            else:   
                padded_img[img_padlen[0]:-(img_padlen[0]), img_padlen[1]:-(img_padlen[1])] = img
            
        elif img_padtype == 'extend':
            padded_img = np.zeros(padded_img_shape)
            padded_img[img_padlen[0]:-img_padlen[0], img_padlen[1]:-(img_padlen[1])] = img

            # TODO: optimize out the for loop later - use tile or repeat 
            # handle edges - the [img_padlen:-img_padlen] indexing is to exclude the corners  
            
            # Extend rows - NOTE: in these two loops the 1 is added as iteration from 0 to end - 1, so need to add 1 for the negative
            for row_border_level in range(img_padlen[0]): 
                if img_padlen[0] == 0:
                    padded_img[row_border_level, :][:] = img[0,:]
                    padded_img[-(row_border_level+1),:][:] = img[-1,:]
                else:
                    padded_img[row_border_level, :][img_padlen[0]:-(img_padlen[0])] = img[0,:]
                    padded_img[-(row_border_level+1),:][img_padlen[0]:-(img_padlen[0])] = img[-1,:]
            # Extend columns
            for col_border_level in range(img_padlen[1]):
                if img_padlen[1] == 0:
                    padded_img[:, col_border_level][:] = img[:, 0]
                    padded_img[:,-(col_border_level + 1)][:] = img[-1,:] 
                else:
                    padded_img[:, col_border_level][img_padlen[1]:-(img_padlen[1])] = img[:,0]
                    padded_img[:,-(col_border_level+1)][img_padlen[1]:-(img_padlen[1])] = img[:,-1]

            # handle corners 
            padded_img[0:img_padlen[0], 0:img_padlen[1]] = img[0,0]    # top left
            padded_img[0:img_padlen[0], -img_padlen[1]:] = img[0,-1]   # top right
            padded_img[-img_padlen[0]:, 0:img_padlen[1]] = img[-1, 0]  # bottom left
            padded_img[-img_padlen[0]:, -img_padlen[1]:] = img[-1, -1] # bottom right
            
        elif img_padtype == 'mirror': 
            if img_padlen[1] != 0:
                x_padded_img = np.hstack(((np.hstack((np.flip(img[:, 0:img_padlen[1]], axis=1), img))), 
                                          np.flip(img[:, -(img_padlen[1]):], axis=1))) # axis=1 flips img along y axis
            else:
                x_padded_img = img

            if img_padlen[0] != 0:
                padded_img = np.vstack((np.vstack((np.flip(x_padded_img[0:img_padlen[0], :], axis=0), x_padded_img)), 
                                       np.flip(x_padded_img[-(img_padlen[0]):, :], axis=0)))
            else: 
                padded_img = x_padded_img

        elif img_padtype == 'wrap': 
            if img_padlen[1] != 0: 
                x_padded_img = np.hstack(((np.hstack((img[:, -img_padlen[1]:], img))), 
                                        img[:, 0:img_padlen[1]])) # axis=1 flips img along y axis
            else:
                x_padded_img = img

            if img_padlen[0] != 0:
                padded_img = np.vstack((np.vstack((x_padded_img[0:img_padlen[0], :], x_padded_img)), 
                                       x_padded_img[-img_padlen[0]:, :]))
            else: 
                padded_img = img 
        
        assert(tuple(padded_img.shape) == padded_img_shape)
        return padded_img


    def get_ypatch_dim(self): 
        """ 
        Get the shape of the ypatches (patch dim with the patch borders removed)
        """
        ypatch_dim = (self.patch_dim[0] - 2*self.patch_border[0], self.patch_dim[1] - 2*self.patch_border[1])
        return ypatch_dim

if __name__ == '__main__': 
    img_dim = (10, 10)
    patch_dim = (5,5)
    patch_border = (0, 0)
    img_padlen = (0,0)
    img_padtype = 'black'
    wrap_x = False
    wrap_y = True 
    stride = (5,5)
    patch_weight = np.ones(patch_dim)
    
    thing = patcher(img_dim, patch_dim, patch_border, img_padlen, img_padtype, wrap_x, wrap_y, stride, patch_weight)
