import os
from time import time
import numpy as np
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
import argparse
from PIL import Image

class MagnitudeLayer(torch.nn.Module):
    '''
    The simple 4 step magnitude layer
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        self.grid = self._generate_grid(x.shape[2],x.shape[3])
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        tmp_matrix = torch.abs(x.view(x.shape[0],x.shape[1],1,-1)-x.view(x.shape[0],x.shape[1],-1,1))
        if self.hamming:
            tmp_matrix = F.threshold(-tmp_matrix,0.,1.)
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.pow(tmp_matrix,self.power)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(x.shape)

def make_patches(img,x_y_patches,overlap=1):
    '''
    Create overlapping patches of the image with overlap of overlap many pixels.
    '''
    if len(img.shape) == 3:
        height,width,_ = img.shape
    else:
        height,width = img.shape
    x_patches, y_patches = x_y_patches
    x_patch_size = int(width/x_patches)
    y_patch_size = int(height/y_patches)

    if (height/x_patches)%1 != 0.0:
        print('The image height must be divisible by the number of patches!')
        return 1
    elif (width/y_patches)%1 != 0.0:
        print('The image width must be divisible by the number of patches!')
        return 1

    patches = []
    patch_grid = []

    for i in range(x_patches):
        for j in range(y_patches):
            patch_grid.append((j,i))
            patches.append(img[max(j*y_patch_size-overlap,0):min((j+1)*y_patch_size+overlap,height),max(i*x_patch_size-overlap,0):min((i+1)*x_patch_size+overlap,width)])
    return patches, patch_grid

def stitch_patches(patches,patch_grid,x_y_patches,channels=1,overlap=1):
    '''
    Stitch the patches back together
    '''
    x_patches, y_patches = x_y_patches
    for p,g in enumerate(patch_grid):
        j,i = g
        if i == j == 0:
            if (x_patches == 1) and (y_patches == 1):
                patch_y_shape,patch_x_shape = patches[p].shape
            elif (x_patches == 1) and (y_patches > 1):
                patch_y_shape,patch_x_shape = patches[p][:-overlap,:].shape
            elif (x_patches > 1) and (y_patches == 1):
                patch_y_shape,patch_x_shape = patches[p][:,:-overlap].shape
            else:
                patch_y_shape,patch_x_shape = patches[p][:-overlap,:-overlap].shape
            if channels > 1:
                img = np.zeros((y_patches*patch_y_shape,x_patches*patch_x_shape,channels),dtype=np.float32)
            else:
                img = np.zeros((y_patches*patch_y_shape,x_patches*patch_x_shape),dtype=np.float32)
        if 1<y_patches<y_patches*patch_y_shape:
            if j == 0:
                patches[p] = patches[p][:-overlap,:]
            elif j == y_patches-1:
                patches[p] = patches[p][overlap:,:]
            else:
                patches[p] = patches[p][overlap:-overlap,:]
        elif patch_y_shape == 1:
            if j < overlap:
                patches[p] = patches[p][j:-overlap,:]
            elif j > y_patches-overlap-1:
                if (j+1-y_patches) < 0:
                    patches[p] = patches[p][overlap:j+1-y_patches,:]
                else:
                    patches[p] = patches[p][overlap:,:]
            else:
                patches[p] = patches[p][overlap:-overlap,:]
        if 1<x_patches<x_patches*patch_x_shape:
            if i == 0:
                patches[p] = patches[p][:,:-overlap]
            elif i == x_patches-1:
                patches[p] = patches[p][:,overlap:]
            else:
                patches[p] = patches[p][:,overlap:-overlap]
        elif patch_x_shape == 1:
            if i < overlap:
                patches[p] = patches[p][:,i:-overlap]
            elif i > x_patches-overlap-1:
                if (i+1-x_patches) < 0:
                    patches[p] = patches[p][:,overlap:i+1-x_patches]
                else:
                    patches[p] = patches[p][:,overlap:]
            else:
                patches[p] = patches[p][:,overlap:-overlap]
        img[j*patch_y_shape:(j+1)*patch_y_shape,i*patch_x_shape:(i+1)*patch_x_shape] = patches[p]
    return img

def min_max(img):
    return (img-np.min(img))/(np.max(img)-np.min(img))

def compute_error(mag_img,mag_img_approx):
    errors = np.abs(mag_img-mag_img_approx)
    Frobenius_error = np.sum(np.power(mag_img-mag_img_approx,2))
    Frobenius_mag_img = np.sum(np.power(mag_img,2))
    return np.max(errors), Frobenius_error/Frobenius_mag_img

def parallel_patches(mag_layer,patch):
    tmp = torch.from_numpy(patch).reshape(1,1,*patch.shape).contiguous()
    mag_patch = mag_layer.forward(tmp).squeeze().numpy()
    return mag_patch

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', action='store',default='.',type=str)
    parser.add_argument('--image',action='store',type=str)
    parser.add_argument('--greyscale',action='store_true')
    parser.add_argument('--x_patches', action='store', default=2, type=int)
    parser.add_argument('--y_patches', action='store', default=2, type=int)
    parser.add_argument('--overlap', action='store', default=1, type=int)
    parser.add_argument('--x_rescale_factor', action='store', default=1, type=float)
    parser.add_argument('--y_rescale_factor', action='store', default=1, type=float)
    parser.add_argument('--l_pixel', action='store', default=1., type=float)


    args = parser.parse_args()

    x_patches = args.x_patches #x_patches in the x direction
    y_patches = args.y_patches #y_patches in the y direction
    overlap = args.overlap
    img_name = args.image.split('.')[0]

    mag_layer = MagnitudeLayer(l_pixel=args.l_pixel)

    # Open and preprocess the image
    img = Image.open(os.path.join(args.data_path,args.image))
    img_arr = np.asarray(img)

    print(f'The image {args.image} has dimensions (h,w,c): {img_arr.shape}')

    if len(img_arr.shape) == 3:
        img_t = torch.from_numpy(img_arr).permute(2,0,1)/255.
        img_t = img_t[:3]
        if args.greyscale:
            img_t = torch.mean(img_t,axis=0)
            img_t = img_t.unsqueeze(0).unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))
        else:
            img_t = img_t.unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))
    else:
        img_t = torch.from_numpy(img_arr)/255.
        img_t = img_t.unsqueeze(0).unsqueeze(0)
        img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))

    print(f'The rescaled image {args.image} has dimensions (b,c,h,w): {img_t_small.shape}')

    y_shape = img_t_small.shape[2]
    x_shape = img_t_small.shape[3]

    # Full magnitude calculation
    tic = time()
    mag_img = mag_layer.forward(img_t_small).squeeze().view(y_shape,x_shape).numpy()
    mag_img = min_max(mag_img)
    toc = time()
    print(f'Elapsed time during full magnitude calculation: {(toc-tic):.2f}s')

    output = np.zeros((8,4))

    # Patch calculation for different number of patches
    for i,n_patches in enumerate([2,4,8,16,32,64,128,256]):
        x_patches = n_patches # Use square patches for simplicity
        y_patches = n_patches

        # Patch-wise magnitude calculation
        tic = time()
        patches,patch_grid = make_patches(img_t_small.permute(0,2,3,1).squeeze().numpy(),(x_patches,y_patches),overlap=overlap)
        patches_vecs = []

        for p in patches:
            tmp = torch.from_numpy(p).reshape(1,1,*p.shape).contiguous()
            mag_patch = mag_layer.forward(tmp).squeeze().numpy()
            patches_vecs.append(mag_patch)
        mag_img_approx = stitch_patches(patches_vecs,patch_grid,(x_patches,y_patches),overlap=overlap)
        mag_img_approx = min_max(mag_img_approx)
        toc = time()
        print(f'Elapsed time during approximate magnitude calculation: {(toc-tic):.2f}s')

        print(f'Max error: {compute_error(mag_img,mag_img_approx)}')

        output[i] = np.array([n_patches,toc-tic,*compute_error(mag_img,mag_img_approx)])

        plt.figure(f'Mag patches_{n_patches}')
        plt.imshow(mag_img_approx)
        plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'Mag_approx_{img_name}_patches_{n_patches}.pdf'))
        plt.close()

        plt.figure(f'Error_{n_patches}')
        plt.imshow(np.abs(mag_img-mag_img_approx))
        plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'Error_{img_name}_patches_{n_patches}.pdf'))
        plt.close()

    plt.figure('Img')
    plt.imshow(img_t_small.squeeze().numpy())
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(os.path.join('output',f'Rescaled_{img_name}.pdf'))
    plt.close()

    plt.figure('Mag img')
    plt.imshow(mag_img)
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(os.path.join('output',f'Mag_{img_name}.pdf'))
    plt.close()

        #
        # plt.show()
    np.savetxt(os.path.join('output',f'patches_img_{img_name}.txt'),output)

if __name__ == '__main__':
    main()
