// Reference:
// https://github.com/facebookresearch/DeepSDF/blob/master/src/PreprocessMesh.cpp

#include <cnpy.h>
#include <pangolin/geometry/geometry.h>
#include <pangolin/geometry/glgeometry.h>
#include <pangolin/gl/gl.h>
#include <pangolin/pangolin.h>

#include <CLI/CLI.hpp>
#include <chrono>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#include "classes.h"
#include "functions.h"

extern pangolin::GlSlProgram GetShaderProgram();

void write_to_npz(const std::string path,
                  const std::vector<Eigen::Vector3f> &points,
                  const std::vector<float> &sdfs,
                  const GeometryNormalizationParameters &params) {
    unsigned int num_sdf_samples = sdfs.size();
    std::vector<float> positive_sdf_samples;
    std::vector<float> negative_sdf_samples;

    assert(points.size() == sdfs.size());

    for (unsigned int k = 0; k < num_sdf_samples; k++) {
        Eigen::Vector3f v = points[k];
        float distance = sdfs[k];
        if (distance > 0) {
            positive_sdf_samples.push_back(v[0]);
            positive_sdf_samples.push_back(v[1]);
            positive_sdf_samples.push_back(v[2]);
            positive_sdf_samples.push_back(distance);
        } else {
            negative_sdf_samples.push_back(v[0]);
            negative_sdf_samples.push_back(v[1]);
            negative_sdf_samples.push_back(v[2]);
            negative_sdf_samples.push_back(distance);
        }
    }
    std::cout << "# positive sdfs: " << positive_sdf_samples.size()
              << std::endl;
    std::cout << "# negative sdfs: " << negative_sdf_samples.size()
              << std::endl;
    if (positive_sdf_samples.size() > 0) {
        cnpy::npz_save(
            path, "positive_sdf_samples", positive_sdf_samples.data(),
            {(long unsigned int)(positive_sdf_samples.size() / 4), 4}, "w");
    }
    if (negative_sdf_samples.size() > 0) {
        cnpy::npz_save(
            path, "negative_sdf_samples", negative_sdf_samples.data(),
            {(long unsigned int)(negative_sdf_samples.size() / 4), 4}, "a");
    }
    const Eigen::Vector3f offset = -1 * params.center;
    cnpy::npz_save(path, "offset", offset.data(), {3ul}, "a");

    const float scale = 1.0 / params.max_distance;
    cnpy::npz_save(path, "scale", &scale, {1ul}, "a");
}

