from multiprocessing import Pool, TimeoutError
from pathlib import Path
from typing import Callable, Tuple

import numpy as np
import pyvista as pv

from .config import Config
from .model import Mesh


def make_data_generator(cfg: Config, stop: Callable[[], bool] = lambda: False):
    import signal as sg

    batch = []
    files = file_generator(cfg.data.training_data)
    handler = sg.signal(sg.SIGINT, sg.SIG_IGN)
    with Pool(cfg.n_workers) as pool:
        sg.signal(sg.SIGINT, handler)
        try:
            while not stop():
                next_batch = pool.starmap_async(
                    load_from_file,
                    [
                        (
                            next(files),
                            cfg.batch_size,
                            cfg.data.random_augmentation,
                            cfg.data.source_multiplier,
                        )
                        for _ in range(cfg.n_workers)
                    ]
                    * cfg.n_workers,
                )
                yield from batch
                while True:
                    try:
                        batch = next_batch.get(20)
                        break
                    except TimeoutError:
                        pass
        except KeyboardInterrupt:
            pool.terminate()


def file_generator(path: str):
    from glob import glob

    files = glob(f"../{path}")
    assert len(files), "No files provided!"
    while True:
        np.random.shuffle(files)
        yield from files


def source(x):
    # return -2 * np.cos(x[:, 0]) * np.sin(x[:, 1])
    return -2 * np.exp(-np.square(2 * x).sum(-1, keepdims=True))


def boundary(x, n):
    # y = 2 * np.cos(x[:, 0]) * np.sin(x[:, 1])
    # y = y.reshape((-1, 1))
    # return y
    return np.zeros((len(x), 1))


def load_from_file(
    fname: str, n_points: int, random_augmentation: bool, mult: float = 1.0
):
    mesh = pv.read(fname)
    return gen_example_from_mesh(mesh, n_points, random_augmentation)


def gen_example_from_mesh(
    mesh: pv.PolyData, n_points: int, random_augmentation: bool, mult: float = 1.0
):
    try:
        if random_augmentation:
            mesh = random_transform(mesh)
        x_int = sample_interior(mesh, n_points)
        if random_augmentation:
            # ensure origin in mesh interior
            center = x_int[np.random.choice(n_points)]
            x_int = x_int - center
            mesh = mesh.translate(-center, inplace=False)
        x_bc, normals = sample_boundary(mesh, n_points)
        mesh = extract_features(simplify(mesh, target_cells=1024))
        return (
            {
                "x_int": (x_int, mesh),
                "x_bc": (x_bc, normals, mesh),
            },
            {
                "target": {
                    "pde": mult * source(x_int),
                    "bc": boundary(x_bc, normals),
                }
            },
        )
    except Exception as ex:
        import traceback

        print(ex)
        print(traceback.format_exc())
        # exit()


def make_box(widths=(1, 1, 1), center=(0, 0, 0), segments=(10, 10, 10)) -> pv.PolyData:
    segs = tuple(s + 1 for s in segments)
    box = pv.UniformGrid(dims=segs).extract_surface()
    box.translate(-np.array(box.center), inplace=True)
    box.scale(tuple(w / s for w, s in zip(widths, segments)), inplace=True)
    box.translate(np.array(center), inplace=True)
    return box.triangulate().clean()


def make_sphere(radius=0.5, center=(0, 0, 0), resolution=10) -> pv.PolyData:
    box = make_box(center=center, segments=[resolution] * 3)
    box.points *= radius / np.linalg.norm(box.points, axis=1, keepdims=True)
    return box


def make_cylinder(radius=0.5, height=1) -> pv.PolyData:
    return pv.ParametricSuperEllipsoid(
        xradius=radius,
        yradius=radius,
        zradius=height,
        n1=0.1,
    ).decimate_pro(0.9)


