#include <tuple>
#include <vector>
#include <stdexcept>
#include <cstdint>
#include <iostream>

extern "C" {
#include <ccd/ccd.h>
#include <ccd/vec3.h>
}

#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

namespace nb = nanobind;
namespace ffi = xla::ffi;

// A convex mesh represented as a collection of unique 3D points.
struct ConvexMesh {
  std::vector<ccd_vec3_t> points;
};

// Support function for ConvexMesh.
// It returns (via p) the point in the mesh with the maximum dot product in the given direction.
static void supportPoints(const void* obj, const ccd_vec3_t* dir, ccd_vec3_t* p) {
  const ConvexMesh* mesh = static_cast<const ConvexMesh*>(obj);
  ccd_real_t bestDot = -CCD_REAL_MAX;
  bool found = false;
  ccd_vec3_t best;
  for (size_t i = 0; i < mesh->points.size(); ++i) {
    const ccd_vec3_t &v = mesh->points[i];
    ccd_real_t d = ccdVec3Dot(&v, dir);
    if (!found || d > bestDot) {
      bestDot = d;
      best = v;
      found = true;
    }
  }
  if (found)
    ccdVec3Copy(p, &best);
}

// Helper: extract the total number of elements and the size of the last dimension.
std::tuple<int64_t, int64_t, int64_t> GetDims(const ffi::AnyBuffer buffer) {
  const ffi::AnyBuffer::Dimensions dims = buffer.dimensions();
  if (dims.size() == 0) {
    return std::make_tuple(0, 0, 0);
  }
  return std::make_tuple(buffer.element_count(), dims.end()[-2], dims.end()[-1]);
}