void sample_sdf_near_surface(
    pangolin::Geometry &geometry, KdVertexListTree &kd_tree,
    std::vector<Eigen::Vector3f> &visible_surface_points,
    std::vector<Eigen::Vector3f> &visible_surface_normals,
    std::vector<int> &visible_surface_point_face_indices,
    std::vector<Eigen::Vector3f> &target_surface_points,
    std::vector<Eigen::Vector3f> &sampled_sdf_points,
    std::vector<float> &sampled_sdf_distances, int num_unit_cube_samples,
    float variance, float second_variance, float bounding_cube_dim, int top_k) {
    std::vector<Eigen::Vector3i> faces = get_geometry_faces(geometry);
    pangolin::Image<float> vertices = pangolin::get<pangolin::Image<float>>(
        geometry.buffers["geometry"].attributes["vertex"]);
    float stddev = sqrt(variance);

    std::random_device seeder;
    std::mt19937 generator(seeder());
    std::uniform_real_distribution<float> rand_dist(0.0, 1.0);

    std::random_device rd;
    std::mt19937 rng(rd());
    std::normal_distribution<float> perterb_norm(0, stddev);
    std::normal_distribution<float> perterb_second(0, sqrt(second_variance));

    std::vector<Eigen::Vector3f> tmp_sdf_points;
    for (unsigned int i = 0; i < target_surface_points.size(); i++) {
        Eigen::Vector3f surface_p = target_surface_points[i];

        Eigen::Vector3f samp1 = surface_p;
        Eigen::Vector3f samp2 = surface_p;

        for (int j = 0; j < 3; j++) {
            samp1[j] += perterb_norm(rng);
            samp2[j] += perterb_second(rng);
        }

        tmp_sdf_points.push_back(samp1);
        tmp_sdf_points.push_back(samp2);
    }

    for (int s = 0; s < (int)(num_unit_cube_samples); s++) {
        tmp_sdf_points.push_back(Eigen::Vector3f(
            rand_dist(generator) * bounding_cube_dim - bounding_cube_dim / 2,
            rand_dist(generator) * bounding_cube_dim - bounding_cube_dim / 2,
            rand_dist(generator) * bounding_cube_dim - bounding_cube_dim / 2));
    }

    for (int s = 0; s < (int)tmp_sdf_points.size(); s++) {
        Eigen::Vector3f sdf_point = tmp_sdf_points[s];
        std::vector<int> closest_indices(top_k);
        std::vector<float> closest_distances(top_k);
        kd_tree.knnSearch(sdf_point.data(), top_k, closest_indices.data(),
                          closest_distances.data());

        float sdf;
        int num_positive = 0;
        for (int k = 0; k < top_k; k++) {
            int point_index = closest_indices[k];
            Eigen::Vector3f closest_point = visible_surface_points[point_index];
            Eigen::Vector3f ray_vec = sdf_point - closest_point;
            float ray_vec_length = ray_vec.norm();

            float q = visible_surface_normals[point_index].dot(ray_vec);
            if (q > 0) {
                num_positive++;
            }
            // if close to the surface, use point plane distance
            if (k == 0) {
                if (ray_vec_length < stddev) {
                    sdf = fabs(q);
                } else {
                    sdf = ray_vec_length;
                }
            }
        }
        if (num_positive == top_k) {
            sampled_sdf_points.push_back(sdf_point);
            sampled_sdf_distances.push_back(sdf);
            continue;
        }
        if (num_positive == 0) {
            sampled_sdf_points.push_back(sdf_point);
            sampled_sdf_distances.push_back(-sdf);
            continue;
        }
        sampled_sdf_points.push_back(sdf_point);
        sampled_sdf_distances.push_back(fabs(sdf));

        // int point_index = closest_indices[0];
        // Eigen::Vector3f closest_point = visible_surface_points[point_index];
        // Eigen::Vector3f base_direction = sdf_point - closest_point;
        // int num_same_direction = 0;
        // for (int k = 1; k < top_k; k++) {
        //     uint32_t point_index = closest_indices[k];
        //     Eigen::Vector3f closest_point =
        //     visible_surface_points[point_index]; Eigen::Vector3f ray_vec =
        //     sdf_point - closest_point; float q = base_direction.dot(ray_vec);
        //     if (q > 0) {
        //         num_same_direction++;
        //     }
        // }

        // if (num_same_direction == top_k - 1) {
        //     sampled_sdf_points.push_back(sdf_point);
        //     sampled_sdf_distances.push_back(fabs(sdf));
        //     continue;
        // }
        // sampled_sdf_points.push_back(sdf_point);
        // sampled_sdf_distances.push_back(-fabs(sdf));

        // std::vector<int> hit_faces;
        // std::vector<Eigen::Vector3f> hit_face_normals;
        // for (uint32_t closest_index : closest_indices) {
        //     int face_index =
        //     visible_surface_point_face_indices[closest_index]; if
        //     (std::find(hit_faces.begin(), hit_faces.end(), face_index) ==
        //         hit_faces.end()) {
        //         hit_faces.push_back(face_index);
        //         hit_face_normals.push_back(
        //             visible_surface_normals[closest_index]);
        //     }
        // }
        // for (uint32_t closest_index : closest_indices) {
        //     Eigen::Vector3f closest_point =
        //         visible_surface_points[closest_index];
        //     Eigen::Vector3f closest_point_normal =
        //         visible_surface_normals[closest_index];
        //     int closest_point_face_index =
        //         visible_surface_point_face_indices[closest_index];
        //     Eigen::Vector3f ray_vector = closest_point - sdf_point;
        //     double ray_length = ray_vector.norm();
        //     Eigen::Vector3f ray_direction = ray_vector / ray_length;
        //     int closest_hit_face_index = -1;
        //     double closest_hit_face_distance = 99999;

        //     std::cout << "point " << (closest_index + 1) << std::endl;
        //     std::cout << "face " << closest_point_face_index << std::endl;
        //     std::cout << "distance " << ray_length << std::endl;
        //     std::cout << "dot " << closest_point_normal.dot(ray_vector)
        //               << std::endl;
        //     for (int face_index : hit_faces) {
        //         if (face_index == closest_point_face_index) {
        //             if (ray_length < closest_hit_face_distance) {
        //                 closest_hit_face_distance = ray_length;
        //                 closest_hit_face_index = face_index;
        //             }
        //         } else {
        //             // Möller–Trumbore intersection algorithm
        //             Eigen::Vector3i face = faces[face_index];
        //             Eigen::Vector3f v0 =
        //                 Eigen::Map<Eigen::Vector3f>(vertices.RowPtr(face(0)));
        //             Eigen::Vector3f v1 =
        //                 Eigen::Map<Eigen::Vector3f>(vertices.RowPtr(face(1)));
        //             Eigen::Vector3f v2 =
        //                 Eigen::Map<Eigen::Vector3f>(vertices.RowPtr(face(2)));
        //             Eigen::Vector3f edge1 = v1 - v0;
        //             Eigen::Vector3f edge2 = v2 - v0;
        //             Eigen::Vector3f h = ray_direction.cross(edge2);
        //             double a = edge1.dot(h);
        //             double f = 1.0 / a;
        //             Eigen::Vector3f s = sdf_point - v0;
        //             double u = f * s.dot(h);
        //             Eigen::Vector3f q = s.cross(edge1);
        //             double v = f * ray_direction.dot(q);
        //             double t = f * edge2.dot(q);
        //             std::cout << "t:" << t << std::endl;
        //             if (t < closest_hit_face_distance) {
        //                 closest_hit_face_distance = t;
        //                 closest_hit_face_index = face_index;
        //             }
        //         }
        //     }
        //     std::cout << closest_hit_face_index << std::endl;
        // }
    }
}

