// Reference:
// https://github.com/facebookresearch/DeepSDF/blob/master/src/PreprocessMesh.cpp

#include <cnpy.h>
#include <pangolin/geometry/geometry.h>
#include <pangolin/geometry/glgeometry.h>
#include <pangolin/gl/gl.h>
#include <pangolin/pangolin.h>

#include <CLI/CLI.hpp>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#include "classes.h"
#include "functions.h"

extern pangolin::GlSlProgram GetShaderProgram();

void write_to_npz(const std::string path,
                  const std::vector<Eigen::Vector3f> &sampled_normals,
                  const std::vector<Eigen::Vector3f> &sampled_points,
                  const GeometryNormalizationParameters &params) {
    std::vector<float> points_data;
    for (auto &vertex : sampled_points) {
        points_data.push_back(vertex[0]);
        points_data.push_back(vertex[1]);
        points_data.push_back(vertex[2]);
    }
    std::vector<float> normals_data;
    for (auto &normal : sampled_normals) {
        normals_data.push_back(normal[0]);
        normals_data.push_back(normal[1]);
        normals_data.push_back(normal[2]);
    }
    cnpy::npz_save(path, "vertices", points_data.data(),
                   {(long unsigned int)(points_data.size() / 3), 3}, "w");
    cnpy::npz_save(path, "vertex_normals", normals_data.data(),
                   {(long unsigned int)(normals_data.size() / 3), 3}, "a");

    const Eigen::Vector3f offset = -1 * params.center;
    cnpy::npz_save(path, "offset", offset.data(), {3ul}, "a");

    const float scale = 1.0 / params.max_distance;
    cnpy::npz_save(path, "scale", &scale, {1ul}, "a");
}

typedef std::tuple<int, int, int> face_t;

struct key_hash : public std::unary_function<face_t, std::size_t> {
    std::size_t operator()(const face_t &k) const {
        return std::get<0>(k) ^ std::get<1>(k) ^ std::get<2>(k);
    }
};

struct key_equal : public std::binary_function<face_t, face_t, bool> {
    bool operator()(const face_t &v0, const face_t &v1) const {
        return v0 == v1;
    }
};

void check_duplicated_faces(pangolin::Geometry &geometry,
                            std::vector<bool> &ignore_face,
                            std::vector<Eigen::Vector3f> &face_normals) {
    std::vector<Eigen::Vector3i> faces = get_geometry_faces(geometry);
    pangolin::Image<float> vertices = pangolin::get<pangolin::Image<float>>(
        geometry.buffers["geometry"].attributes["vertex"]);

    std::unordered_map<std::tuple<int, int, int>, Eigen::Vector3f, key_hash,
                       key_equal>
        duplicated_faces = {};
    for (unsigned int face_index = 0; face_index < faces.size(); face_index++) {
        float area = 0;
        if (ignore_face[face_index] == false) {
            const Eigen::Vector3i &face = faces[face_index];
            std::vector<int> v = {face(0), face(1), face(2)};
            // std::cout << v[0] << ", " << v[1] << ", " << v[2] << " -> ";
            std::sort(v.begin(), v.end());
            // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
            face_t tuple = std::make_tuple(v[0], v[1], v[2]);
            auto iter = duplicated_faces.find(tuple);
            if (iter == duplicated_faces.end()) {
                duplicated_faces[tuple] = face_normals[face_index];
            } else {
                std::cout << "face " << (face_index + 1) << " is duplicated"
                          << std::endl;
            }
        }
    }
    std::cout << duplicated_faces.size() << std::endl;
}

