"""This module contains simple helper functions """
from __future__ import print_function
import re
import os
import sys
import shutil
import subprocess


import h5py
import numpy as np
#import torch
from PIL import Image
from scipy.interpolate import RegularGridInterpolator

from collections.abc import Iterable
from tqdm import tqdm


EPS = 0.000000001
# math utils
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
def logistic(x,x0=.5,L=1,k=20):
    return L/(1+np.exp(-k*(x-x0)))
def s_shape_curve_0to1(x, beta=3.):
    xeps = x + EPS
    return 1 / (1+np.power(xeps/(1-xeps), -beta))
def normalize(x):
    return x/np.sqrt((x*x).sum())
def length(x):
    return np.linalg.norm(x)

def point2lineDistance(q, p1, p2):
    d = np.linalg.norm(np.cross(p2-p1, p1-q))/np.linalg.norm(p2-p1)
    return d
def get2DRotMat(theta=90, mode='degree'):
    if mode == 'degree':
        theta = np.radians(theta)
    return np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
def pointSegDistance(q, p1, p2):
    line_vec = p2-p1
    pnt_vec = q-p1
    line_len = np.linalg.norm(line_vec)
    line_unitvec = normalize(line_vec)
    pnt_vec_scaled = pnt_vec * 1.0/line_len
    t = np.dot(line_unitvec, pnt_vec_scaled)
    if t < 0.0:
        t = 0.0
    elif t > 1.0:
        t = 1.0
    nearest = line_vec * t
    dist = length(nearest - pnt_vec)
    nearest = nearest + p1
    return (dist, nearest)
def rescale(vertex):
    shifted = vertex-vertex.mean(-2)[..., np.newaxis, :]
    scope  = np.abs(shifted).max((-1,-2))[..., np.newaxis, np.newaxis]
    vertex = shifted / scope * 40 + np.array([50.,50.])
    print('checking scope: ' , scope)
    print('checking vertex: ', vertex.max(-2))
    return vertex
def cm2inch(*tupl):
    inch = 2.54
    if isinstance(tupl[0], tuple):
        return tuple(i/inch for i in tupl[0])
    else:
        return tuple(i/inch for i in tupl)
def sampleTriangle(vertices, sampleNum=10, noVert=False):
    # vertices: numpy array of 
    if noVert == False:
        rd_a, rd_b = np.random.rand(sampleNum-3), np.random.rand(sampleNum-3)
    else:
        rd_a, rd_b = np.random.rand(sampleNum), np.random.rand(sampleNum)
    larger_than_1 = (rd_a + rd_b > 1.)
    rd_a[larger_than_1] = 1 - rd_a[larger_than_1]
    rd_b[larger_than_1] = 1 - rd_b[larger_than_1]
    if noVert == False:
        rd_a = np.r_[0,1,0,rd_a]
        rd_b = np.r_[0,0,1,rd_b]
    samples = np.array([vertices[0] + rd_a[i]*(vertices[1]-vertices[0]) + rd_b[i]*(vertices[2]-vertices[0]) \
                            for i in range(sampleNum)])
    return samples
def randQuat(N=1):
    s = np.random.rand(3,N)
    sigma1 = np.sqrt(1.0 - s[0])
    sigma2 = np.sqrt(s[0])
    theta1 = 2*np.pi*s[1]
    theta2 = 2*np.pi*s[2]
    w = np.cos(theta2)*sigma2
    x = np.sin(theta1)*sigma1
    y = np.cos(theta1)*sigma1
    z = np.sin(theta2)*sigma2
    return np.array([w, x, y, z])
def multQuat(Q1,Q2):
    w0,x0,y0,z0 = Q1   # unpack
    w1,x1,y1,z1 = Q2
    return([-x1*x0 - y1*y0 - z1*z0 + w1*w0, x1*w0 + y1*z0 - z1*y0 +
    w1*x0, -x1*z0 + y1*w0 + z1*x0 + w1*y0, x1*y0 - y1*x0 + z1*w0 +
    w1*z0])
def conjugateQuat(Q):
    return np.array([Q[0],-Q[1],-Q[2],-Q[3]])
def applyQuat(V, Q):
    P = np.array([0., V[0], V[1], V[2]])
    nP = multQuat(Q, multQuat(P, conjugateQuat(Q)) )
    return nP[1:4]
