import os
from time import time
import numpy as np
import pandas as pd
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from matplotlib import transforms
import argparse
from PIL import Image
from tqdm import tqdm
from data.MNIST_test import LitMNIST
from data.KMNIST_test import LitKMNIST
from data.FashionMNIST_test import LitFashionMNIST


class MagnitudeLayerGrid(torch.nn.Module):
    '''
    Custom magnitude layer which takes the vecotr set as an input
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        return self._magnitude_vec(x)
    def _magnitude_vec(self,x):
        grid_matrix = torch.cdist(x,x,p=self.p)
        if self.hamming:
            tmp_matrix = F.threshold(-tmp_matrix,0.,1.)
        tmp_matrix = self.l_grid*grid_matrix
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(x.shape[0])

def generate_ground_set(x):
    '''
    Generate the ground set of vectors
    '''
    x_pixel,y_pixel = x.shape[2],x.shape[3]
    xx = torch.linspace(0,x_pixel-1,x_pixel)
    yy = torch.linspace(0,y_pixel-1,y_pixel)
    grid = torch.meshgrid(xx,yy)
    grid_t = torch.stack(grid).view(2,-1).permute(1,0)
    grid_t = torch.tile(grid_t,(x.shape[0],x.shape[1],1,1))
    ground_set = torch.cat([grid_t,x.view(x.shape[0],x.shape[1],-1,1)],dim=3)
    return ground_set

def min_max(img):
    return (img-np.min(img))/(np.max(img)-np.min(img))

def compute_error(mag_img,mag_img_approx):
    errors = np.abs(mag_img-mag_img_approx)
    Frobenius_error = np.sum(np.power(mag_img-mag_img_approx,2))
    Frobenius_mag_img = np.sum(np.power(mag_img,2))
    return np.max(errors), Frobenius_error/Frobenius_mag_img

def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_set', action='store',default='MNIST',type=str)
    args = parser.parse_args()

    data = eval(f'Lit{args.data_set}()')
    data.setup()
    data_loader = data.test_dataloader(bs=1)
    data_point = next(iter(data_loader))[0]

    ground_set_t = generate_ground_set(data_point)
    ground_set = ground_set_t.squeeze()
    ground_set_arr = ground_set.numpy()

    # Plot the filtration scheme
    fig = plt.figure()
    xs = np.linspace(0, 27, 2)
    ys = np.linspace(0, 27, 2)
    X, Y = np.meshgrid(xs, ys)
    Z_1 = 0.0*np.ones_like(X)
    Z_2 = (1/2)*np.ones_like(X)
    Z_3 = (2/3)*np.ones_like(X)
    Z_4 = 1.0*np.ones_like(X)
    ax = plt.axes(projection='3d')
    ax.scatter3D(ground_set_arr[:,0], ground_set_arr[:,1], ground_set_arr[:,2], c=ground_set_arr[:,2],cmap='jet',zorder=1)
    ax.plot_surface(X, Y, Z_1,alpha=1,color='black',zorder=0)
    ax.plot_surface(X, Y, Z_2,alpha=0.5,color='black',zorder=0)
    ax.plot_surface(X, Y, Z_4,alpha=0.1,color='black',zorder=0)
    ax.w_xaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
    ax.w_yaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
    ax.w_zaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
    plt.tight_layout()
    plt.savefig(os.path.join('output','Filtration_scheme.pdf'))
    plt.close()

    mag_layer = MagnitudeLayerGrid()

    thresholds = np.linspace(0,1,11)
    print(thresholds)

    # Generate the output dataframe
    mag_vecs_df = pd.DataFrame(columns=['point','x','y',*[f'mag_up_{i}' for i in range(len(thresholds))],*[f'mag_down_{i}' for i in range(len(thresholds))]],data=np.nan*np.ones((ground_set.shape[0],3+2*len(thresholds))))
    mag_vecs_df.loc[:,['x','y']] = ground_set[:,:2].numpy()
    mag_vecs_df.loc[:,'point'] = list(zip(ground_set[:,0].numpy(),ground_set[:,1].numpy()))

    # Run the filtration
    for i,t in enumerate(thresholds):
        # Up filtration
        subset_up = ground_set[ground_set[:,2]>=t]
        grid_points_up = subset_up[:,:2]

        # Down filtration
        subset_down = ground_set[ground_set[:,2]<t]
        grid_points_down = subset_down[:,:2]

        mag_vec_up = mag_layer.forward(grid_points_up)
        mag_vec_down = mag_layer.forward(grid_points_down)
        grid_zip_up = list(zip(grid_points_up.numpy()[:,0],grid_points_up.numpy()[:,1]))
        grid_zip_down = list(zip(grid_points_down.numpy()[:,0],grid_points_down.numpy()[:,1]))

        mag_vecs_df.loc[mag_vecs_df['point'].isin(grid_zip_up),f'mag_up_{i}'] = mag_vec_up.numpy()
        mag_vecs_df.loc[mag_vecs_df['point'].isin(grid_zip_down),f'mag_down_{i}'] = mag_vec_down.numpy()

        # Plot the magnitude vecotrs of the subgrids
        plt.figure(figsize=(6,6))
        base = plt.gca().transData
        rotation = transforms.Affine2D().rotate_deg(-90)
        plt.scatter(grid_points_up[:,0],grid_points_up[:,1],s=100*mag_vec_up,c=mag_vec_up,transform=rotation+base,cmap='plasma')
        plt.axis('off')
        # plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'Up_filtration_threshold_{t}.pdf'))
        plt.close()

        plt.figure(figsize=(6,6))
        base = plt.gca().transData
        rotation = transforms.Affine2D().rotate_deg(-90)
        plt.scatter(grid_points_down[:,0],grid_points_down[:,1],s=100*mag_vec_down,c=mag_vec_down,transform=rotation+base,cmap='plasma')
        plt.axis('off')
        # plt.colorbar()
        plt.tight_layout()
        plt.savefig(os.path.join('output',f'Down_filtration_threshold_{t}.pdf'))
        plt.close()

    # Aggregate the subgrids, here we use simple mean aggregation
    mag_vecs_df['avg_mag_vec'] = mag_vecs_df.loc[:,[*[f'mag_up_{i}' for i in range(len(thresholds))],*[f'mag_down_{i}' for i in range(len(thresholds))]]].mean(axis=1)

    mag_vec_filtration = min_max(mag_vecs_df['avg_mag_vec'].values)
    mag_vec_filtration = mag_vec_filtration.reshape(28,28)


    plt.figure(figsize=(6,6))
    plt.imshow(mag_vec_filtration)
    # plt.colorbar()
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join('output',f'Filtration_magnitude.pdf'))
    plt.close()

if __name__ == '__main__':
    main()
