import os
import igl
import scipy as sp

abs_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

import plotly.offline as offline
import numpy as np
import torch
import plotly.graph_objs as go
import trimesh



def plot_point_cloud(obj, mesh, samples):
    verts = mesh.vertices
    I, J, K = mesh.faces.transpose()

    trace = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
                       i=I, j=J, k=K, alphahull=5, opacity=0.4, color='cyan'),
             go.Scatter3d(x=samples[:, 0], y=samples[:, 1], z=samples[:, 2], mode='markers', marker=dict(size=3))]
    fig = go.Figure(data=trace)

    if obj == "bunny":
        scene_dict = dict(xaxis=dict(range=(-1.05, 1.05), autorange=False),
                               yaxis=dict(range=(-1.05, 1.05), autorange=False),
                               zaxis=dict(range=(-1.05, 1.05), autorange=False),
                               aspectratio=dict(x=1, y=1, z=1),
                               camera=dict(
                                   eye=dict(x=-0.5, y=0, z=-2),
                                   up=dict(x=0, y=1, z=0),
                                   center=dict(x=0, y=0, z=0)))
    else:
        scene_dict = dict(xaxis=dict(range=(-1.05, 1.05), autorange=False),
                               yaxis=dict(range=(-1.05, 1.05), autorange=False),
                               zaxis=dict(range=(-1.05, 1.05), autorange=False),
                               aspectratio=dict(x=1, y=1, z=1),
                               camera=dict(
                                   eye=dict(x=-1, y=1, z=1),
                                   up=dict(x=0, y=1, z=0),
                                   center=dict(x=0, y=0, z=0)))

    fig.update_layout(title=f'{samples.shape[0]} scatters',
                      scene=scene_dict, width=1400, height=1400, showlegend=False)

    return fig


@torch.no_grad()
def refine_dataset_SDF(sdf, samples0, tol=1e-5, max_iter_n=1000, step_size=1e-1, keep_quiet=False):
    @torch.enable_grad()
    def sdf_grad(samples):
        samples.requires_grad_(True)
        gradients = torch.autograd.grad(
            outputs=sdf(samples).sum(),
            inputs=samples,
            create_graph=True,
            retain_graph=True)[0]
        return gradients.detach()

    if isinstance(samples0, torch.Tensor):
        samples = samples0.clone()
        device = samples.device
    else:
        samples = torch.tensor(samples0, dtype=torch.float32)
        device = torch.device("cpu")

    active_idx = torch.arange(0, samples.shape[0], dtype=torch.int64).to(device)

    iter_n = 0
    while iter_n < max_iter_n:
        xi_vals = sdf(samples[active_idx, :])
        error = torch.abs(xi_vals).squeeze(dim=1)
        bad_idx = (error >= tol)
        if bad_idx.sum() == 0:
            break
        else:
            if iter_n % 50 == 0 and not keep_quiet:
                print(f'iter {iter_n}: max_err={torch.max(error):.3e}, {bad_idx.sum()} bad states, tol={tol:.3e}')
        active_idx = active_idx[bad_idx]
        samples[active_idx, :] = samples[active_idx, :] - xi_vals[bad_idx,:] * sdf_grad(samples[active_idx,:]) * step_size
        iter_n += 1

    xi_vals = sdf(samples)
    max_error = torch.max(torch.abs(xi_vals).squeeze())

    print(f'Total steps={iter_n}, final error: {max_error: .3e}.')
    if max_error > tol * 1.1:
        print(f'Warning: tolerance ({tol: .3e}) not reached!')

    return samples.detach().cpu()


def mesh_eigen(obj, v, f, idx, upsample, nsamples):
    np.random.seed(777)
    numeigs = idx + 20

    if upsample > 0:
        v_np, f_np = igl.upsample(v, f, upsample)
        v, f = torch.tensor(v_np), torch.tensor(f_np)
    else:
        v_np, f_np = v, f
        v, f = torch.tensor(v), torch.tensor(f)
    print("#vertices: ", v.shape[0], "#faces: ", f.shape[0])

    # preprocess_eigenfunctions
    M = igl.massmatrix(v_np, f_np, igl.MASSMATRIX_TYPE_VORONOI)
    L = -igl.cotmatrix(v_np, f_np)

    eigvals, eigfns = sp.sparse.linalg.eigsh(
        L, numeigs + 1, M, sigma=0, which="LM", maxiter=100000
    )

    # Sometimes time -1
    if (obj, idx) in [("bunny", 49), ("bunny", 99), ("spot", 99)]:
        print("times -1")
        eigfns = eigfns * -1.0

    # Remove the zero eigenvalue.
    eigvals = eigvals[..., 1:]
    eigfns = eigfns[..., 1:]
    print(f"largest eigval: {eigvals.max().item()}, smallest eigval: {eigvals.min().item()}")

    vals = torch.tensor(eigfns[:, idx]).clamp(min=0.0000)
    vals = torch.mean(vals[f], dim=1)
    vals = vals * torch.tensor(igl.doublearea(v_np, f_np)).reshape(-1) / 2

    f_idx = torch.multinomial(vals, nsamples, replacement=True)
    barycoords = sample_simplex_uniform(2, (nsamples,))
    samples = torch.sum(v[f[f_idx]] * barycoords[..., None], axis=1)
    samples = samples.cpu().detach().numpy()
    return samples, vals