// This FFI handler implements the penetration depth computation.
// It expects two input buffers (mesh1 and mesh2) each containing an array
// of 3D points (shape (n,3)) of type F64. It writes three outputs:
// a scalar depth, a 3-element penetration direction, and a 3-element contact point.
ffi::Error PenetrationDispatch(ffi::AnyBuffer mesh1, ffi::AnyBuffer mesh2,
                               ffi::Result<ffi::AnyBuffer> depth_out,
                               ffi::Result<ffi::AnyBuffer> penetration_dir_out,
                               ffi::Result<ffi::AnyBuffer> contact_point_out) {
  // Verify that each mesh has points of dimension 3.
  auto [total1, n1, lastDim1] = GetDims(mesh1);
  auto [total2, n2, lastDim2] = GetDims(mesh2);
  if (lastDim1 != 3 || lastDim2 != 3) {
    return ffi::Error::InvalidArgument("Each input mesh must have last dimension of size 3 (3D points).");
  }

  size_t batch_size = total1 / n1 / 3;
  size_t batch_size_2 = total2 / n2 / 3;
  if (batch_size != batch_size_2) {
    return ffi::Error::InvalidArgument("Input meshes must have the same batch size.");
  }

  // std::cout << "Batch size: " << batch_size << std::endl;

  ccd_real_t* data1 = mesh1.typed_data<ccd_real_t>();
  ccd_real_t* data2 = mesh2.typed_data<ccd_real_t>();
  ccd_real_t* depth_ptr = depth_out->typed_data<ccd_real_t>();
  ccd_real_t* pd_ptr = penetration_dir_out->typed_data<ccd_real_t>();
  ccd_real_t* cp_ptr = contact_point_out->typed_data<ccd_real_t>();

  // print data1 and data2
  // for (size_t i = 0; i < total1; i++) {
  //   std::cout << "Data1 " << i << ": " << data1[i] << std::endl;
  // }
  // for (size_t i = 0; i < total2; i++) {
  //   std::cout << "Data2 " << i << ": " << data2[i] << std::endl;
  // }
  // // print n1 and n2
  // std::cout << "n1: " << n1 << std::endl;
  // std::cout << "n2: " << n2 << std::endl;

  for (size_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
    // Build ConvexMesh objects from the input data.
    ConvexMesh m1, m2;
    size_t offset_1 = batch_idx * n1 * 3;
    size_t offset_2 = batch_idx * n2 * 3;

    // std::cout << "Batch" << batch_idx << std::endl;
    // std::cout << "M1" << std::endl;
    for (size_t i = 0; i < n1; i++) {
      ccd_vec3_t vec;
      vec.v[0] = data1[offset_1 + 3 * i + 0];
      vec.v[1] = data1[offset_1 + 3 * i + 1];
      vec.v[2] = data1[offset_1 + 3 * i + 2];
      m1.points.push_back(vec);
      // std::cout << "Point " << i << ": " << vec.v[0] << ", " << vec.v[1] << ", " << vec.v[2] << std::endl;
    }
    // std::cout << "M2" << std::endl;
    for (size_t i = 0; i < n2; i++) {
      ccd_vec3_t vec;
      vec.v[0] = data2[offset_2 + 3 * i + 0];
      vec.v[1] = data2[offset_2 + 3 * i + 1];
      vec.v[2] = data2[offset_2 + 3 * i + 2];
      m2.points.push_back(vec);
      // std::cout << "Point " << i << ": " << vec.v[0] << ", " << vec.v[1] << ", " << vec.v[2] << std::endl;
    }

    // Set up the CCD configuration.
    ccd_t ccd;
    CCD_INIT(&ccd);  // Zeroes out the structure.
    ccd.max_iterations = 100;
    ccd.epa_tolerance  = 0.0001;
    ccd.support1 = supportPoints;
    ccd.support2 = supportPoints;

    // Variables to hold the penetration results.
    ccd_real_t depth;
    ccd_vec3_t penetrationDir;
    ccd_vec3_t contactPoint;

    int res = ccdGJKPenetration(&m1, &m2, &ccd, &depth, &penetrationDir, &contactPoint);
    if (res != 0) {
      depth = 0.0;
      penetrationDir.v[0] = 0.0;
      penetrationDir.v[1] = 0.0;
      penetrationDir.v[2] = 0.0;
      contactPoint.v[0] = 0.0;
      contactPoint.v[1] = 0.0;
      contactPoint.v[2] = 0.0;
    }
    // Write results into the provided output buffers.
    // (depth_out is expected to have space for 1 element;
    //  penetration_dir_out and contact_point_out for 3 elements each)
    depth_ptr[batch_idx] = depth;

    pd_ptr[batch_idx * 3] = penetrationDir.v[0];
    pd_ptr[batch_idx * 3 + 1] = penetrationDir.v[1];
    pd_ptr[batch_idx * 3 + 2] = penetrationDir.v[2];

    cp_ptr[batch_idx * 3] = contactPoint.v[0];
    cp_ptr[batch_idx * 3 + 1] = contactPoint.v[1];
    cp_ptr[batch_idx * 3 + 2] = contactPoint.v[2];
  }

  return ffi::Error::Success();
}

// Use the XLA_FFI_DEFINE_HANDLER_SYMBOL macro to declare the handler symbol.
// This macro exposes the FFI call so that JAX can invoke it.
XLA_FFI_DEFINE_HANDLER_SYMBOL(ComputePenetration, PenetrationDispatch,
  ffi::Ffi::Bind()
      .Arg<ffi::AnyBuffer>()   // mesh1
      .Arg<ffi::AnyBuffer>()   // mesh2
      .Ret<ffi::AnyBuffer>()   // depth
      .Ret<ffi::AnyBuffer>()   // penetration direction
      .Ret<ffi::AnyBuffer>()   // contact point
);

// Helper to encapsulate an FFI handler in a nanobind capsule.
template <typename T>
nb::capsule EncapsulateFfiHandler(T *fn) {
    static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>, "Encapsulated function must be and XLA FFI handler");
    return nb::capsule(reinterpret_cast<void *>(fn));
}
  
// Nanobind module definition to register the FFI handlers.
NB_MODULE(_gjk_epa_module, m) {
  m.def("registrations", []() {
    nb::dict registrations;
    registrations["compute_penetration"] = EncapsulateFfiHandler(ComputePenetration);
    return registrations;
  });
}
