
import cv2
import os
from glob import glob
import scipy.io as sio
import numpy as np
import math
import shutil


class EMPatches(object):
    def __init__(self):
        pass

    def extract_patches(self, data, patchsize, overlap=None, stride=None, vox=False):
        '''
        Parameters
        ----------
        data : array to extract patches from; it can be 1D, 2D or 3D [W, H, D]. H: Height, W: Width, D: Depth,
               3D data includes images (RGB, RGBA, etc) or Voxel data.
        patchsize :  size of patch to extract from image only square patches can be
                    extracted for now.
        overlap (Optional): overlap between patched in percentage a float between [0, 1].
        stride (Optional): Step size between patches
        vox (Optional): Whether data is volumetric or not if set to true array will be cropped in last dimension too.

        Returns
        -------
        data_patches : a list containing extracted patches of images.
        indices : a list containing indices of patches in order, whihc can be used
                at later stage for 'merging_patches'.

        '''

        dims = data.shape

        if len(dims) == 1:

            width = data.shape[0]
            maxWindowSize = patchsize
            windowSizeX = maxWindowSize
            windowSizeX = min(windowSizeX, width)

        elif len(dims) == 2:

            height = data.shape[0]
            width = data.shape[1]
            maxWindowSize = patchsize
            windowSizeX = maxWindowSize
            windowSizeY = maxWindowSize
            windowSizeX = min(windowSizeX, width)
            windowSizeY = min(windowSizeY, height)

        elif len(dims) == 3:

            height = data.shape[0]
            width = data.shape[1]
            depth = data.shape[2]

            maxWindowSize = patchsize
            windowSizeX = maxWindowSize
            windowSizeY = maxWindowSize
            windowSizeZ = maxWindowSize

            windowSizeX = min(windowSizeX, width)
            windowSizeY = min(windowSizeY, height)
            windowSizeZ = min(windowSizeZ, depth)

        if stride is not None:
            if len(dims) == 1:
                stepSizeX = stride
            elif len(dims) == 2:
                stepSizeX = stride
                stepSizeY = stride
            elif len(dims) == 3:
                stepSizeX = stride
                stepSizeY = stride
                stepSizeZ = stride

        elif overlap is not None:
            overlapPercent = overlap

            if len(dims) == 1:
                windowSizeX = maxWindowSize

                # If the input data is smaller than the specified window size,
                # clip the window size to the input size on both dimensions
                windowSizeX = min(windowSizeX, width)

                # Compute the window overlap and step size
                windowOverlapX = int(math.floor(windowSizeX * overlapPercent))

                stepSizeX = windowSizeX - windowOverlapX

            elif len(dims) == 2:
                windowSizeX = maxWindowSize
                windowSizeY = maxWindowSize

                # If the input data is smaller than the specified window size,
                # clip the window size to the input size on both dimensions
                windowSizeX = min(windowSizeX, width)
                windowSizeY = min(windowSizeY, height)

                # Compute the window overlap and step size
                windowOverlapX = int(math.floor(windowSizeX * overlapPercent))
                windowOverlapY = int(math.floor(windowSizeY * overlapPercent))

                stepSizeX = windowSizeX - windowOverlapX
                stepSizeY = windowSizeY - windowOverlapY

            elif len(dims) == 3:
                windowSizeX = maxWindowSize
                windowSizeY = maxWindowSize
                windowSizeZ = maxWindowSize

                # If the input data is smaller than the specified window size,
                # clip the window size to the input size on both dimensions
                windowSizeX = min(windowSizeX, width)
                windowSizeY = min(windowSizeY, height)
                windowSizeZ = min(windowSizeZ, depth)

                # Compute the window overlap and step size
                windowOverlapX = int(math.floor(windowSizeX * overlapPercent))
                windowOverlapY = int(math.floor(windowSizeY * overlapPercent))
                windowOverlapZ = int(math.floor(windowSizeZ * overlapPercent))

                stepSizeX = windowSizeX - windowOverlapX
                stepSizeY = windowSizeY - windowOverlapY
                stepSizeZ = windowSizeZ - windowOverlapZ

        else:
            if len(dims) == 1:
                stepSizeX = 1
            elif len(dims) == 2:
                stepSizeX = 1
                stepSizeY = 1
            elif len(dims) == 3:
                stepSizeX = 1
                stepSizeY = 1
                stepSizeZ = 1

        if len(dims) == 1:

            # Determine how many windows we will need in order to cover the input data
            lastX = width - windowSizeX
            xOffsets = list(range(0, lastX + 1, stepSizeX))

            # Unless the input data dimensions are exact multiples of the step size,
            # we will need one additional row and column of windows to get 100% coverage
            if len(xOffsets) == 0 or xOffsets[-1] != lastX:
                xOffsets.append(lastX)

            data_patches = []
            indices = []

            for xOffset in xOffsets:
                if len(data.shape) >= 3:
                    data_patches.append(data[(slice(xOffset, xOffset + windowSizeX, None))])
                else:
                    data_patches.append(data[(slice(xOffset, xOffset + windowSizeX))])

                indices.append((xOffset, xOffset + windowSizeX))

        elif len(dims) == 2:

            # Determine how many windows we will need in order to cover the input data
            lastX = width - windowSizeX
            lastY = height - windowSizeY
            xOffsets = list(range(0, lastX + 1, stepSizeX))
            yOffsets = list(range(0, lastY + 1, stepSizeY))

            # Unless the input data dimensions are exact multiples of the step size,
            # we will need one additional row and column of windows to get 100% coverage
            if len(xOffsets) == 0 or xOffsets[-1] != lastX:
                xOffsets.append(lastX)
            if len(yOffsets) == 0 or yOffsets[-1] != lastY:
                yOffsets.append(lastY)

            data_patches = []
            indices = []

            for xOffset in xOffsets:
                for yOffset in yOffsets:
                    if len(data.shape) >= 3:
                        data_patches.append(data[(slice(yOffset, yOffset + windowSizeY, None),
                                                  slice(xOffset, xOffset + windowSizeX, None))])
                    else:
                        data_patches.append(data[(slice(yOffset, yOffset + windowSizeY),
                                                  slice(xOffset, xOffset + windowSizeX))])

                    indices.append((yOffset, yOffset + windowSizeY, xOffset, xOffset + windowSizeX))

        elif len(dims) == 3:

            # Determine how many windows we will need in order to cover the input data
            lastX = width - windowSizeX
            lastY = height - windowSizeY
            lastZ = depth - windowSizeZ

            xOffsets = list(range(0, lastX + 1, stepSizeX))
            yOffsets = list(range(0, lastY + 1, stepSizeY))
            zOffsets = list(range(0, lastZ + 1, stepSizeZ))

            # Unless the input data dimensions are exact multiples of the step size,
            # we will need one additional row and column of windows to get 100% coverage
            if len(xOffsets) == 0 or xOffsets[-1] != lastX:
                xOffsets.append(lastX)
            if len(yOffsets) == 0 or yOffsets[-1] != lastY:
                yOffsets.append(lastY)
            if len(zOffsets) == 0 or zOffsets[-1] != lastZ:
                zOffsets.append(lastZ)

            data_patches = []
            indices = []

            if not vox:  # for images
                for xOffset in xOffsets:
                    for yOffset in yOffsets:
                        if len(data.shape) >= 3:
                            data_patches.append(data[(slice(yOffset, yOffset + windowSizeY, None),
                                                      slice(xOffset, xOffset + windowSizeX, None))])
                        else:
                            data_patches.append(data[(slice(yOffset, yOffset + windowSizeY),
                                                      slice(xOffset, xOffset + windowSizeX))])

                        indices.append((yOffset, yOffset + windowSizeY, xOffset, xOffset + windowSizeX))
            if vox:  # for volumetric data
                for xOffset in xOffsets:
                    for yOffset in yOffsets:
                        for zOffset in zOffsets:
                            if len(data.shape) >= 4:
                                data_patches.append(data[(slice(yOffset, yOffset + windowSizeY, None),
                                                          slice(xOffset, xOffset + windowSizeX, None),
                                                          slice(zOffset, zOffset + windowSizeZ, None))])
                            else:
                                data_patches.append(data[(slice(yOffset, yOffset + windowSizeY),
                                                          slice(xOffset, xOffset + windowSizeX),
                                                          slice(zOffset, zOffset + windowSizeZ))])

                            indices.append((yOffset, yOffset + windowSizeY, xOffset, xOffset + windowSizeX, zOffset,
                                            zOffset + windowSizeZ))

        return data_patches, indices

    def merge_patches(self, data_patches, indices, mode='overwrite'):
        '''
        Parameters
        ----------
        data_patches : list containing image patches that needs to be joined, dtype=uint8
        indices : a list of indices generated by 'extract_patches' function of the format;
                    (yOffset, yOffset+windowSizeY, xOffset, xOffset+windowSizeX)
        mode : how to deal with overlapping patches;
                overwrite -> next patch will overwrite the overlapping area of the previous patch.
                max -> maximum value of overlapping area at each pixel will be written.
                min -> minimum value of overlapping area at each pixel will be written.
                avg -> mean/average value of overlapping area at each pixel will be written.
        Returns
        -------
        Stitched image.
        '''
        modes = ["overwrite", "max", "min", "avg"]
        if mode not in modes:
            raise ValueError(f"mode has to be either one of {modes}, but got {mode}")

        dims = len(indices[-1])

        if dims == 2:
            orig_h = indices[-1][1]
        elif dims == 4:
            orig_h = indices[-1][1]
            orig_w = indices[-1][3]
        elif dims == 6:
            orig_h = indices[-1][1]
            orig_w = indices[-1][3]
            orig_d = indices[-1][5]

        ### There is scope here for rgb/hyperspectral volume (i.e. 4D -> 3 spatial and 1 spectral dimensions, simplest case is only 3 channles for the spectral dimension)
        rgb = True
        if len(data_patches[0].shape) == 2:
            rgb = False

        if mode == 'min':
            if dims == 2:
                empty_data = np.zeros((orig_h)).astype(np.float32) + np.inf  # using float here is better

            elif dims == 4:
                if rgb:
                    empty_data = np.zeros((orig_h, orig_w, data_patches[0].shape[-1])).astype(
                        np.float32) + np.inf  # using float here is better
                else:
                    empty_data = np.zeros((orig_h, orig_w)).astype(np.float32) + np.inf  # using float here is better

            elif dims == 6:
                empty_data = np.zeros((orig_h, orig_w, orig_d)).astype(
                    np.float32) + np.inf  # using float here is better

        else:
            if dims == 2:
                empty_data = np.zeros((orig_h)).astype(np.float32)  # using float here is better

            elif dims == 4:
                if rgb:
                    empty_data = np.zeros((orig_h, orig_w, data_patches[0].shape[-1])).astype(
                        np.float32)  # using float here is better
                else:
                    empty_data = np.zeros((orig_h, orig_w)).astype(np.float32)  # using float here is better

            elif dims == 6:
                empty_data = np.zeros((orig_h, orig_w, orig_d)).astype(np.float32)  # using float here is better

        for i, indice in enumerate(indices):

            if mode == 'overwrite':

                if dims == 2:
                    empty_data[indice[0]:indice[1]] = data_patches[i]

                elif dims == 4:
                    if rgb:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3], :] = data_patches[i]
                    else:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3]] = data_patches[i]

                elif dims == 6:
                    empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]] = data_patches[i]


            elif mode == 'max':

                if dims == 2:
                    empty_data[indice[0]:indice[1]] = np.maximum(data_patches[i], empty_data[indice[0]:indice[1]])
                elif dims == 4:
                    if rgb:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3], :] = np.maximum(data_patches[i],
                                                                                             empty_data[
                                                                                             indice[0]:indice[1],
                                                                                             indice[2]:indice[3], :])
                    else:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3]] = np.maximum(data_patches[i], empty_data[
                                                                                                           indice[0]:
                                                                                                           indice[1],
                                                                                                           indice[2]:
                                                                                                           indice[3]])
                elif dims == 6:
                    empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]] = np.maximum(
                        data_patches[i], empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]])


            elif mode == 'min':
                if dims == 2:
                    empty_data[indice[0]:indice[1]] = np.minimum(data_patches[i], empty_data[indice[0]:indice[1]])
                elif dims == 4:
                    if rgb:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3], :] = np.minimum(data_patches[i],
                                                                                             empty_data[
                                                                                             indice[0]:indice[1],
                                                                                             indice[2]:indice[3], :])
                    else:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3]] = np.minimum(data_patches[i], empty_data[
                                                                                                           indice[0]:
                                                                                                           indice[1],
                                                                                                           indice[2]:
                                                                                                           indice[3]])
                elif dims == 6:
                    empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]] = np.minimum(
                        data_patches[i], empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]])

            elif mode == 'avg':

                if dims == 2:
                    empty_data[indice[0]:indice[1]] = np.where(empty_data[indice[0]:indice[1]] == 0,
                                                               data_patches[i],
                                                               np.add(data_patches[i],
                                                                      empty_data[indice[0]:indice[1]]) / 2)
                elif dims == 4:
                    if rgb:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3], :] = np.where(
                            empty_data[indice[0]:indice[1], indice[2]:indice[3], :] == 0,
                            data_patches[i],
                            np.add(data_patches[i], empty_data[indice[0]:indice[1], indice[2]:indice[3], :]) / 2)
                        # Below line should work with np.ones mask but giving Weights sum to zero error and is approx. 10 times slower then np.where
                        # empty_data[indice[0]:indice[1], indice[2]:indice[3], :] = np.average(([empty_data[indice[0]:indice[1], indice[2]:indice[3], :],
                        #                                                                         data_patches[i]]), axis=0,
                        #                                                                         weights=(np.asarray([empty_data[indice[0]:indice[1], indice[2]:indice[3], :],
                        #                                                                                               data_patches[i]])>0))
                    else:
                        empty_data[indice[0]:indice[1], indice[2]:indice[3]] = np.where(
                            empty_data[indice[0]:indice[1], indice[2]:indice[3]] == 0,
                            data_patches[i],
                            np.add(data_patches[i], empty_data[indice[0]:indice[1], indice[2]:indice[3]]) / 2)
                elif dims == 6:
                    empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]] = np.where(
                        empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]] == 0,
                        data_patches[i],
                        np.add(data_patches[i],
                               empty_data[indice[0]:indice[1], indice[2]:indice[3], indice[4]:indice[5]]) / 2)

        return empty_data


