import os
import random
import importlib
from scipy.sparse import csc_matrix
import scipy.sparse as sparse
import scipy
import numpy as np
import open3d as o3d
from functools import wraps
import time
from scipy.optimize import linear_sum_assignment
import torch
from torch.autograd import Variable

DIVISION_EPS = 1e-10


def parameter_count(model):
    print('parameters number:',
          sum(param.numel() for param in model.parameters())/1e6, ' M')


def cuda_time():
    torch.cuda.synchronize()
    return time.time()


def timing(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        r = func(*args, **kwargs)
        end = time.perf_counter()
        print('{}.{} : {}'.format(func.__module__, func.__name__, end - start))
        return r

    return wrapper


def v(var, cuda=True, volatile=False):
    if type(var) == torch.Tensor or type(var) == torch.DoubleTensor:
        res = Variable(var.float(), volatile=volatile)
    elif type(var) == np.ndarray:
        res = Variable(torch.from_numpy(var), volatile=volatile)
    if cuda:
        res = res.cuda()
    return res


def npy(var):
    return var.data.cpu().numpy()


def get_model_module(model_version):
    importlib.invalidate_caches()
    return importlib.import_module(model_version)


def write_ply(fn, point, normal=None, color=None):
    ply = o3d.geometry.PointCloud()
    ply.points = o3d.utility.Vector3dVector(point)

    if color is not None:
        ply.colors = o3d.utility.Vector3dVector(color)

    if normal is not None:
        ply.normals = o3d.utility.Vector3dVector(normal)

    o3d.io.write_point_cloud(fn, ply)

    return


def write_xyz_files(output_path, point, normal=None):

    fout = open(output_path, "w")

    if normal is not None:
        for i in range(point.shape[0]):
            fout.write("%f %f %f %f %f %f\n" %
                       (point[i][0], point[i][1], point[i][2], normal[i][0],
                        normal[i][1], normal[i][2]))
    else:
        for i in range(point.shape[0]):
            fout.write("%f %f %f\n" % (point[i][0], point[i][1], point[i][2]))

    fout.close()

    return


def read_xyz_files(filename, normal=True):
    with open(filename, 'r') as f:
        lines = f.readlines()

        num_points = len(lines)
        pc_pos = []
        pc_norm = []
        i = 0
        for line in lines:
            line = line.split()
            line = [float(i) for i in line]
            pc_pos.append(line[:3])
            if normal:
                pc_norm.append(line[3:6])

    pc_pos = np.array(pc_pos)
    pc_norm = np.array(pc_norm)

    if normal:
        return pc_pos, pc_norm

    return pc_pos