def sample_simplex_uniform(K, shape=(), dtype=torch.float32, device="cpu"):
    x = torch.sort(torch.rand(shape + (K,), dtype=dtype, device=device))[0]
    x = torch.cat(
        [
            torch.zeros(*shape, 1, dtype=dtype, device=device),
            x,
            torch.ones(*shape, 1, dtype=dtype, device=device),
        ],
        dim=-1,
    )
    diffs = x[..., 1:] - x[..., :-1]
    return diffs


def create_eigfn(obj):

    np.random.seed(777)

    v, f = igl.read_triangle_mesh(f"{abs_path}/data/{obj}/{obj}_simp.obj")

    if obj == "bunny":
        v = v / 250.

    mesh_simple = trimesh.Trimesh(vertices=v, faces=f)
    mesh_simple.export(f"{abs_path}/data/{obj}/{obj}_mesh_simple.ply", 'ply')

    v_np3, f_np3 = igl.upsample(mesh_simple.vertices, mesh_simple.faces, 3)
    mesh_complex = trimesh.Trimesh(vertices=v_np3, faces=f_np3)
    mesh_complex.export(f"{abs_path}/data/{obj}/{obj}_mesh_complex.ply", 'ply')
    np.save(f"{abs_path}/data/{obj}/{obj}_whole.npy", np.array(mesh_complex.vertices, dtype=np.float32))

    idx_list = [49, 99]
    for idx in idx_list:
        samples, vals = mesh_eigen(obj, v, f, idx=idx, upsample=3, nsamples=500000)
        dataset_name = f"{obj}_eigfn{idx:03d}"
        np.save(f"{abs_path}/data/{obj}/{dataset_name}.npy", samples)
        np.save(f"{abs_path}/data/{obj}/{dataset_name}_color.npy", vals.numpy())

        fig = plot_point_cloud(obj, mesh_complex, samples)
        offline.plot(fig, filename=f'{abs_path}/datasets/figs/mesh/{dataset_name}.html', auto_open=False)
        fig.write_image(f"{abs_path}/datasets/figs/mesh/{dataset_name}.png")

    return


@torch.no_grad()
def refine_mesh_data(obj):
    device = torch.device('cpu')

    mesh = trimesh.load(f"{abs_path}/data/{obj}/{obj}_mesh_simple.ply")

    # get data
    sdf = torch.load(f"{abs_path}/constraint/model/{obj}_whole_sdf.pt", map_location='cpu').to(device)

    data_name_list = [f"{obj}_eigfn049", f"{obj}_eigfn099"]
    for data_name in data_name_list:
        data_path = f"{abs_path}/data/{obj}/{data_name}.npy"
        samples = np.random.permutation(np.load(data_path))[:50000]

        print(f"----Refine data set {data_name}----")

        refined_samples = refine_dataset_SDF(sdf, samples)
        np.save(f"{abs_path}/data/{obj}/{data_name}_refined.npy", refined_samples.numpy())
        # a = np.load(f"{abs_path}/data/{obj}/{data_name}_refined.npy")
        # b = refined_samples.numpy()
        # print(a-b, (a-b).max())

        fig = plot_point_cloud(obj, mesh, refined_samples)
        offline.plot(fig, filename=f'{abs_path}/datasets/figs/mesh/{data_name}_refined.html', auto_open=False)
        fig.write_image(f"{abs_path}/datasets/figs/mesh/{data_name}_refined.png")
        fig = plot_point_cloud(obj, mesh, samples)
        offline.plot(fig, filename=f'{abs_path}/datasets/figs/mesh/{data_name}_before.html', auto_open=False)
        fig.write_image(f"{abs_path}/datasets/figs/mesh/{data_name}_before.png")

        sdf_before = sdf(torch.tensor(mesh.vertices).float())
        sdf_after = sdf(refined_samples)
        refined_rel_error = (refined_samples - samples).norm(dim=1) / (np.max(np.abs(samples), axis=1) + 1.0e-3)
        refined_error = (refined_samples - samples).norm(dim=1)

        print(f"\nBefore: mean {sdf_before.mean():.6f}, max {sdf_before.abs().max():.6f}")
        print(f"After: mean {sdf_after.mean():.6f}, max {sdf_after.abs().max():.6f}")
        print(f"Relative error of states: max {refined_rel_error.max():.4f}")
        print(f"Refine error: max {refined_error.max():.6f}, mean {refined_error.mean():.6f}\n")

    return



