// Reference:
// https://github.com/facebookresearch/DeepSDF/blob/master/src/PreprocessMesh.cpp

#include <cnpy.h>
#include <pangolin/geometry/geometry.h>
#include <pangolin/geometry/glgeometry.h>
#include <pangolin/gl/gl.h>
#include <pangolin/pangolin.h>

#include <CLI/CLI.hpp>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#include "classes.h"
#include "functions.h"
#include "geometry.h"

extern pangolin::GlSlProgram GetShaderProgram();

template <typename T>
std::vector<T> take(std::vector<T> old_vector, std::vector<int> indices) {
    std::vector<T> new_vector(indices.size());
    for (int n = 0; n < indices.size(); n++) {
        new_vector[n] = old_vector[indices[n]];
    }
    return new_vector;
}

template <typename T>
std::vector<T> cast_indices(std::vector<uint32_t> indices) {
    std::vector<T> new_vector(indices.size());
    for (int n = 0; n < indices.size(); n++) {
        new_vector[n] = indices[n];
    }
    return new_vector;
}
void write_to_npz(const std::string path,
                  const std::vector<Eigen::Vector3f> &normals,
                  const std::vector<Eigen::Vector3f> &points,
                  const GeometryNormalizationParameters &params,
                  const int num_viewpoints,
                  const std::vector<std::vector<uint32_t>> &map_view_points) {
    std::vector<float> normals_data;
    for (const auto &normals : normals) {
        normals_data.push_back(normals[0]);
        normals_data.push_back(normals[1]);
        normals_data.push_back(normals[2]);
    }
    std::vector<float> points_data;
    for (const auto &vertex : points) {
        points_data.push_back(vertex[0]);
        points_data.push_back(vertex[1]);
        points_data.push_back(vertex[2]);
    }
    cnpy::npz_save(path, "vertices", points_data.data(),
                   {(long unsigned int)points.size(), 3}, "w");
    cnpy::npz_save(path, "vertex_normals", normals_data.data(),
                   {(long unsigned int)normals.size(), 3}, "a");
    cnpy::npz_save(path, "num_viewpoints", &num_viewpoints, {1}, "a");

    int num_points = points.size();
    if (num_points <= 65536) {
        for (int view_index = 0; view_index < num_viewpoints; view_index++) {
            const auto &indices = map_view_points[view_index];
            std::vector<uint16_t> uint16_t_indices =
                cast_indices<uint16_t>(indices);
            std::string key =
                "partial_point_indices_" + std::to_string(view_index);
            cnpy::npz_save(path, key, uint16_t_indices.data(),
                           {(long unsigned int)uint16_t_indices.size()}, "a");
        }
        std::cout << "Converted indices to uint16_t" << std::endl;
    } else {
        for (int view_index = 0; view_index < num_viewpoints; view_index++) {
            const auto &indices = map_view_points[view_index];
            std::string key =
                "partial_point_indices_" + std::to_string(view_index);
            cnpy::npz_save(path, key, indices.data(),
                           {(long unsigned int)indices.size()}, "a");
        }
    }

    const Eigen::Vector3f offset = -1 * params.center;
    cnpy::npz_save(path, "offset", offset.data(), {3ul}, "a");

    const float scale = 1.0 / params.max_distance;
    cnpy::npz_save(path, "scale", &scale, {1ul}, "a");
}

typedef std::tuple<int, int, int> face_t;

struct key_hash : public std::unary_function<face_t, std::size_t> {
    std::size_t operator()(const face_t &k) const {
        return std::get<0>(k) ^ std::get<1>(k) ^ std::get<2>(k);
    }
};

struct key_equal : public std::binary_function<face_t, face_t, bool> {
    bool operator()(const face_t &v0, const face_t &v1) const {
        return v0 == v1;
    }
};

void check_duplicated_faces(pangolin::Geometry &geometry,
                            std::vector<bool> &ignore_face,
                            std::vector<Eigen::Vector3f> &face_normals) {
    std::vector<Eigen::Vector3i> faces = get_geometry_faces(geometry);
    pangolin::Image<float> vertices = pangolin::get<pangolin::Image<float>>(
        geometry.buffers["geometry"].attributes["vertex"]);

    std::unordered_map<std::tuple<int, int, int>, Eigen::Vector3f, key_hash,
                       key_equal>
        duplicated_faces = {};
    for (unsigned int face_index = 0; face_index < faces.size(); face_index++) {
        float area = 0;
        if (ignore_face[face_index] == false) {
            const Eigen::Vector3i &face = faces[face_index];
            std::vector<int> v = {face(0), face(1), face(2)};
            // std::cout << v[0] << ", " << v[1] << ", " << v[2] << " -> ";
            std::sort(v.begin(), v.end());
            // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
            face_t tuple = std::make_tuple(v[0], v[1], v[2]);
            auto iter = duplicated_faces.find(tuple);
            if (iter == duplicated_faces.end()) {
                duplicated_faces[tuple] = face_normals[face_index];
            } else {
                std::cout << "face " << (face_index + 1) << " is duplicated"
                          << std::endl;
            }
        }
    }
    std::cout << duplicated_faces.size() << std::endl;
}