def fibonacci_sphere(samples=1000):
    rnd = 1.

    points = []
    offset = 2./samples
    increment = np.pi * (3. - np.sqrt(5.));

    for i in range(samples):
        y = ((i * offset) - 1) + (offset / 2);
        r = np.sqrt(1 - np.power(y,2))

        phi = ((i + rnd) % samples) * increment

        x = np.cos(phi) * r
        z = np.sin(phi) * r

        points.append([x,y,z])

    return points

def setDiffND(array1, array2, precision=1):
    array1 = np.round(array1, decimals=precision)
    array2 = np.round(array2, decimals=precision)
    print(array1, array2)
    a1_rows = array1.view([('', array1.dtype)] * array1.shape[1])
    a2_rows = array2.view([('', array2.dtype)] * array2.shape[1])
    diff = np.setdiff1d(a1_rows, a2_rows).view(array1.dtype).reshape(-1, array2.shape[1])
    return diff
def split_by_category(array, category):
    splited = []
    _, index = np.unique(category, return_index=True)
    order_preserved_unique = category[np.sort(index)]
    for cate in order_preserved_unique:
        splited.append(array[category==cate])
    #print(splited)
    if type(array) is np.ndarray:
        return np.array(splited)
    return torch.stack(splited)
def array2chunks(array, chunk_size=2):
    oshape = array.shape
    chunks = array[:oshape[0]//chunk_size*chunk_size]
    if oshape[0]//chunk_size*chunk_size != oshape[0]:
        print("Warning! some data will be dropped during array2chunks (len mod chunk_size!=0)")
    chunks = chunks.reshape( oshape[0]//chunk_size, chunk_size ,*oshape[1:] )
    return chunks
# legacy from CycleGAN
def save_image(image_numpy, image_path, aspect_ratio=1.0):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """

    image_pil = Image.fromarray(image_numpy)
    h, w, _ = image_numpy.shape

    if aspect_ratio > 1.0:
        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
    if aspect_ratio < 1.0:
        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
    image_pil.save(image_path)
def print_numpy(x, title="", val=True, printAll=False, shp=True):
    """Print the mean, min, max, median, std, and size of a numpy array

    Parameters:
        val (bool) -- if print the values of the numpy array
        shp (bool) -- if print the shape of the numpy array
    """
    if title:
        print(title)
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape, x.dtype)
    if printAll:
        print(x)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
# torch tools
def diagnose_network(net, name='network'):
    """Calculate and print the mean of average absolute(gradients)

    Parameters:
        net (torch network) -- Torch network
        name (str) -- the name of the network
    """
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)
def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

## python utils
def dictAdd(dict1, dict2):
    for key in dict2:
        if key in dict1:
            dict1[key]+= dict2[key]
        else:
            dict1[key] = dict2[key]
    return dict1
def dictAppend(dict1, dict2):
    for key in dict2:
        if key in dict1:
            dict1[key].append(dict2[key])
        else:
            dict1[key] = [dict2[key]]
    return dict1
def dictApply(func, dic):
    for key in dic:
        dic[key] = func(dic[key])
def dictMean(dicts):
    keys = dicts[0].keys()
    accum = {k: np.stack([x[k] for x in dicts if k in x]).mean() for k in keys}
    return accum
def prefixDictKey(dic, prefix=''):
    return dict([(prefix+key, dic[key]) for key in dic])
pj=os.path.join
def strNumericalKey(s):
    if s:
        try:
            c = re.findall('\d+', s)[0]
        except:
            c = -1
        return int(c)
# List dir and sort by numericals (if applicable)
def listdir(directory, return_path=True):
    filenames = os.listdir(directory)
    filenames.sort(key=strNumericalKey)
    if return_path==True:
        paths = [os.path.join(directory, filename) for filename in filenames]
        return filenames, paths
    else:
        return filenames
def mkdirs(paths):
    """create empty directories if they don't exist

    Parameters:
        paths (str list) -- a list of directory paths
    """
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)
def mkdir(path):
    """create a single empty directory if it didn't exist

    Parameters:
        path (str) -- a single directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)
def runCommand(command):
    import bashlex
    cmd = list(bashlex.split(command))
    p = subprocess.run(cmd, capture_output=True) # capture_output=True -> capture stdout&stderr
    stdout = p.stdout.decode('utf-8')
    stderr = p.stderr.decode('utf-8')
    return stdout, stderr, p.returncode
def array2batches(array, chunkSize=4):
    return [array[i*chunkSize:(i+1)*chunkSize] for i in range( (len(array)+chunkSize-1)//chunkSize )]
def progbar(array, total=None):
    if total is None:
        if isinstance(array, enumerate):
            array = list(array)
        total = len(array)
    return tqdm(array,total=total, ascii=True)
def parallelMap(func, args, batchFunc=None, zippedIn=True, zippedOut=False, cores=-1, quiet=False):
    from pathos.multiprocessing import ProcessingPool
    """Parallel map using multiprocessing library Pathos

    Args:
        stderr (function): func
        args (arguments): [arg1s, arg2s ,..., argns](zippedIn==True) or [[arg1,arg2,...,argn], ...](zippedIn=False)
        batchFunc (func, optional): TODO. Defaults to None.
        zippedIn (bool, optional): See [args]. Defaults to True.
        zippedOut (bool, optional): See [Returns]. Defaults to False.
        cores (int, optional): How many processes. Defaults to -1.
        quiet (bool, optional): if do not print anything. Defaults to False.

    Returns:
        tuples: [out1s, out2s,..., outns](zippedOut==False) or [[out1,out2,...,outn], ...](zippedOut==True)
    """    
    if batchFunc is None:
        batchFunc = lambda x:x
    if zippedIn==True:
        args = list(map(list, zip(*args))) # transpose
    if cores==-1:
        cores = os.cpu_count()
    pool = ProcessingPool(nodes=cores)
    batchIdx = list(range(len(args[0])))
    batches = array2batches(batchIdx, cores)
    out = []
    iterations = enumerate(batches) if quiet==True else progbar(enumerate(batches))
    for i,batch in iterations:
        batch_args = [[arg[i] for i in batch] for arg in args]
        out.extend( pool.map(func, *batch_args) )
    if zippedOut == False:
        if type(out[0]) is not tuple:
            out=[(item,) for item in out]
        out = list(map(list, zip(*out)))
    return out
def imgs2video(targetDir, folderName, frameRate=6):
    ''' Making video from a sequence of images
    
        Make a video from images with index, e.g. 1.png 2.png 3.png ...
        the images should be in targetDir/folderName/ 
        the output will be targetDir/folderName.mp4 .
        Args:
            targetDir: the output directory
            folderName: a folder in targetDir which contains the images
        Returns:
            stdout: stdout
            stderr: stderr
            returncode: exitcode
    '''
    imgs_dir = os.path.join(targetDir, folderName)
    command = 'ffmpeg -framerate {2} -f image2 -i {0} -c:v libx264 -pix_fmt yuv420p -r 25 {1} -y'.format(  \
            os.path.join(imgs_dir,'%d.png'),                                                \
            os.path.join(targetDir, '%s.mp4'%folderName),                                   \
            frameRate
            )
    print('Executing command: ', command)
    _, stderr, returncode = runCommand(command)
    if returncode!=0:
        print("ERROR happened during making visRecon video:\n error code:%d"%returncode, stderr)

def imgs2video2(imgs_dir, out_path, frameRate=6):
    ''' Making video from a sequence of images
    
        Make a video from images with index, e.g. 1.png 2.png 3.png ...
        the images should be in imgs_dir
        the output will be at out_path
        Args:
            targetDir: the output directory
            folderName: a folder in targetDir which contains the images
        Returns:
            stdout: stdout
            stderr: stderr
            returncode: exitcode
    '''
    command='ffmpeg -framerate {framerate} -f image2 -i {imgs} -c:v libx264 -pix_fmt yuv420p -r 25 {outpath} -y'
    command = command.format(outpath=out_path,imgs=os.path.join(imgs_dir,'%d.png'),framerate=frameRate)
    print('Executing command: ', command)
    _, stderr, returncode = runCommand(command)
    if returncode!=0:
        print("ERROR happened during making video:\n error code:%d"%returncode, stderr)

def imgarray2video(targetPath, img_list):
    import vis
    temp_dir = os.path.expanduser('~/.temp_imgarray2video/')
    print(os.path.realpath)
    if os.path.exists(temp_dir):
        raise OSError('Temp folder already exists!')
    else:
        try:
            vis.saveImgs(targetDir=temp_dir, baseName='', imgs=img_list)
            imgs2video2(temp_dir, targetPath)
        finally:
            os.system('rm -r %s'%temp_dir)

def sh2option(filepath, parser, quiet=False):
    with open(filepath, 'r') as file:
        data = file.read().replace('\n', ' ')
    args = [string.lstrip().rstrip() for string in bashlex.split(data)][1:]
    args = [string for string in args if string != '']
    previous_argv = sys.argv 
    sys.argv = args
    opt = parser.parse(quiet=quiet)
    sys.argv = previous_argv
    return opt
def isIterable(object):
    return isinstance(object, Iterable)
def make_funcdict(d=None, **kwargs):
    def funcdict(d=None, **kwargs):
        if d is not None:
            funcdict.__dict__.update(d)
        funcdict.__dict__.update(kwargs)
        return funcdict.__dict__
    funcdict(d, **kwargs)
    return funcdict
def pickleCopy(obj):
    import io
    import pickle
    buf = io.BytesIO()
    pickle.dump(obj, buf)
    buf.seek(0)
    obj2 = pickle.load(buf) 
    return obj2
def makeArchive(folder, outpath, format='zip'):
    folder = os.path.abspath(folder)
    folder_name = os.path.basename(folder)
    parent = os.path.dirname(folder)
    #print(folder, folder_name, parent)
    out_file = outpath + '.%s'%format
    if os.path.exists(out_file):
        os.remove(out_file)
    shutil.make_archive(base_name= outpath,format=format, \
                        root_dir = parent, \
                        base_dir = folder_name)

# linear algebra & np array
def rdchoice(array, num, replace=False):
    choice = sorted(np.random.choice(array.shape[0], num, replace=replace))
    return array[choice]
def arraySupersample(array, weight = None, multiplier=10):
    sup = []
    if weight is None:
        weight = np.linspace(0.,1.,multiplier)
    interp = np.repeat(array[:-1], multiplier ,axis=0)
    interpN= np.repeat(array[1:], multiplier ,axis=0)
    weight = np.tile(weight, array.shape[0]-1 )
    weight = weight.reshape( weight.shape + (1,)*len(array.shape[1:]))
    interped = interp*(1-weight) + interpN*weight
    return interped

def pointInterp(queryP, points, values):
    """Interpolate on *queryP* based on the *values* of *points*

    Args:
        queryP (MxC1 np array): position to be interpolated
        points (NxC1 np array): base points
        values (NxC2 np array): value of base points

    Returns:
        MxC2 np array: value of query points
    """    
    lengths = np.zeros(len(points))
    weights = np.zeros(len(points))
    interp = 0.
    for i,p in enumerate(points):
        lengths[i] =  np.linalg.norm(p-queryP)
        #print(i,lengths)
        if lengths[i]<EPS:
            return values[i]
        weights[i] =  1/lengths[i]
        interp += weights[i] * values[i]
    interp /= weights.sum()
    return interp

def gridInterp(queries, grid, gridRange=np.array([[0.,1.],[0.,1.],[0.,1.]]), return_func=False):
    axes = [np.arange(gridRange[i,0],gridRange[i,1], 1./grid.shape[i]) for i in range(gridRange.shape[0])]
    func = RegularGridInterpolator(axes, grid)
    interp = func(queries)
    if return_func:
        return interp, func
    return interp

def array2mesh(array, thresh=0., dim=3, coords=None):
    import mcubes
    grid = array2NDCube(array, N=dim)
    verts, faces = mcubes.marching_cubes(grid, thresh)
    verts = verts[:,[1,0,2]]/(grid.shape[0]-1) # rearrange order and rescale
    if coords is not None:
        bbmin, bbmax = arrayBBox(coords)
        verts = verts*(bbmax-bbmin) + bbmin
    return verts, faces.astype(int)
def sampleMesh(vert, face, sampleN):
    resample = True
    sampled = None
    while resample:
        try:
            B,FI    = igl.random_points_on_mesh(sampleN, vert, face)
            sampled =   B[:,0:1]*vert[face[FI,0]] + \
                        B[:,1:2]*vert[face[FI,1]] + \
                        B[:,2:3]*vert[face[FI,2]]
            resample=False
            if sampled.shape[0] != sampleN:
                print('Failed to sample "sampleN" points, now resampling...', file=sys.__stdout__)
                resample=True
        except Exception as e:
            print('Error encountered during mesh sampling:', e, file=sys.__stdout__)
            print('Now resampling...', file=sys.__stdout__)
            resample = True
    return sampled
def arrayBBox(array):
    bb_min, bb_max = [], []
    for i in range(array.shape[-1]):
        bb_min.append( array[..., i].min() )
        bb_max.append( array[..., i].max() )
    return np.array([bb_min, bb_max])

def makeGrid(bb_min=[0,0,0], bb_max=[1,1,1], shape=[10,10,10], flatten=True):
    """ Make a grid of coordinates

    Args:
        bb_min (list or np.array): least coordinate for each dimension
        bb_max (list or np.array): maximum coordinate for each dimension
        shape (list or int): list of coordinate number along each dimension. If it is an int, the number
                            same for all dimensions
        flatten (bool, optional): Return as list of points or as a grid. Defaults to True.

    Returns:
        np.array: return coordinates, the format is specified by flatten.
    """    
    coords=[]
    bb_min = np.array(bb_min)
    bb_max = np.array(bb_max)
    if type(shape) is int:
        shape = np.array([shape]*bb_min.shape[0])
    for i,si in enumerate(shape):
        coord = np.linspace(bb_min[i], bb_max[i], si)
        coords.append( coord )
    grid = np.stack(np.meshgrid(*coords,sparse=False), axis=-1)
    if flatten==True:
        grid = grid.reshape(-1,grid.shape[-1])
    return grid
def array2NDCube(array, N=3): 
    vox_dim = np.ceil( np.power(array.shape[0], 1./N) ).astype(int)
    return array.reshape((vox_dim,)*N)
def gridPoints2levelPlanes(points, level_axis=0, decimals=8):
    points = points.round(decimals=decimals)
    uniqeFs = np.unique( points[..., level_axis] )
    levels = []
    for uniqueF in uniqeFs:
        levels.append(points[points[..., level_axis] == uniqueF])
    levels = np.array(levels)
    levels = np.delete(levels, level_axis, axis=-1)
    print([level.shape for level in levels])
    return levels
def inverseWhere(where, shape):
    mask = np.zeros(shape, dtype='bool')
    mask[where]=True
    return mask
def subsampleArray(array, sampleN, axis=0):
    if array.shape[axis] < sampleN:
        return array
    choice = np.random.choice(array.shape[axis], sampleN, replace=False)
    return np.take(array, choice, axis=axis)
def subsampleBoolArray(array, sampleN, ratio=None):
    if ratio is None:
        pos = np.where(array)[0]
        if pos.shape[0] > sampleN:
            choice = np.random.choice(pos.shape[0], sampleN, replace=False)
            pos = pos[choice]
            subArray = inverseWhere(pos, array.shape[0])
        else:
            subArray = array
    else:
        subArray = array
    return subArray
def array2slices(array):
    if array.dtype is not np.dtype(int):
        raise ValueError("Array's type is not int")
    return tuple([slice(0,si) if si>0 else 0 for si in array])
def padArray(objArray):
    if (type(objArray) is np.ndarray) and (not all( (type(obj) is np.ndarray) for obj in objArray)):
        raise ValueError('Invalid objArray! Not every array element is np.ndarray')
    dtype = objArray[0].dtype
    #shapes = np.array([array.shape for array in objArray])
    maxShapeL = np.array([len(array.shape) for array in objArray]).max()
    paddedShapes, slices = [], []
    for array in objArray:
        shape = np.zeros(maxShapeL).astype(int)
        shape[0:len(array.shape)] = np.array(array.shape)
        paddedShapes.append(shape)
        slices.append( array2slices(shape) )
    paddedShapes = np.array(paddedShapes)
    targetShape = paddedShapes.max(axis=0).astype(int)
    targetArrays = np.zeros((len(objArray), *tuple(targetShape)), dtype=dtype)
    for i, array in enumerate(objArray):
        targetArrays[i][slices[i]] = array
    return targetArrays, paddedShapes
def padAsShape(array, targetShape):
    targetSA = np.zeros(targetShape)
    combined = np.empty(2,dtype='O')
    combined[0] = targetSA
    combined[1] = array
    padded, _ = padArray(combined)
    padded = padded[1]
    return padded
def padAs(array, targetArray):
    return padAsShape(array, targetArray.shape)
def padded2array(padded, shapes, single=False):
    if single==True:
        padded = np.array([padded])
        shapes = np.array([shapes])
    targetArrays = []
    for i in range(padded.shape[0]):
        slices = array2slices(shapes[i])
        targetArrays.append( padded[i][slices] )
    if single==True:
        return np.array(targetArrays)[0]
    return np.array(targetArrays)
def serializeArray(objArray):
    if (type(objArray) is np.ndarray) and (not all( (type(obj) is np.ndarray) for obj in objArray)):
        raise ValueError('Invalid objArray! Not every array element is np.ndarray')
    serialData, serialShape, dataBias, shapeBias=[], [], [0], [0]
    for array in objArray:
        #shapes.append(array.shape)
        shape=np.array(array.shape)
        serialData.append(array.reshape(-1))
        serialShape.append(shape.reshape(-1))
        shapeBias.append( shapeBias[-1] + serialShape[-1].shape[0])
        dataBias.append( dataBias[-1] + serialData[-1].shape[0])
    serialData, serialShape = np.concatenate(serialData), np.concatenate(serialShape)
    dataBias, shapeBias     = np.array(dataBias), np.array(shapeBias)
    return serialData, dataBias, serialShape, shapeBias
def serialized2array(serialData, dataBias, serialShape, shapeBias):
    targetArrays=[]
    for i in progbar(range(shapeBias.shape[0]-1)):
        shape       = serialShape[shapeBias[i]:shapeBias[i+1]]
        data_serial = serialData[dataBias[i]:dataBias[i+1]]
        array       = data_serial.reshape(shape)
        targetArrays.append(array)
    #print(targetArrays)
    return np.array(targetArrays)

# H5 dataset
class H5DataDict():
    def __init__(self, path, cache_max = 10000000):
        self.path = path
        f = H5File(path)
        self.fkeys = f.fkeys
        self.dict = dict([(key, H5Var(path, key)) for key in self.fkeys])
        self.cache= dict([(key, {}) for key in self.fkeys])
        self.cache_counter=0
        self.cache_max = cache_max
        f.close()
    def keys(self):
        return self.fkeys
    def __getitem__(self,values):
        if type(values) is not tuple:
            if values in self.fkeys:
                return self.dict[values]
            else:
                raise ValueError('%s does not exist'%values)
        if values[0] in self.fkeys:
            if values[1] not in self.cache[values[0]]:
                data = self.dict[values[0]][values[1]]
                if self.cache_counter < self.cache_max:
                    self.cache[values[0]][values[1]] = data
            else:
                data = self.cache[values[0]][values[1]]
            return data
        else:
            raise ValueError('%s does not exist'%values)
class H5Var():
    def __init__(self, path, datasetName):
        self.path, self.dname=path, datasetName
    def __getitem__(self, index):
        f = H5File(self.path)
        if index is None:
            data = f[(self.dname,)]
        else:
            data = f[self.dname, index]
        f.close()
        return data
    def __len__(self):
        f = H5File(self.path)
        leng = f.getLen(self.dname)
        f.close()
        return leng
    @property
    def shape(self):
        return len(self)
    def append(self, array):
        f = H5File(self.path, mode='a')
        if self.dname not in f.f.keys():
            if np.issubdtype(array[0].dtype, np.integer):
                dtype = 'i8' # 'i' means 'i4' which only support number < 2147483648
            elif np.issubdtype(array[0].dtype, np.float):
                dtype = 'f8'
            else:
                raise ValueError('Unsupported dtype %s'%array.dtype)
            f.f.create_dataset(self.dname, (0,), dtype=dtype, maxshape=(None,), chunks=(102400,))
            f.f.create_dataset('%s_serial_dataBias'%self.dname, (0,), dtype='i', maxshape=(None,), chunks=(102400,))
            f.f.create_dataset('%s_serial_shape'%self.dname, (0,), dtype='i', maxshape=(None,), chunks=(102400,))
            f.f.create_dataset('%s_serial_shapeBias'%self.dname, (0,), dtype='i', maxshape=(None,), chunks=(102400,))
        if "%s_serial_dataBias"%self.dname not in f.f.keys():
            raise('Appending for Non-serialized form is not implemented')
        #f.append(self.dname, array)
        serialData, dataBias, serialShape, shapeBias = serializeArray(array)
        key = self.dname
        dataTuple = [serialData, serialShape]
        for i,key in enumerate([self.dname, '%s_serial_shape'%self.dname]):
            oshape = f.f[key].shape[0]
            f.f[key].resize((oshape+dataTuple[i].shape[0],))
            f.f[key][oshape:oshape+dataTuple[i].shape[0]] = (dataTuple[i] if 'Bias' not in key else dataTuple[i]+f.f[key][-1])
        dataTuple = [dataBias, shapeBias]
        for i,key in enumerate(['%s_serial_dataBias'%self.dname, '%s_serial_shapeBias'%self.dname]):
            oshape = f.f[key].shape[0]
            if oshape ==0:
                f.f[key].resize((dataTuple[i].shape[0],))
                f.f[key][:] = dataTuple[i]
            else:
                tshape = oshape+dataTuple[i].shape[0]-1
                f.f[key].resize((oshape+dataTuple[i].shape[0]-1,))
                f.f[key][oshape:tshape] = dataTuple[i][1:]+f.f[key][oshape-1]
def get_h5row(path, datasetName, index):
    f = H5File(path)
    data = f[datasetName, index]
    f.close()
    return data
class H5File():
    def __init__(self, path, mode='r'):
        self.f = h5py.File(path, mode)
        self.fkeys = list(self.f.keys())
    def keys(self):
        return self.fkeys
    def get_serial_data(self, key, index):
        f = self.f
        serial_data = f[key]
        shapeBias = f['%s_serial_shapeBias'%key]
        dataBias = f['%s_serial_dataBias'%key]
        serial_shape = f['%s_serial_shape'%key]
        shape = np.array( serial_shape[shapeBias[index]:shapeBias[index+1]] )

        data = np.array(serial_data[ dataBias[index]:dataBias[index+1]]).reshape(shape)
        return data
    def __getitem__(self, value):
        f, fkeys = self.f, self.fkeys
        key = value[0]
        if "%s_serial_dataBias"%key in fkeys:
            if len(value)==1:
                serialData, dataBias, serialShape, shapeBias = np.array(f[key]), np.array(f['%s_serial_dataBias'%key]), np.array(f['%s_serial_shape'%key]), np.array(f['%s_serial_shapeBias'%key])
                item = serialized2array(serialData, dataBias, serialShape, shapeBias)
            else:
                if isIterable(value[1]):
                    item = np.array([self.get_serial_data(key, ind) for ind in value[1]])
                else:
                    item = self.get_serial_data(key,value[1])
        elif "%s_shapes"%key in fkeys:
            item = padded2array(f[key][value[1]], f["%s_shapes"%key])
        else:
            if len(value)==1:
                item = np.array(f[key])
            else:
                if isIterable(value[1]):
                    ind  = np.array(value[1])
                    uind, inverse = np.unique(ind, return_inverse=True)
                    sindi= np.argsort(uind)
                    sind = uind[sindi]
                    item = np.array(f[key][ list(sind) ])
                    item = item[sindi]
                    item = item[inverse]
                else:
                    item = np.array(f[key][value[1]])
        #print(type(item),item.shape)
        return item
    def getLen(self, key):
        if "%s_serial_dataBias"%key in self.fkeys:
            leng = self.f["%s_serial_dataBias"%key].shape[0] - 1
        else:
            leng = self.f[key].shape[0]
        return leng
    def append(self, dname, array):
        pass

    def close(self):
        self.f.close()
def readh5(path):
    dataDict={}
    with h5py.File(path,'r') as f:
        fkeys = f.keys()
        for key in fkeys:
            if "_serial" in key:
                continue
            if "_shapes" in key:
                continue
            # if np.array(f[key]).dtype.type is np.bytes_: # if is string (strings are stored as bytes in h5py)
            #     print(f[key])
            #     xs=np.array(f[key])
            #     print(xs, xs.dtype, xs.dtype.type)
            #     dataDict[key] = np.char.decode(np.array(f[key]), encoding='UTF-8')
            #     continue

            if "%s_serial_dataBias"%key in fkeys:
                serialData, dataBias, serialShape, shapeBias = np.array(f[key]), np.array(f['%s_serial_dataBias'%key]), np.array(f['%s_serial_shape'%key]), np.array(f['%s_serial_shapeBias'%key])
                dataDict[key] = serialized2array(serialData, dataBias, serialShape, shapeBias)
            elif "%s_shapes"%key in fkeys:
                dataDict[key] = padded2array(f[key], f["%s_shapes"%key])
            else:
                dataDict[key] = np.array(f[key])
    return dataDict
def writeh5(path, dataDict, mode='w', compactForm='serial', quiet=False):
    with h5py.File(path, mode) as f:
        fkeys = f.keys()
        for key in dataDict.keys():
            if key in fkeys: # overwrite if dataset exists
                del f[key]
            else:
                if dataDict[key].dtype is np.dtype('O'):
                    if compactForm=='padding':
                        padded, shapes = padArray(dataDict[key])
                        f[key] = padded
                        f['%s_shapes'%key] = shapes
                    elif compactForm=='serial':
                        serialData, dataBias, serialShape, shapeBias = serializeArray(dataDict[key])
                        f[key] = serialData
                        f['%s_serial_dataBias'%key] = dataBias
                        f['%s_serial_shape'%key]    = serialShape
                        f['%s_serial_shapeBias'%key]= shapeBias

                elif dataDict[key].dtype.type is np.str_:
                    f[key] = np.char.encode( dataDict[key], 'UTF-8' )
                else:
                    f[key] = dataDict[key]
    if quiet==False:
        print(path, 'is successfully written.')
    return dataDict

def readply(path, scaleFactor=1/256.):
    from plyfile import PlyData

    try:
        #print(path)
        with open(path, 'rb') as f:
            plydata = PlyData.read(f)
        vert = np.array([plydata['vertex']['x'],plydata['vertex']['y'],plydata['vertex']['z']]).T
        face = np.array([plydata['face']['vertex_index'][i] for i in range(len(plydata['face']['vertex_index']))]).astype(int)
        #vert = np.zeros((2,3))
    except:
        print('read error', path)
        return np.zeros((10,3)), np.zeros((10,3)).astype(int), False
    #return vert, face
    return vert*scaleFactor, face, True
def plys2h5(plyDir, h5path, indices=None, scaleFactor=1/256.):
    plynames, plypaths = listdir(plyDir, return_path=True)
    plypaths = np.array(plypaths)
    if indices is None:
        indices = np.arange(len(plynames))
    print('Total shapes: %d, selected shapes: %d'%(len(plynames), indices.shape[0]))
    verts, faces = [], []
    args = [(plypath,) for plypath in plypaths[indices]]
    #verts, faces = parallelMap(readply, args, zippedOut=True)
    func = lambda path: readply(path, scaleFactor=scaleFactor)
    verts, faces, valids = parallelMap(readply, args, zippedOut=False)
    inv_ind = np.where(valids==False)[0]
    print('inv_ind', inv_ind)
    np.savetxt('inv_ind.txt', inv_ind)
    writeh5(h5path, dataDict={'vert':np.array(verts), 'face':np.array(faces), 'index':np.array(indices)}, compactForm='serial')
    return inv_ind

# geometry
def signed_distance(queries, vert, face): # remove NAN's
    S, I, C = igl.signed_distance(queries, vert,face)
    return np.nan_to_num(S), I, C
def shape2sdf(shapePath, shapeInd, gridDim=256, disturb=False):
    vert = H5Var(shapePath, 'vert')[shapeInd]
    face = H5Var(shapePath, 'face')[shapeInd]
    x = y = z = np.linspace(0,1,gridDim)
    grid = np.stack(np.meshgrid(x,y,z,sparse=False), axis=-1)
    all_samples = grid.reshape(-1,3)
    if disturb==True:
        disturbation = np.random.rand(all_samples.shape[0],3)/gridDim
        all_samples += disturbation
    S, I, C = signed_distance(all_samples, vert, face)
    sdfPairs = np.concatenate([all_samples,S[:,None]], axis=-1)
    return sdfPairs
def mesh2sdf(vert, face, gridDim=64, disturb=False):
    x = y = z = np.linspace(0,1,gridDim)
    grid = np.stack(np.meshgrid(x,y,z,sparse=False), axis=-1)
    all_samples = grid.reshape(-1,3)
    if disturb==True:
        disturbation = np.random.rand(all_samples.shape[0],3)/gridDim
        all_samples += disturbation
    S, I, C = signed_distance(all_samples, vert, face)
    sdfPairs = np.concatenate([all_samples,S[:,None]], axis=-1)
    return sdfPairs
def pc2sdf():
    #TODO
    pass
# grid sampling (inefficient)
def shapes2sdfs(shapePath, sdfPath, indices=np.arange(10), gridDim=256, disturb=False):
    #shapeDict = readh5(shapePath)
    #verts, faces = shapeDict['vert'], shapeDict['face']
    if os.path.exists(sdfPath):
        if '.h5' == sdfPath[-3:]:
            os.remove(sdfPath)
        else:
            raise ValueError('sdfPath must ended with .h5')
    args = [indices]
    func = lambda index:shape2sdf(shapePath, index,gridDim=gridDim, disturb=disturb)
    batchFunc=lambda batchOut: [H5Var(sdfPath, 'SDF').append(np.array(batchOut))]
    ret = np.array(parallelMap(func, args, zippedIn=False, batchFunc=batchFunc))[0]
    print(ret.shape)
    #writeh5(sdfPath, {'SDF':sdf})
    return ret

class Obj():
    # Just an empty class. Used for conveniently assign member variables
    def __init__(self, dataDict, **kwargs):
        self.update(dataDict)
        self.update(kwargs)
    def update(self, dataDict):
        self.__dict__.update(dataDict)