class BatchPatching(EMPatches):
    def __init__(self, patchsize, overlap=None, stride=None, typ='tf', vox=False):
        '''
        Parameters
        ----------
        patchsize :  size of patch to extract from image only square patches can be
                     extracted for now.
        overlap (Optional): overlap between patched in percentage a float between [0, 1].
        stride (Optional): Step size between patches
        type: Type of batched images tf or torch type
        '''
        super().__init__()
        self.patchsize = patchsize
        self.overlap = overlap
        self.stride = stride
        self.typ = typ
        self.vox = vox

    def patch_batch(self, batch):
        '''
        Parameters
        ----------
        batch : Batch of images of shape either BxCxHxW -> pytorch or BxHxWxC -> tf
                to extract patches. For 1D spectra or batchs shape should be BxD.
        Returns
        -------
        batch_patches : a list containing lists of extracted patches of images.
        batch_indices : a list containing lists of indices of patches in order, whihc can be used
                  at later stage for 'merging_patches'.

        '''
        if len(batch.shape) != 2:
            typs = ["tf", "torch"]
            if self.typ not in typs:
                raise ValueError(
                    f"mode has to be either one of {typs}, but got {self.typ}. For numpy also use 'tf' type.")
            if len(batch.shape) != 4:
                raise ValueError(
                    f'Input batch should be of shape BxDxHxW or BxHxWxD i.e. 4D for image or volumetric data or BxD i.e. 2D for 1D spectral data, but got {len(batch.shape)} dims')

            if self.typ == 'torch':
                batch = batch.permute(0, 2, 3, 1)
        else:
            pass

        img_list = list(batch)

        b_patches, b_indices = [], []
        for i in range(len(img_list)):
            patches, indices = super().extract_patches(img_list[i], self.patchsize, self.overlap, self.stride, self.vox)
            b_patches.append(patches)
            b_indices.append(indices)

        return b_patches, b_indices

    def merge_batch(self, b_patches, b_indices, mode='overwrite'):
        '''
        Parameters
        ----------
        b_patches : list containing lists of patches of images to be merged together
                    e.g. list(list1, list2, ...), where, list1->([H W C], [H W C], ...) and so on.
        b_indices : list containing lists of indices of images to be merged in format as return by
                    patch_batch method.
        Returns
        -------
        merged_batch : a np array of shape BxCxHxW -> pytorch or BxHxWxC -> tf.

        '''
        m_patches = []
        for p, i in zip(b_patches, b_indices):
            m = super().merge_patches(p, i, mode)
            m_patches.append(m)

        m_patches = np.asarray(m_patches)

        if self.typ == 'torch':
            m_patches = m_patches.transpose(0, 3, 2, 1)

        return m_patches


