#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 24 18:31:14 2024

@author: anonymous
"""

import matplotlib.pyplot as plt
import numpy as np
import pickle
import tqdm
import os
import time
import gudhi
from gudhi.representations import PersistenceImage


###############################################################################



def convert(diag,k):
    """ Converts Gudhi persistence diagram diag to numpy array of points in
        homology degree k.

	Parameters
	----------
	diag : Gudhi persistence diagram 
    k : Int
            homology degree
    Returns
    ----------
    Numpy array
            points in homology degree k     
    """
    D=[]
    
    for i in range(len(diag)):
        if diag[i][0]==k:
            D.append([diag[i][1][0],diag[i][1][1]])
            
    return np.array(D)   



###############################################################################



def generate_persistence_images(directory, outfile):    
    """ Computes persistence images in homology degree 1
        of alpha filtrations of point clouds in directory
        and writes them to outfile.

	Parameters
	----------
	directory : String 
			path of the folder containing the filtration-files
    	outfile : String 
    			path of the file containing the persistence images
    """

    PI=PersistenceImage(bandwidth=0.1, weight=lambda x: 1, resolution=[100,100],im_range=[0,1,0,1])

    output=[]
    time_count=0

    number_of_files=len([item for item in os.listdir(directory) if os.path.isfile(os.path.join(directory, item))])

    for i in tqdm.tqdm(range(number_of_files)):
        P=np.loadtxt(directory+'/'+str(i)+'.txt')[:,:2]
        alpha=gudhi.AlphaComplex(P)
        simplex_tree=alpha.create_simplex_tree()
        
        start=time.time()
        
        diag=convert(simplex_tree.persistence(homology_coeff_field=2, min_persistence=0),1)
        
        if len(diag)>0:
            PM=np.flip(np.reshape(PI(diag), [100,100]), 0).reshape((1,100,100))
        else:
            PM=np.zeros((100,100)).reshape((1,100,100))
            
        time_count+=time.time()-start
        
        output.append(PM)
        
    pickle.dump(np.array(output), open(outfile, 'wb'))
    
    print('Time persistence images: ',time_count,'\n')



###############################################################################



def show_persistence_images(directory):    
    """ Plots persistence image in homology degree 1
        of alpha filtrations of random point cloud in directory.

	Parameters
	----------
	directory : String 
			path of the folder containing the filtration-files
    """

    PI=PersistenceImage(bandwidth=0.1, weight=lambda x: 1, resolution=[100,100],im_range=[0,1,0,1])

    number_of_files=len([item for item in os.listdir(directory) if os.path.isfile(os.path.join(directory, item))])
    i=np.random.randint(0,number_of_files)

    P=np.loadtxt(directory+'/'+str(i)+'.txt')[:,:2]
    alpha=gudhi.AlphaComplex(P)
    simplex_tree=alpha.create_simplex_tree()
    diag=convert(simplex_tree.persistence(homology_coeff_field=2, min_persistence=0),1)
    if len(diag)>0:
        PM=np.flip(np.reshape(PI(diag), [100,100]), 0)
    else:
        PM=np.zeros((100,100))

    plt.imshow(PM)
    plt.show()

###############################################################################



generate_persistence_images('Data/Pointcloud_Dataset','Data/persistence_images.txt')

# show_persistence_images('Data/Pointcloud_Dataset')