int main(int argc, char **argv) {
    std::string mesh_path;
    std::string npz_path;
    int num_sample = 500000;
    int num_viewpoints = 100;
    size_t image_width = 400;
    size_t image_height = 400;

    CLI::App app{"sample_surface_points"};
    app.add_option("--input-mesh-path", mesh_path)->required();
    app.add_option("--output-npz-path", npz_path)->required();
    app.add_option("--num-samples", num_sample);
    app.add_option("--image-width", image_width);
    app.add_option("--image-height", image_height);
    app.add_option("--num-viewpoints", num_viewpoints);
    CLI11_PARSE(app, argc, argv);

    assert(image_width == image_height);

    glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
    glPixelStorei(GL_UNPACK_ROW_LENGTH, 0);
    glPixelStorei(GL_UNPACK_SKIP_PIXELS, 0);
    glPixelStorei(GL_UNPACK_SKIP_ROWS, 0);

    pangolin::Geometry geometry = pangolin::LoadGeometry(mesh_path);
    std::cout << geometry.objects.size() << " objects" << std::endl;
    geometry.textures.clear();

    int total_num_faces = 0;
    for (const auto &object : geometry.objects) {
        auto vertext_indices = object.second.attributes.find("vertex_indices");
        if (vertext_indices != object.second.attributes.end()) {
            pangolin::Image<uint32_t> ibo =
                pangolin::get<pangolin::Image<uint32_t>>(
                    vertext_indices->second);

            total_num_faces += ibo.h;
        }
    }
    std::cout << total_num_faces << " faces" << std::endl;

    pangolin::ManagedImage<uint8_t> new_buffer(3 * sizeof(uint32_t),
                                               total_num_faces);
    pangolin::Image<uint32_t> new_ibo =
        new_buffer.UnsafeReinterpret<uint32_t>().SubImage(0, 0, 3,
                                                          total_num_faces);
    int new_index = 0;
    for (const auto &object : geometry.objects) {
        auto vertext_indices = object.second.attributes.find("vertex_indices");
        if (vertext_indices != object.second.attributes.end()) {
            pangolin::Image<uint32_t> ibo =
                pangolin::get<pangolin::Image<uint32_t>>(
                    vertext_indices->second);
            for (int k = 0; k < ibo.h; ++k) {
                new_ibo.Row(new_index).CopyFrom(ibo.Row(k));
                new_index++;
            }
        }
    }
    geometry.objects.clear();

    auto faces = geometry.objects.emplace(std::string("mesh"),
                                          pangolin::Geometry::Element());
    faces->second.Reinitialise(3 * sizeof(uint32_t), total_num_faces);
    faces->second.CopyFrom(new_buffer);
    new_ibo = faces->second.UnsafeReinterpret<uint32_t>().SubImage(
        0, 0, 3, total_num_faces);
    faces->second.attributes["vertex_indices"] = new_ibo;

    auto normalization_params = normalize_geometry(geometry, true);
    float max_distance = normalization_params.max_distance;
    std::cout << "center: " << normalization_params.center[0] << ", "
              << normalization_params.center[1] << ", "
              << normalization_params.center[2] << std::endl;
    std::cout << "max_distance: " << max_distance << std::endl;

    pangolin::CreateWindowAndBind("Main", 1, 1);
    glEnable(GL_DEPTH_TEST);
    glDisable(GL_DITHER);
    glDisable(GL_POINT_SMOOTH);
    glDisable(GL_LINE_SMOOTH);
    glDisable(GL_POLYGON_SMOOTH);
    glHint(GL_POINT_SMOOTH, GL_DONT_CARE);
    glHint(GL_LINE_SMOOTH, GL_DONT_CARE);
    glHint(GL_POLYGON_SMOOTH_HINT, GL_DONT_CARE);
    glDisable(GL_MULTISAMPLE_ARB);
    glShadeModel(GL_FLAT);

    // Define Projection and initial ModelView matrix
    float camera_distanmce = 1;
    pangolin::OpenGlRenderState camera(
        pangolin::ProjectionMatrixOrthographic(
            -camera_distanmce, camera_distanmce, camera_distanmce,
            -camera_distanmce, 0, 2.5),
        pangolin::ModelViewLookAt(0, 0, -1, 0, 0, 0, pangolin::AxisY));
    pangolin::GlGeometry gl_geometry = pangolin::ToGlGeometry(geometry);
    pangolin::GlSlProgram program = GetShaderProgram();

    pangolin::GlRenderBuffer z_buffer(image_width, image_height,
                                      GL_DEPTH_COMPONENT32);
    pangolin::GlTexture texture_normals(image_width, image_height, GL_RGBA32F);
    pangolin::GlTexture texture_vertices(image_width, image_height, GL_RGBA32F);
    pangolin::GlFramebuffer frame_buffer(texture_normals, texture_vertices,
                                         z_buffer);

    SurfaceRenderer renderer(geometry, image_width, image_height,
                             camera_distanmce);
    SurfaceRendererState state(total_num_faces, num_viewpoints,
                               camera_distanmce * 1.1);
    renderer.render(state, program, texture_normals, texture_vertices,
                    frame_buffer);

    std::vector<bool> ignore_face(total_num_faces);
    for (size_t k = 0; k < total_num_faces; k++) {
        ignore_face[k] = true;
    }
    int num_unobserved_faces = 0;
    for (unsigned int face_index = 0;
         face_index < state.invisible_face_test.size(); face_index++) {
        if (state.invisible_face_test[face_index] == true) {
            num_unobserved_faces++;
        } else {
            ignore_face[face_index] = false;
        }
    }
    std::vector<float> face_area = compute_face_area(geometry, ignore_face);

    int num_ignored_backfaces = 0;
    int num_observed_backfaces = 0;
    float pixel_area = 1.0f / image_width / image_height;
    for (unsigned int face_index = 0;
         face_index < state.face_normal_test.size(); face_index++) {
        if (state.face_observation_count[face_index] == 0) {
            continue;
        }
        int count = state.face_observation_count[face_index];
        float observed_area = pixel_area * count;
        float ratio = observed_area / face_area[face_index];
        if (ratio < 0.05) {
            ignore_face[face_index] = true;
            num_ignored_backfaces++;
        }
        if (state.observed_backface_test[face_index] == true) {
            num_observed_backfaces++;
        }
    }

    std::vector<Eigen::Vector3f> face_normals(total_num_faces);
    for (unsigned int k = 0; k < state.face_normal_test.size(); k++) {
        if (ignore_face[k] == false) {
            const Eigen::Vector3f normal = state.face_normal_test[k].head<3>();
            face_normals[k] = normal;
        }
    }

    std::cout << mesh_path << std::endl;
    std::cout << "# faces: " << total_num_faces
              << " # unobserved faces: " << num_unobserved_faces
              << " # ignored backfaces: " << num_ignored_backfaces
              << " # observed backfaces: " << num_observed_backfaces
              << std::endl;

    std::vector<Eigen::Vector3f> visible_surface_points;
    std::vector<Eigen::Vector3f> visible_surface_normals;
    for (unsigned int n = 0; n < state.observed_points.size(); n++) {
        const auto face_index =
            static_cast<std::size_t>(state.observed_normals[n][3] + 0.01f) - 1;
        if (ignore_face[face_index] == false) {
            Eigen::Vector3f vertex = state.observed_points[n].head<3>();
            if (std::isnan(vertex[0])) {
                continue;
            }
            if (std::isnan(vertex[1])) {
                continue;
            }
            if (std::isnan(vertex[2])) {
                continue;
            }
            Eigen::Vector3f normal = state.observed_normals[n].head<3>();
            if (std::isnan(normal[0])) {
                continue;
            }
            if (std::isnan(normal[1])) {
                continue;
            }
            if (std::isnan(normal[2])) {
                continue;
            }
            visible_surface_points.push_back(
                state.observed_points[n].head<3>());
            visible_surface_normals.push_back(normal);
        }
    }
    std::vector<int> indices(visible_surface_points.size());
    for (int k = 0; k < visible_surface_points.size(); k++) {
        indices[k] = k;
    }
    std::vector<int> rand_indices;
    std::sample(indices.begin(), indices.end(),
                std::back_inserter(rand_indices), num_sample,
                std::mt19937{std::random_device{}()});
    std::vector<Eigen::Vector3f> sampled_points;
    std::vector<Eigen::Vector3f> sampled_normals;
    for (int index : rand_indices) {
        sampled_points.push_back(visible_surface_points[index].head<3>());
        sampled_normals.push_back(visible_surface_normals[index].head<3>());
    }

    // std::vector<Eigen::Vector3f> sampled_points;
    // std::vector<Eigen::Vector3f> sampled_normals;
    // sample_from_surface(geometry, ignore_face, face_normals, sampled_normals,
    //                     sampled_points, num_sample);

    write_to_npz(npz_path, sampled_normals, sampled_points,
                 normalization_params);

    return 0;
}