def extract_features(mesh: pv.PolyData) -> Mesh:
    edges, neighbours = extract_mesh_adjacency(mesh)
    ps = mesh.points
    norms = mesh.point_normals
    curv = mesh.curvature()

    mids = (ps[edges[:, 0]] + ps[edges[:, 1]]) / 2
    lens = np.linalg.norm(ps[edges[:, 0]] - ps[edges[:, 1]], axis=-1, keepdims=True)
    mean_norms = (norms[edges[:, 0]] + norms[edges[:, 1]]) / 2
    mean_curv = (curv[edges[:, 0]] + curv[edges[:, 1]]) / 2

    features = np.concatenate(
        (mids, lens, mean_norms, mean_curv[..., None]),
        axis=-1,
    )

    return Mesh(features, neighbours)


def extract_mesh_adjacency(mesh: pv.PolyData) -> Tuple[np.ndarray, np.ndarray]:
    assert mesh.is_all_triangles(), "mesh must be triangular"
    assert mesh.is_manifold, "mesh is not a valid manifold"

    tris = mesh.faces.reshape(-1, 4)[:, 1:]
    edges = np.stack((tris[:, :2], tris[:, 1:], tris[:, ::-2]), axis=1)
    # ^ shape: (n_faces, 3, 2)
    is_edge_dup = (edges == edges[:, :, None, None, ::-1]).all(-1)
    # ^ shape: (n_faces, 3, n_faces, 3)
    idx = np.stack(np.indices(is_edge_dup.shape), axis=-1)
    dups_idx = idx[is_edge_dup]
    # ^ shape: (3*n_faces, 4)
    # choose edge with lower index
    is_lower = dups_idx[:, 0] < dups_idx[:, 2]
    uniq_idx = dups_idx[is_lower]
    # ^ shape: (n_edges, 4)
    edge_ids = -np.ones(edges.shape[:-1], dtype=int)
    edge_ids[uniq_idx[:, 0], uniq_idx[:, 1]] = np.arange(uniq_idx.shape[0])
    edge_ids[uniq_idx[:, 2], uniq_idx[:, 3]] = np.arange(uniq_idx.shape[0])
    assert not (edge_ids < 0).any()

    neighbours = np.concatenate(
        (edge_ids[uniq_idx[:, 0]], edge_ids[uniq_idx[:, 2]]),
        axis=-1,
    )
    non_self = neighbours != np.arange(neighbours.shape[0]).reshape(-1, 1)
    assert (non_self.sum(-1) == 4).all()
    neighbours = neighbours[non_self].reshape(-1, 4)
    # TODO assert correct triangles in neighbours list
    return edges[uniq_idx[:, 0], uniq_idx[:, 1]], neighbours


def sample_interior(mesh: pv.PolyData, N=1024) -> np.ndarray:
    # from tetgen import TetGen

    assert mesh.is_manifold

    # tets = TetGen(mesh)
    # tets.tetrahedralize(quality=False, nobisect=True)
    # voxels = tets.grid
    voxels = (
        pv.create_grid(mesh, (31, 31, 31))
        # ensure grid is slightly larger than bounds
        .scale(1.001, inplace=False)
        .clip_surface(mesh)
        .triangulate()
    )
    volumes = np.abs(voxels.compute_cell_sizes()["Volume"])
    non_nan = ~np.isnan(volumes)
    volumes = volumes[non_nan]
    cells = voxels.cells.reshape(-1, 5)[:, 1:]
    cells = cells[non_nan]
    idx = np.random.choice(cells.shape[0], N, p=volumes / np.sum(volumes))
    points = voxels.points[cells[idx]]
    ps = np.random.dirichlet(np.ones(4), points.shape[0])
    points = (points * ps[..., None]).sum(axis=-2)

    return points


def sample_boundary(mesh: pv.PolyData, N=1024) -> Tuple[np.ndarray, np.ndarray]:
    assert mesh.is_all_triangles()

    areas = mesh.compute_cell_sizes()["Area"]
    tris = mesh.faces.reshape(-1, 4)[:, 1:]
    idx = np.random.choice(tris.shape[0], N, p=areas / np.sum(areas))
    points = mesh.points[tris[idx]]
    ps = np.random.dirichlet(np.ones(3), points.shape[0])
    points = (points * ps[..., None]).sum(axis=-2)
    normals = mesh.cell_normals[idx]

    return points, normals


