import os
from time import time
import numpy as np
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
import argparse
from PIL import Image
from tqdm import tqdm

from joblib import Parallel, delayed


'''
Calculate the agnitude using the Schur complement method of Bunch et al.
'''


class MagnitudeLayer(torch.nn.Module):
    '''
    A custom magnitude layer for the Schur method
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        self.grid = self._generate_grid(x.shape[2],x.shape[3])
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        tmp_matrix = torch.abs(x.view(x.shape[0],x.shape[1],1,-1)-x.view(x.shape[0],x.shape[1],-1,1))
        if self.hamming:
            tmp_matrix = F.threshold(-tmp_matrix,0.,1.)
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.pow(tmp_matrix,self.power)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(x.shape)

class MagnitudeLayerSchur(torch.nn.Module):
    def __init__(self,seed_size=100,step_size=1,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.seed_size = seed_size
        self.step_size = step_size
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        ground_set = self._generate_ground_set(x)
        self.A_inv,self.w_x  = self._inverse_seed(ground_set[:,:,:self.seed_size,:])
        for i in tqdm(range(self.seed_size,ground_set.shape[2],self.step_size)):
            self.B,self.C,self.D = self._get_B_C_D(ground_set[:,:,:i,:],ground_set[:,:,i:i+self.step_size,:])
            self.M_minus_A_inv = self._get_M_minus_A_inv()
            self.rho_M_A = self._get_rho_M_A()
            self.A_inv = self._update_A_inv()
        return torch.sum(self.A_inv,axis=-1)
    def _generate_ground_set(self,x):
        x_pixel,y_pixel = x.shape[2],x.shape[3]
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        grid_t = torch.tile(grid_t,(x.shape[0],x.shape[1],1,1))
        ground_set = torch.cat([grid_t,x.view(x.shape[0],x.shape[1],-1,1)],dim=3)
        return ground_set
    def _inverse_seed(self,x):
        tmp_matrix = torch.cdist(x,x,p=self.p)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return tmp_matrix, torch.sum(tmp_matrix,axis=-1)
    def _get_B_C_D(self,old_set,new_set):
        B = torch.exp(-torch.cdist(old_set,new_set,p=self.p))
        C = B.permute(0,1,3,2)
        D = torch.exp(-torch.cdist(new_set,new_set,p=self.p))
        return B,C,D
    def _get_M_minus_A_inv(self):
        return torch.inverse(self.D - torch.matmul(torch.matmul(self.C,self.A_inv),self.B))
    def _get_rho_M_A(self):
        R_1 = torch.matmul(self.A_inv,torch.matmul(torch.matmul(self.B,self.M_minus_A_inv),torch.matmul(self.C,self.A_inv)))
        R_2 = -torch.matmul(torch.matmul(self.A_inv,self.B),self.M_minus_A_inv)
        R_3 = -torch.matmul(self.M_minus_A_inv,torch.matmul(self.C,self.A_inv))
        R_4 = self.M_minus_A_inv
        R_12 = torch.cat([R_1,R_2],dim=3)
        R_34 = torch.cat([R_3,R_4],dim=3)
        return torch.cat([R_12,R_34],dim=2)
    def _update_A_inv(self):
        tmp = torch.zeros_like(self.rho_M_A)
        tmp[:,:,:self.A_inv.shape[2],:self.A_inv.shape[3]] = self.A_inv
        return tmp + self.rho_M_A

def min_max(img):
    return (img-np.min(img))/(np.max(img)-np.min(img))

def compute_error(mag_img,mag_img_approx):
    errors = np.abs(mag_img-mag_img_approx)
    Frobenius_error = np.sum(np.power(mag_img-mag_img_approx,2))
    Frobenius_mag_img = np.sum(np.power(mag_img,2))
    return np.max(errors), Frobenius_error/Frobenius_mag_img

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', action='store',default='.',type=str)
    parser.add_argument('--image',action='store',type=str)
    parser.add_argument('--greyscale',action='store_true')
    parser.add_argument('--x_rescale_factor', action='store', default=1, type=float)
    parser.add_argument('--y_rescale_factor', action='store', default=1, type=float)

    args = parser.parse_args()


    img_name = args.image.split('.')[0]

    mag_layer = MagnitudeLayer()
    mag_layer_schur = MagnitudeLayerSchur(seed_size=2400,step_size=10)

    # Open and preprocess the image
    img = Image.open(os.path.join(args.data_path,args.image))
    img_arr = np.asarray(img)

    print(f'The image {args.image} has dimensions (h,w,c): {img_arr.shape}')
    if len(img_arr.shape) == 3:
        img_t = torch.from_numpy(img_arr).permute(2,0,1)/255.
        if args.greyscale:
            img_t = torch.mean(img_t,axis=0)
            img_t = img_t.unsqueeze(0).unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))
        else:
            img_t = img_t.unsqueeze(0)
            img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))
    else:
        img_t = torch.from_numpy(img_arr)/255.
        img_t = img_t.unsqueeze(0).unsqueeze(0)
        img_t_small = F.resize(img_t,(int(img_t.shape[2]*args.y_rescale_factor),int(img_t.shape[3]*args.x_rescale_factor)))

    print(f'The rescaled image {args.image} has dimensions (b,c,h,w): {img_t_small.shape}')

    y_shape = img_t_small.shape[2]
    x_shape = img_t_small.shape[3]

    tic = time()
    mag_img = mag_layer.forward(img_t_small).squeeze().view(y_shape,x_shape).numpy()
    mag_img = min_max(mag_img)
    toc = time()
    print(f'Elapsed time during full magnitude calculation: {(toc-tic):.2f}s')
    #
    tic = time()
    mag_img_schur = mag_layer_schur.forward(img_t_small).squeeze().view(y_shape,x_shape).numpy()
    mag_img_schur = min_max(mag_img_schur)
    toc = time()
    print(f'Elapsed time during Schur magnitude calculation: {toc-tic}s')

    print(f'Max error: {compute_error(mag_img,mag_img_schur)}')

    plt.figure(f'Mag Schur')
    plt.imshow(mag_img_schur)
    plt.colorbar()
    plt.tight_layout()

    plt.figure(f'Error')
    plt.imshow(np.abs(mag_img-mag_img_schur))
    plt.colorbar()
    plt.tight_layout()

    plt.figure('Img')
    plt.imshow(img_t_small.squeeze().numpy())
    plt.colorbar()
    plt.tight_layout()

    plt.figure('Mag img')
    plt.imshow(mag_img)
    plt.colorbar()
    plt.tight_layout()

    plt.show()

if __name__ == '__main__':
    main()