int main(int argc, char **argv) {
    std::string mesh_path;
    std::string npz_path;
    float variance = 0.005;
    int num_sample = 500000;
    size_t image_width = 400;
    size_t image_height = 400;
    int num_viewpoints = 100;
    bool test_mode = false;

    CLI::App app{"sample_sdf"};
    app.add_option("--input-mesh-path", mesh_path)->required();
    app.add_option("--output-npz-path", npz_path)->required();
    app.add_option("--num-samples", num_sample);
    app.add_option("--variance", variance);
    app.add_option("--image-width", image_width);
    app.add_option("--image-height", image_height);
    app.add_option("--num-viewpoints", num_viewpoints);
    app.add_flag("--test", test_mode);
    CLI11_PARSE(app, argc, argv);

    std::cout << mesh_path << " -> " << npz_path << std::endl;

    float second_variance = variance / 10;
    if (test_mode) {
        variance = 0.05;
        second_variance = variance / 100;
        num_sample /= 2;
    }
    std::cout << "variance: " << variance << " and " << second_variance
              << std::endl;

    glPixelStorei(GL_UNPACK_ALIGNMENT, 1);
    glPixelStorei(GL_UNPACK_ROW_LENGTH, 0);
    glPixelStorei(GL_UNPACK_SKIP_PIXELS, 0);
    glPixelStorei(GL_UNPACK_SKIP_ROWS, 0);

    pangolin::Geometry geometry = pangolin::LoadGeometry(mesh_path);
    std::cout << geometry.objects.size() << " objects" << std::endl;
    geometry.textures.clear();

    int total_num_faces = 0;
    for (const auto &object : geometry.objects) {
        auto vertext_indices = object.second.attributes.find("vertex_indices");
        if (vertext_indices != object.second.attributes.end()) {
            pangolin::Image<uint32_t> ibo =
                pangolin::get<pangolin::Image<uint32_t>>(
                    vertext_indices->second);

            total_num_faces += ibo.h;
        }
    }
    std::cout << total_num_faces << " faces" << std::endl;

    pangolin::ManagedImage<uint8_t> new_buffer(3 * sizeof(uint32_t),
                                               total_num_faces);
    pangolin::Image<uint32_t> new_ibo =
        new_buffer.UnsafeReinterpret<uint32_t>().SubImage(0, 0, 3,
                                                          total_num_faces);
    int new_index = 0;
    for (const auto &object : geometry.objects) {
        auto vertext_indices = object.second.attributes.find("vertex_indices");
        if (vertext_indices != object.second.attributes.end()) {
            pangolin::Image<uint32_t> ibo =
                pangolin::get<pangolin::Image<uint32_t>>(
                    vertext_indices->second);
            for (int k = 0; k < ibo.h; ++k) {
                new_ibo.Row(new_index).CopyFrom(ibo.Row(k));
                new_index++;
            }
        }
    }
    geometry.objects.clear();

    auto faces = geometry.objects.emplace(std::string("mesh"),
                                          pangolin::Geometry::Element());
    faces->second.Reinitialise(3 * sizeof(uint32_t), total_num_faces);
    faces->second.CopyFrom(new_buffer);
    new_ibo = faces->second.UnsafeReinterpret<uint32_t>().SubImage(
        0, 0, 3, total_num_faces);
    faces->second.attributes["vertex_indices"] = new_ibo;

    auto normalization_params = normalize_geometry(geometry, true);
    float max_distance = normalization_params.max_distance;
    std::cout << "center: " << normalization_params.center[0] << ", "
              << normalization_params.center[1] << ", "
              << normalization_params.center[2] << std::endl;
    std::cout << "max_distance: " << max_distance << std::endl;

    pangolin::CreateWindowAndBind("Main", 1, 1);
    glEnable(GL_DEPTH_TEST);
    glDisable(GL_DITHER);
    glDisable(GL_POINT_SMOOTH);
    glDisable(GL_LINE_SMOOTH);
    glDisable(GL_POLYGON_SMOOTH);
    glHint(GL_POINT_SMOOTH, GL_DONT_CARE);
    glHint(GL_LINE_SMOOTH, GL_DONT_CARE);
    glHint(GL_POLYGON_SMOOTH_HINT, GL_DONT_CARE);
    glDisable(GL_MULTISAMPLE_ARB);
    glShadeModel(GL_FLAT);

    // Define Projection and initial ModelView matrix
    float camera_distanmce = 1;
    pangolin::OpenGlRenderState camera(
        pangolin::ProjectionMatrixOrthographic(
            -camera_distanmce, camera_distanmce, camera_distanmce,
            -camera_distanmce, 0, 2.5),
        pangolin::ModelViewLookAt(0, 0, -1, 0, 0, 0, pangolin::AxisY));
    pangolin::GlGeometry gl_geometry = pangolin::ToGlGeometry(geometry);
    pangolin::GlSlProgram program = GetShaderProgram();

    pangolin::GlRenderBuffer z_buffer(image_width, image_height,
                                      GL_DEPTH_COMPONENT32);
    pangolin::GlTexture texture_normals(image_width, image_height, GL_RGBA32F);
    pangolin::GlTexture texture_vertices(image_width, image_height, GL_RGBA32F);
    pangolin::GlFramebuffer frame_buffer(texture_normals, texture_vertices,
                                         z_buffer);

    SurfaceRenderer renderer(geometry, image_width, image_height,
                             camera_distanmce);
    SurfaceRendererState state(total_num_faces, num_viewpoints,
                               camera_distanmce * 1.1);
    renderer.render(state, program, texture_normals, texture_vertices,
                    frame_buffer);

    std::vector<bool> ignore_face(total_num_faces);
    for (size_t k = 0; k < total_num_faces; k++) {
        ignore_face[k] = true;
    }
    int num_unobserved_faces = 0;
    for (unsigned int face_index = 0;
         face_index < state.invisible_face_test.size(); face_index++) {
        if (state.invisible_face_test[face_index] == true) {
            num_unobserved_faces++;
        } else {
            ignore_face[face_index] = false;
        }
    }
    std::vector<float> face_area = compute_face_area(geometry, ignore_face);

    int num_ignored_backfaces = 0;
    int num_observed_backfaces = 0;
    float pixel_area = 1.0f / image_width / image_height;
    for (unsigned int face_index = 0;
         face_index < state.face_normal_test.size(); face_index++) {
        if (state.face_observation_count[face_index] == 0) {
            continue;
        }
        int count = state.face_observation_count[face_index];
        float observed_area = pixel_area * count;
        float ratio = observed_area / face_area[face_index];
        if (ratio < 0.05) {
            ignore_face[face_index] = true;
            num_ignored_backfaces++;
        }
        if (state.observed_backface_test[face_index] == true) {
            num_observed_backfaces++;
        }
    }

    std::vector<Eigen::Vector3f> face_normals(total_num_faces);
    for (unsigned int k = 0; k < state.face_normal_test.size(); k++) {
        if (ignore_face[k] == false) {
            const Eigen::Vector3f normal = state.face_normal_test[k].head<3>();
            face_normals[k] = normal;
        }
    }
    std::cout << "# faces: " << total_num_faces
              << " # unobserved faces: " << num_unobserved_faces
              << " # ignored backfaces: " << num_ignored_backfaces
              << " # observed backfaces: " << num_observed_backfaces
              << std::endl;

    std::vector<Eigen::Vector3f> visible_surface_points;
    std::vector<Eigen::Vector3f> visible_surface_normals;
    std::vector<int> visible_surface_point_face_indices;
    for (unsigned int n = 0; n < state.observed_points.size(); n++) {
        const auto face_index =
            static_cast<std::size_t>(state.observed_normals[n][3] + 0.01f) - 1;
        if (ignore_face[face_index] == false) {
            Eigen::Vector3f vertex = state.observed_points[n].head<3>();
            if (std::isnan(vertex[0])) {
                continue;
            }
            if (std::isnan(vertex[1])) {
                continue;
            }
            if (std::isnan(vertex[2])) {
                continue;
            }
            Eigen::Vector3f normal = state.observed_normals[n].head<3>();
            if (std::isnan(normal[0])) {
                continue;
            }
            if (std::isnan(normal[1])) {
                continue;
            }
            if (std::isnan(normal[2])) {
                continue;
            }
            visible_surface_points.push_back(
                state.observed_points[n].head<3>());
            visible_surface_normals.push_back(normal);
            visible_surface_point_face_indices.push_back(face_index);
        }
    }
    std::cout << "# visible points: " << visible_surface_points.size()
              << std::endl;

    KdVertexList kdVertex(visible_surface_points);
    KdVertexListTree kd_tree(3, kdVertex);
    kd_tree.buildIndex();

    int num_near_surface_samples = (int)(47 * num_sample / 50);
    int num_surface_samples = num_near_surface_samples /
                              2;  // we sample twice with different variances
    // std::vector<Eigen::Vector3f> sampled_visible_surface_normals;
    // std::vector<Eigen::Vector3f> target_surface_points;
    // sample_from_surface(geometry, ignore_face, face_normals,
    //                     sampled_visible_surface_normals,
    //                     target_surface_points, num_surface_samples);

    std::vector<Eigen::Vector3f> target_surface_points;
    std::sample(visible_surface_points.begin(), visible_surface_points.end(),
                std::back_inserter(target_surface_points), num_surface_samples,
                std::mt19937{std::random_device{}()});

    std::vector<Eigen::Vector3f> sampled_sdf_points;
    std::vector<float> sampled_sdf_distances;
    sample_sdf_near_surface(
        geometry, kd_tree, visible_surface_points, visible_surface_normals,
        visible_surface_point_face_indices, target_surface_points,
        sampled_sdf_points, sampled_sdf_distances,
        num_sample - num_near_surface_samples, variance, second_variance, 2,
        10);

    std::cout << "# sdf samples: " << sampled_sdf_points.size() << std::endl;

    write_to_npz(npz_path, sampled_sdf_points, sampled_sdf_distances,
                 normalization_params);
    return 0;
}