def patch_via_indices(data, indices):
    '''
        Parameters
        ----------
        img : array to extract patches from; it can be 1D, 2D or 3D [W, H, D]. H: Height, W: Width, D: Depth.
              3D data includes images (RGB, RGBA, etc) or Voxel data.
        indices :   list of indices/tuple of 4 e.g;
                    [(ystart, yend, xstart, xend, zstart, zend), -> indices of 1st patch
                     (ystart, yend, xstart, xend, zstart, zend), -> indices of 2nd patch
                     ...] -> for 3D data
        Returns
        -------
        img_patches : a list containing extracted patches of image.
        '''
    dims = len(indices[-1])

    data_patches = []

    if dims == 2:

        for indice in indices:
            data_patches.append(data[(slice(indice[0], indice[1]))])

    elif dims == 4:

        for indice in indices:
            if len(data.shape) >= 3:
                data_patches.append(data[(slice(indice[0], indice[1], None),
                                          slice(indice[2], indice[3], None))])
            else:
                data_patches.append(data[(slice(indice[0], indice[1]),
                                          slice(indice[2], indice[3]))])

    elif dims == 6:

        for indice in indices:
            data_patches.append(data[(slice(indice[0], indice[1]),
                                      slice(indice[2], indice[3]),
                                      slice(indice[3], indice[4]))])

    return data_patches