int main(int argc, char **argv) {
    std::string mesh_path;
    std::string npz_path;
    int max_num_samples = 65536;
    int num_viewpoints = 100;
    size_t image_width = 400;
    size_t image_height = 400;

    CLI::App app{"sample_surface_points"};
    app.add_option("--input-mesh-path", mesh_path)->required();
    app.add_option("--output-npz-path", npz_path)->required();
    app.add_option("--num-samples", max_num_samples);
    app.add_option("--image-width", image_width);
    app.add_option("--image-height", image_height);
    app.add_option("--num-viewpoints", num_viewpoints);
    CLI11_PARSE(app, argc, argv);

    assert(image_width == image_height);

    glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
    glPixelStorei(GL_UNPACK_ROW_LENGTH, 0);
    glPixelStorei(GL_UNPACK_SKIP_PIXELS, 0);
    glPixelStorei(GL_UNPACK_SKIP_ROWS, 0);

    pangolin::Geometry geometry = pangolin::LoadGeometry(mesh_path);
    int total_num_faces;
    merge_objects(geometry, total_num_faces);

    auto normalization_params = normalize_geometry(geometry, true);
    float max_distance = normalization_params.max_distance;
    std::cout << "center: " << normalization_params.center[0] << ", "
              << normalization_params.center[1] << ", "
              << normalization_params.center[2] << std::endl;
    std::cout << "max_distance: " << max_distance << std::endl;

    pangolin::CreateWindowAndBind("Main", 1, 1);
    glEnable(GL_DEPTH_TEST);
    glDisable(GL_DITHER);
    glDisable(GL_POINT_SMOOTH);
    glDisable(GL_LINE_SMOOTH);
    glDisable(GL_POLYGON_SMOOTH);
    glHint(GL_POINT_SMOOTH, GL_DONT_CARE);
    glHint(GL_LINE_SMOOTH, GL_DONT_CARE);
    glHint(GL_POLYGON_SMOOTH_HINT, GL_DONT_CARE);
    glDisable(GL_MULTISAMPLE_ARB);
    glShadeModel(GL_FLAT);

    // Define Projection and initial ModelView matrix
    float camera_distanmce = 1;
    pangolin::OpenGlRenderState camera(
        pangolin::ProjectionMatrixOrthographic(
            -camera_distanmce, camera_distanmce, camera_distanmce,
            -camera_distanmce, 0, 2.5),
        pangolin::ModelViewLookAt(0, 0, -1, 0, 0, 0, pangolin::AxisY));
    pangolin::GlGeometry gl_geometry = pangolin::ToGlGeometry(geometry);
    pangolin::GlSlProgram program = GetShaderProgram();

    pangolin::GlRenderBuffer z_buffer(image_width, image_height,
                                      GL_DEPTH_COMPONENT32);
    pangolin::GlTexture texture_normals(image_width, image_height, GL_RGBA32F);
    pangolin::GlTexture texture_vertices(image_width, image_height, GL_RGBA32F);
    pangolin::GlFramebuffer frame_buffer(texture_normals, texture_vertices,
                                         z_buffer);

    SurfaceRenderer renderer(geometry, image_width, image_height,
                             camera_distanmce);
    SurfaceRendererState state(total_num_faces, num_viewpoints,
                               camera_distanmce * 1.1);
    renderer.render(state, program, texture_normals, texture_vertices,
                    frame_buffer);

    std::vector<bool> ignore_face(total_num_faces);
    for (size_t k = 0; k < total_num_faces; k++) {
        ignore_face[k] = true;
    }
    int num_unobserved_faces = 0;
    for (unsigned int face_index = 0;
         face_index < state.invisible_face_test.size(); face_index++) {
        if (state.invisible_face_test[face_index] == true) {
            num_unobserved_faces++;
        } else {
            ignore_face[face_index] = false;
        }
    }
    std::vector<float> face_area = compute_face_area(geometry, ignore_face);

    int num_ignored_backfaces = 0;
    int num_observed_backfaces = 0;
    float pixel_area = 1.0f / image_width / image_height;
    for (unsigned int face_index = 0;
         face_index < state.face_normal_test.size(); face_index++) {
        if (state.face_observation_count[face_index] == 0) {
            continue;
        }
        int count = state.face_observation_count[face_index];
        float observed_area = pixel_area * count;
        float ratio = observed_area / face_area[face_index];
        if (ratio < 0.05) {
            ignore_face[face_index] = true;
            num_ignored_backfaces++;
        }
        if (state.observed_backface_test[face_index] == true) {
            num_observed_backfaces++;
        }
    }

    std::vector<Eigen::Vector3f> face_normals(total_num_faces);
    for (unsigned int k = 0; k < state.face_normal_test.size(); k++) {
        if (ignore_face[k] == false) {
            const Eigen::Vector3f normal = state.face_normal_test[k].head<3>();
            face_normals[k] = normal;
        }
    }

    std::cout << mesh_path << std::endl;
    std::cout << "# faces: " << total_num_faces
              << " # unobserved faces: " << num_unobserved_faces
              << " # ignored backfaces: " << num_ignored_backfaces
              << " # observed backfaces: " << num_observed_backfaces
              << std::endl;

    std::vector<Eigen::Vector3f> visible_surface_points;
    std::vector<Eigen::Vector3f> visible_surface_normals;
    std::vector<int> visible_surface_points_view_index;
    std::vector<int> visible_surface_points_face_index;
    for (unsigned int n = 0; n < state.observed_points.size(); n++) {
        const size_t face_index =
            static_cast<std::size_t>(state.observed_normals[n][3] + 0.01f) - 1;
        if (ignore_face[face_index] == false) {
            Eigen::Vector3f vertex = state.observed_points[n].head<3>();
            if (std::isnan(vertex[0])) {
                continue;
            }
            if (std::isnan(vertex[1])) {
                continue;
            }
            if (std::isnan(vertex[2])) {
                continue;
            }
            Eigen::Vector3f normal = state.observed_normals[n].head<3>();
            if (std::isnan(normal[0])) {
                continue;
            }
            if (std::isnan(normal[1])) {
                continue;
            }
            if (std::isnan(normal[2])) {
                continue;
            }
            visible_surface_points.push_back(
                state.observed_points[n].head<3>());
            visible_surface_normals.push_back(normal);
            visible_surface_points_view_index.push_back(
                state.observed_points_view_index[n]);
            visible_surface_points_face_index.push_back(face_index);
        }
    }
    int num_visible_points = visible_surface_points.size();

    // shuffle
    std::vector<int> rand_indices(num_visible_points);
    for (unsigned int n = 0; n < num_visible_points; n++) {
        rand_indices[n] = n;
    }
    std::shuffle(rand_indices.begin(), rand_indices.end(),
                 std::mt19937{std::random_device{}()});

    visible_surface_points = take(visible_surface_points, rand_indices);
    visible_surface_normals = take(visible_surface_normals, rand_indices);
    visible_surface_points_view_index =
        take(visible_surface_points_view_index, rand_indices);
    visible_surface_points_face_index =
        take(visible_surface_points_face_index, rand_indices);

    std::vector<bool> remaining_points(num_visible_points, false);
    std::vector<bool> deleted_points(num_visible_points, false);
    std::vector<std::set<uint32_t>> map_point_view(num_visible_points);

    std::cout << "# visible points: " << num_visible_points << std::endl;
    KdVertexList kdVertex(visible_surface_points);
    KdVertexListTree kd_tree(3, kdVertex);
    kd_tree.buildIndex();

    for (int n = 0; n < num_visible_points; n++) {
        if (deleted_points[n] == true) {
            continue;
        }
        int view_index = visible_surface_points_view_index[n];
        remaining_points[n] = true;
        map_point_view[n].insert(view_index);
        Eigen::Vector3f point = visible_surface_points[n];
        Eigen::Vector3f normal = visible_surface_normals[n];
        int top_k = 50;
        std::vector<int> closest_indices(top_k);
        std::vector<float> closest_distances(top_k);
        kd_tree.knnSearch(point.data(), top_k, closest_indices.data(),
                          closest_distances.data());
        for (int k = 0; k < top_k; k++) {
            int target_index = closest_indices[k];
            if (n == target_index) {
                continue;
            }
            float closest_distance = closest_distances[k];
            if (closest_distance > 5e-5) {
                break;
            }
            if (remaining_points[target_index] == true) {
                continue;
            }
            int target_view_index =
                visible_surface_points_view_index[target_index];
            if (view_index == target_view_index) {
                continue;
            }
            Eigen::Vector3f target_normal =
                visible_surface_normals[target_index];
            float dot = normal.dot(target_normal);
            if (dot <= 0) {
                continue;
            }
            map_point_view[n].insert(target_view_index);
            deleted_points[target_index] = true;
        }
    }

    std::vector<Eigen::Vector3f> trimmed_surface_points;
    std::vector<Eigen::Vector3f> trimmed_surface_normals;
    std::vector<std::vector<uint32_t>> map_view_points(num_viewpoints);
    for (int n = 0; n < num_visible_points; n++) {
        if (remaining_points[n] == true) {
            const Eigen::Vector3f &point = visible_surface_points[n];
            const Eigen::Vector3f &normal = visible_surface_normals[n];
            const auto &view_indices = map_point_view[n];
            int new_point_index = trimmed_surface_points.size();
            trimmed_surface_points.push_back(point);
            trimmed_surface_normals.push_back(normal);
            for (int view_index : view_indices) {
                map_view_points[view_index].push_back(new_point_index);
            }
            if (trimmed_surface_points.size() == max_num_samples) {
                break;
            }
        }
    }

    std::cout << "Reduced points from "
              << visible_surface_points_view_index.size() << " to "
              << max_num_samples << std::endl;

    write_to_npz(npz_path, trimmed_surface_normals, trimmed_surface_points,
                 normalization_params, num_viewpoints, map_view_points);

    return 0;
}