def tesselate(mesh: pv.PolyData, resolution: int = 101) -> pv.UnstructuredGrid:
    # tets = pv.create_grid(mesh, [resolution] * 3).clip_surface(mesh).triangulate()
    from tetgen import TetGen

    mesh = mesh.clean(
        tolerance=0,
    )
    assert mesh.is_manifold
    assert mesh.is_all_triangles()

    tets = TetGen(mesh)
    tets.tetrahedralize(switches="Yq1.1a0.001")
    tets = tets.grid
    if "GroupIds" in mesh.array_names:
        inside = mesh.extract_cells(mesh["GroupIds"] != 0)
        if inside.n_points:
            tets = tets.clip_surface(inside, invert=True, crinkle=True)
    return ensure_possitive_orientation(tets)


def ensure_possitive_orientation(tets: pv.UnstructuredGrid) -> pv.UnstructuredGrid:
    cells = tets.cells.reshape(-1, 5)
    assert (cells[:, 0] == 4).all(), "not all cells are tetrahedral"

    def volumes():
        # return tets.compute_cell_sizes()["Volume"]
        ps = tets.points
        xs = np.stack([ps[cells[:, i]] - ps[cells[:, 1]] for i in (2, 3, 4)], axis=1)
        return np.linalg.det(xs) / 6

    vols = volumes()
    bad_orientation = volumes() < 0

    cells[bad_orientation, 1:] = cells[bad_orientation, 1:][:, [0, 1, 3, 2]]
    tets = pv.UnstructuredGrid(cells.reshape(-1), tets.celltypes, tets.points)

    bad_orientation = volumes() < 0
    assert not bad_orientation.any(), "mesh has negative volumes"
    return tets


def export_fem_mesh(mesh: pv.UnstructuredGrid, fname: Path):
    surf = mesh.extract_surface()
    assert surf.is_all_triangles()
    tris = surf["vtkOriginalPointIds"][surf.faces.reshape(-1, 4)[:, 1:]]
    with fname.open("w") as fp:
        fp.writelines(
            [
                "MeshVersionFormatted 2\n",
                "\n",
                "Dimension 3\n",
                "\n",
                "Vertices\n",
                str(mesh.n_points),
                "\n",
            ]
        )
        fp.writelines([f"{p[0]} {p[1]} {p[2]} 0\n" for p in mesh.points])
        fp.writelines(
            [
                "\n",
                "Tetrahedra\n",
                str(mesh.n_cells),
                "\n",
            ]
        )
        fp.writelines(
            [
                f"{c[0]} {c[1]} {c[2]} {c[3]} 0\n"
                for c in (mesh.cells.reshape(-1, 5)[:, 1:] + 1)
            ]
        )
        fp.writelines(
            [
                "\n",
                "Triangles\n",
                str(tris.shape[0]),
                "\n",
            ]
        )
        fp.writelines([f"{t[0]} {t[1]} {t[2]} 0\n" for t in tris + 1])
        fp.write("\nEnd\n")


def simplify(mesh: pv.PolyData, target_cells: int) -> pv.PolyData:
    assert mesh.is_manifold, "Simplify non manifold"
    if mesh.n_cells > target_cells:
        m = mesh.decimate_pro(
            1 - target_cells / mesh.n_cells,
            splitting=False,
        )
        if not m.is_manifold:
            print("Warning: simplify failed to create manifold")
        else:
            return m
    return mesh


# def build_mesh(n_min=1, n_max=5, n_holes=5) -> pv.PolyData:
#     while True:
#         mesh = mesh_block(n_min, n_max).scale(2, inplace=False)

#         factories = [
#             lambda: make_cylinder(height=15, radius=0.3),
#             lambda: make_box((0.6, 0.6, 15)).smooth(),
#             lambda: make_box((2, 2, 2)).smooth(),
#             mesh_blob,
#             mesh_block,
#         ]
#         for _ in range(np.random.randint(1, n_holes + 1)):
#             hole = factories[np.random.randint(len(factories))]()
#             mesh = boolean_diff(mesh, random_transform(hole, scale=0.3))