def load_img(filepath, load_img_tag = 'msi'):
    # load hypersepctral MSI or RGB  image
    # x: shape (H, W, C) => (512, 512, 31)
    x = sio.loadmat(filepath)
    x = x[load_img_tag].astype(np.float64)
    # x = torch.tensor(x).float()
    return x


# split cave image to patches with fixed H ,W
def crop_cave(filename, size, tag, save_path):
    data = load_img(filename, tag)
    crop_engine = EMPatches()
    patches, _ = crop_engine.extract_patches(data, size, overlap=0) # patches and index
    id = 0
    # create_folder(save_path)
    for patch in patches:
        sio.savemat(os.path.join(save_path, f"{os.path.splitext(filename)[0].split('/')[-1]}_{size}_{id}.mat"), {tag: patch})
        id = id + 1

# split pavia centre images to patches with fixed H ,W
def crop_pavia(filename, size, save_path):
    data = np.load(filename)
    crop_engine = EMPatches()
    patches, _ = crop_engine.extract_patches(data, size, overlap=0) # patches and index
    id = 0
    # create_folder(save_path)
    for patch in patches:
        np.save(os.path.join(save_path, f"{os.path.splitext(filename)[0].split('/')[-1]}_{size}_{id}.npy"), patch)
        id = id + 1



