import trimesh
from trimesh.visual.color import uv_to_color
import numpy as np
from PIL import Image


def load_mesh(path):
    scene = trimesh.load(path, force="scene", process=False)
    vs = []
    fs = []
    fns = []
    vns = []
    uvs = []
    tex_ids = []
    images = []
    Kas = []
    Kds = []
    Kss = []
    Nss = []

    _i = 0
    for name, geo in scene.geometry.items():
        print(name)
        v = geo.vertices
        f = geo.faces
        fn = geo.face_normals
        vn = geo.vertex_normals
        f += sum([v.shape[0] for v in vs])
        vs.append(v)
        fs.append(f)
        fns.append(fn)
        vns.append(vn)
        
        try:
            uv = geo.visual.uv
            image = geo.visual.material.image
            # uv = np.clip(uv, 0, 1)

            tex_id = np.zeros((uv.shape[0], 1), dtype=np.int32)
            tex_id[:] = _i

            Ka = np.array(geo.visual.material.ambient)[:3] / 255.0
            Kd = np.array(geo.visual.material.diffuse)[:3] / 255.0
            Ks = np.array(geo.visual.material.specular)[:3] / 255.0
            Ns = np.array(geo.visual.material.glossiness)

            uvs.append(uv)
            images.append(image)
            tex_ids.append(tex_id)
            Kas.append(Ka)
            Kds.append(Kd)
            Kss.append(Ks)
            Nss.append(Ns)
        except Exception as e:
            print(e)
            pass
        _i += 1

    ret = {
        "vs": np.concatenate(vs, axis=0),
        "fs": np.concatenate(fs, axis=0),
        "fns": np.concatenate(fns, axis=0),
        "vns": np.concatenate(vns, axis=0),
        "uvs": np.concatenate(uvs, axis=0) if len(uvs) > 0 else None,
        "tex_ids": np.concatenate(tex_ids, axis=0) if len(tex_ids) > 0 else None,
        "images": images,
        "Kas": np.stack(Kas, axis=0),
        "Kds": np.stack(Kds, axis=0),
        "Kss": np.stack(Kss, axis=0),
        "Nss": np.stack(Nss, axis=0),
    }
    return ret


# def combine_uv_images(uvs, images):
#     # FIXME: decprecated
#     # horitontal stack PIL images
#     # https://stackoverflow.com/questions/30227466/combine-several-images-horizontally-with-python
#     widths, heights = zip(*(i.size for i in images))

#     total_width = sum(widths)
#     max_height = max(heights)

#     new_im = Image.new('RGB', (total_width, max_height))

#     x_offset = 0
#     new_uvs = []
#     for i, im in enumerate(images):
#         w, h = im.size
#         u = uvs[i][:, 0] % 1.0
#         v = uvs[i][:, 1] % 1.0
#         u = u * w / total_width + x_offset / total_width
#         v = v * h / max_height + 1.0 - h / max_height
#         new_uv = np.stack([u, v], axis=1)
#         new_uvs.append(new_uv)

#         new_im.paste(im, (x_offset,0))
#         x_offset += im.size[0]

#     new_uvs = np.concatenate(new_uvs, axis=0)
#     return new_uvs, new_im


def uvs_to_colors(uvs, tex_ids, images, Kds):
    unique_tex_ids = np.unique(tex_ids)
    colors = np.zeros((uvs.shape[0], 4))
    for tid in unique_tex_ids:
        mask = tex_ids[..., 0] == tid
        if images[tid] is None:
            color = np.ones((4, ))
            color[:3] = Kds[tid]
            colors[mask] = color
        else:
            uv = uvs[mask]
            color = uv_to_color(uv, images[tid]) / 255.0
            colors[mask] = color
    return colors


def sample_grid_points_aabb(aabb, resolution):
    # aabb: (6, )
    # resolution: int
    # return: (Nx, Ny, Nz, 3)
    aabb_min, aabb_max = aabb[:3], aabb[3:]
    aabb_size = aabb_max - aabb_min
    resolutions = (resolution * aabb_size / aabb_size.max()).astype(np.int32)

    xs = np.linspace(0.5, resolutions[0] - 0.5, resolutions[0]) / resolutions[0] * aabb_size[0] + aabb_min[0]
    ys = np.linspace(0.5, resolutions[1] - 0.5, resolutions[1]) / resolutions[1] * aabb_size[1] + aabb_min[1]
    zs = np.linspace(0.5, resolutions[2] - 0.5, resolutions[2]) / resolutions[2] * aabb_size[2] + aabb_min[2]
    grid_points = np.stack(np.meshgrid(xs, ys, zs, indexing='ij'), axis=-1)
    return grid_points


def normalize_aabb(v, reso, enlarge_scale=1.03, mult=8):
    aabb_min = np.min(v, axis=0)
    aabb_max = np.max(v, axis=0)
    center = (aabb_max + aabb_min) / 2
    bbox_size = (aabb_max - aabb_min).max() * enlarge_scale
    print("center:", center)
    print("bbox size", bbox_size)

    translation = -center
    scale = 1.0 / bbox_size * 2
    # v = (v + translation) * scale
    # v = (v - center) / bbox_size * 2
    aabb_min = (aabb_min * enlarge_scale - center) / bbox_size * 2
    aabb_max = (aabb_max * enlarge_scale - center) / bbox_size * 2
    aabb = np.concatenate([aabb_min, aabb_max], axis=0)
    print("v max:", v.max(axis=0), "v min:", v.min(axis=0))
    print("aabb:", aabb)

    aabb_size = aabb_max - aabb_min
    fm_size = (reso * aabb_size / aabb_size.max()).astype(np.int32)
    # round to multiple of 8
    fm_size = (fm_size + mult - 1) // mult * mult
    aabb_max = fm_size / fm_size.max()
    aabb = np.concatenate([-aabb_max, aabb_max], axis=0)
    print("aabb:", aabb)
    return aabb, translation, scale


if __name__ == "__main__":
    info = load_mesh("mesh-data/table/table.obj")
    print(info["Kas"])
