#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 24 15:52:25 2024

@author: anonymous
"""

import pandas as pd
import pickle
import tqdm
import numpy as np
import torch
import os


###############################################################################


def gril_inputs_from_filtrations(directory, outfile):    
    """ Computes inputs for GRIL from filtration files obtained by 
        function_delaunay and writes them to file.

	Parameters
	----------
	directory : String 
			path of the folder containing the filtration-files
    	outfile : String 
    			path of the file containing the GRIL inputs
    """

    number_of_files=len([item for item in os.listdir(directory) if os.path.isfile(os.path.join(directory, item))])
    
    output=[]
    
    for j in tqdm.tqdm(range(number_of_files)):
        path=directory+'/'+str(j)+'.txt'
        
        df=pd.read_csv(path, sep='[;, ]', skiprows=2, nrows=1, header=None, engine='python')
    
        d3,d2,d1,d0=df.to_numpy()[0,:4]
    
        df=pd.read_csv(path, sep='[ , ; ]', skiprows=3+d3 , nrows=d2, header=None, engine='python')
    
        tri_index=df.to_numpy()[:,:2]
        tri_simp=df.to_numpy()[:,4:]
    
        df=pd.read_csv(path, sep='[ , ; ]', skiprows=3+d3+d2 , nrows=d1, header=None, engine='python')
    
        edge_index=df.to_numpy()[:,:2]
        edge_simp=df.to_numpy()[:,4:]
    
        df=pd.read_csv(path, sep='[ , ; ]', skiprows=3+d3+d2+d1 , nrows=d0, header=None, engine='python')
    
        vert_index=df.to_numpy()[:,:2]
        
        simplices=[]
        index=np.concatenate((vert_index,edge_index,tri_index))
        
        vert=[[i] for i in range(len(vert_index))]
        edge=edge_simp.astype(int).tolist()
        triangle=tri_simp.astype(int).tolist()
        face=[]
    
        for i in range(len(triangle)):
            t=edge[triangle[i][0]][:]
            if t[0]==edge[triangle[i][1]][0] or t[1]==edge[triangle[i][1]][0]: t.append(edge[triangle[i][1]][1])
            else: t.append(edge[triangle[i][1]][0])
            face.append(t)
        
        simplices=vert+edge+face
    
        m0=np.max(index[:,0])
        index[:,0]=index[:,0]/m0
        m1=np.max(index[:,1])
        index[:,1]=index[:,1]/m1
    
        for i in range(len(simplices)):
            if len(simplices[i])==2:
                index[i,0]+=0.000001
                index[i,1]+=0.000001
            if len(simplices[i])==3:
                index[i,0]+=0.000002
                index[i,1]+=0.000002
                
        f=torch.from_numpy(index)
            
        pers_inp=[(f, simplices)]
                
        output.append(pers_inp)
        
    pickle.dump(output,open(outfile, 'wb'))
    
    
###############################################################################


gril_inputs_from_filtrations('Data/Filtration_Dataset','Data/gril_inputs.txt')