#         if mesh.is_manifold:
#             break
#         break
#         print("Warning: manifold creation failed")

#     mesh = mesh.smooth()
#     mesh = simplify(mesh, target_cells=2048)
#     return mesh


# def mesh_block(n_min=1, n_max=4) -> pv.PolyData:
#     mesh = pv.merge(
#         [
#             random_transform(make_box().smooth_taubin())
#             for _ in range(np.random.randint(n_min, n_max + 1))
#         ]
#     )
#     return mesh.delaunay_3d().extract_surface().clean().smooth()


# def mesh_blob(n_min=3, n_max=5) -> pv.PolyData:
#     mesh = pv.merge(
#         [
#             random_transform(make_sphere())
#             for _ in range(np.random.randint(n_min, n_max + 1))
#         ]
#     )
#     return mesh.delaunay_3d().extract_surface().clean().smooth()


def random_transform(mesh: pv.PolyData, scale=0.1, translate=0.3) -> pv.PolyData:
    return (
        mesh.rotate_x(np.random.uniform(360), inplace=False)
        .rotate_y(np.random.uniform(360), inplace=False)
        .rotate_z(np.random.uniform(360), inplace=False)
        .scale(np.random.lognormal(0, scale, size=(3,)), inplace=False)
        .translate(np.random.normal(0, translate, (3,)), inplace=False)
    )


# def random_warp(mesh: pv.PolyData, factor=0.2):
#     mesh.point_data["noise"] = np.random.randn(mesh.n_points)
#     return mesh.warp_by_scalar("noise", factor=factor)


# def boundary(m):
#     m.point_data["OriginIds"] = np.arange(m.n_points)
#     return m.extract_feature_edges(
#         boundary_edges=True,
#         feature_edges=False,
#         manifold_edges=False,
#         non_manifold_edges=False,
#     )


# def boolean_diff(m1, m2):
#     assert m1.is_all_triangles() and m2.is_all_triangles()
#     # if not m1.is_manifold or not m2.is_manifold:
#     #     print("Warning: non manifold")

#     return (
#         pv.create_grid(m1, (11, 11, 11))
#         .clip_surface(m1, invert=True)
#         .clip_surface(m2, invert=False)
#         .extract_surface()
#         .clean()
#         .triangulate()
#     )


if __name__ == "__main__":
    from pathlib import Path

    p = Path("data/raw/fusion/")
    root = Path("data/training/")
    root.mkdir(exist_ok=True)
    n_faces = 1024

    for i, f in enumerate(filter(lambda f: f.stem != "assembly", p.glob("**/*.obj"))):
        try:
            mesh = pv.read(f)
            if mesh.n_points == 0:
                continue

            mesh = mesh.clean(tolerance=0)
            mesh = mesh.triangulate()
            if mesh.n_cells < n_faces:
                mesh = mesh.subdivide(
                    int(np.ceil(np.log2(n_faces / mesh.n_cells) / 2)),
                    "linear",
                )
            mesh = mesh.decimate(1 - n_faces / mesh.n_cells, volume_preservation=True)
            if mesh.n_cells < 0.8 * n_faces or mesh.n_cells > 1.2 * n_faces:
                continue

            if mesh.n_points and mesh.is_manifold:
                mesh.translate(-np.array(mesh.center), inplace=True)
                bounds = np.array(mesh.bounds)
                lo, hi = bounds[::2], bounds[1::2]
                mesh.scale(2 / (hi - lo), inplace=True)
                if mesh.volume < 0.5:
                    continue
                # test voxelization
                voxels = (
                    pv.create_grid(mesh, (31, 31, 31)).clip_surface(mesh).triangulate()
                )
                volumes = np.abs(voxels.compute_cell_sizes()["Volume"])
                non_nan = ~np.isnan(volumes)
                cells = voxels.cells.reshape(-1, 5)[:, 1:]
                cells = cells[non_nan]
                if len(cells) == 0:
                    continue
                mesh.save(root / f"{i}.vtk")
        except KeyboardInterrupt as ctrlc:
            raise ctrlc
        except Exception as ex:
            print(ex)