def create_folder(path):
    if os.path.exists(path):
        print(f"Folder '{path}' already exists. Deleting...")
        try:
            shutil.rmtree(path)
        except OSError as e:
            print(f"Error: {path} : {e.strerror}")

    print(f"Creating folder '{path}'...")
    try:
        os.makedirs(path)
        print(f"Folder '{path}' created successfully.")
    except OSError as e:
        print(f"Error: {path} : {e.strerror}")


if __name__ == '__main__':
    image_path = r"/data/dataset/CAVE/CAVEdata/test/Y"

    filename = glob(image_path + "/*.mat")
    for i in filename:
        print(i)
        # basename = os.path.basename(i)
        crop_cave(i, 128, "RGB",save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/Y_128")
        crop_cave(i, 256, "RGB", save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/Y_256")
        crop_cave(i, 512, "RGB", save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/Y_512")
    image_path = r"/data/dataset/CAVE/CAVEdata/test/X"

    filename = glob(image_path + "/*.mat")
    for i in filename:
        print(i)
        # basename = os.path.basename(i)
        crop_cave(i, 128, "msi",save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/X_128")
        crop_cave(i, 256, "msi", save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/X_256")
        crop_cave(i, 512, "msi", save_path="/data/dataset/CAVE/CAVEdata/test_split_for_swinir/X_512")

    # image_path = r"/data/dataset/Pavia_Centre/train/HSI"
    #
    # filename = glob(image_path + "/*.npy")
    # for i in filename:
    #     print(i)
    #     # basename = os.path.basename(i)
    #     crop_pavia(i, 64, save_path="/data/dataset/Pavia_Centre/train_split64/HSI")
    #     crop_pavia(i, 48, save_path="/data/dataset/Pavia_Centre/train_split48/HSI")
    #
    # image_path = r"/data/dataset/Pavia_Centre/train/MSI"
    #
    # filename = glob(image_path + "/*.npy")
    # for i in filename:
    #     print(i)
    #     # basename = os.path.basename(i)
    #     crop_pavia(i, 64, save_path="/data/dataset/Pavia_Centre/train_split64/MSI")
    #     crop_pavia(i, 48, save_path="/data/dataset/Pavia_Centre/train_split48/MSI")