import torch
from MeshUtils import load_off, campos_to_R_T, RasterizationSettings, Meshes, MeshRenderer, MeshRasterizer, HardPhongShader, PointLights, pre_process_mesh_pascal, Textures
from pytorch3d.renderer import OpenGLPerspectiveCameras, PerspectiveCameras
import BboxTools as bbt
import numpy as np
from PIL import Image


def get_img(theta, campos, crop_size, render_image_size):
    C = campos
    R, T = campos_to_R_T(C, theta, device=device)
    image = phong_renderer(meshes_world=meshes.clone(), R=R, T=T)
    image = image[:, ..., :3]
    box_ = bbt.box_by_shape(crop_size, (render_image_size // 2,) * 2)
    bbox = box_.bbox
    image = image[:, bbox[0][0]:bbox[0][1], bbox[1][0]:bbox[1][1], :]
    image = torch.squeeze(image).detach().cpu().numpy()
    image = np.array((image / image.max()) * 255).astype(np.uint8)
    return image


cate = 'aeroplane'
mesh_path = '../PASCAL3D/PASCAL3D+_release1.1/CAD_buildn/%s/01.off' % cate
# device = 'cpu'
device = 'cuda:0'
image_size =  {'car': (256, 672), 'bus': (320, 800), 'motorbike': (512, 512), 'boat': (384, 704), 'aeroplane': (288, 768), 'bicycle': (512, 512)}[cate]
# template_path = './ICOsamples3r4p.off'
# template_path = './ICOsamples2r8p.off'
template_path = './UVsamples4p1.off'
save_name = template_path.split('.off')[0] + '_mask_%s_cub.npz' % cate
print(save_name)

samples, _ = load_off(template_path)


verts, faces = load_off(mesh_path, to_torch=True)
verts = pre_process_mesh_pascal(verts)

verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
# textures = Textures(verts_rgb=verts_rgb.to(device))
textures = Textures(verts_features=verts_rgb.to(device))
meshes = Meshes(verts=[verts], faces=[faces], textures=textures)
meshes = meshes.to(device)

cameras = PerspectiveCameras(focal_length=1.0 * 3000, principal_point=((max(image_size) / 2, max(image_size) / 2), ), image_size=((max(image_size), ) * 2, ), device=device)

raster_settings = RasterizationSettings(
    image_size=(max(image_size), ) * 2,
    blur_radius=0.0,
    faces_per_pixel=1,
    bin_size=0
)
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras,
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, lights=lights, cameras=cameras)
)

out = dict()
print(samples.shape)
# cam_pos = torch.Tensor([[5, 0, 0]])
for i, this_sample in enumerate(samples):
    print(i)
    cam_pos = torch.from_numpy(this_sample).unsqueeze(0) * 5
    cam_pos = cam_pos.to(device)
    img = get_img(torch.zeros(1), cam_pos, image_size, max(image_size))

    mask = img.sum(2) != 765
    out['%d' % i] = mask
    # Image.fromarray(mask.astype(np.uint8) * 255).show()
    # break
np.savez(save_name, **out)
