// langevin-gpu/src/langevin_kernel.cu
#include "langevin_kernel.h"
#include <curand_kernel.h>

#define TOL 1e-10f
#define MAX_ITERATIONS_PER_STEP 100
#define STEPS_PER_KERNEL 1000
#define POTENTIAL_LINEAR 0
#define POTENTIAL_QUADRATIC 1

extern "C" {
__device__ float compute_drift(const int potential_type, const int edge_index,
                               const float position,
                               const float *drift_coeffs) {
  const float coeff = drift_coeffs[edge_index];
  if (potential_type == POTENTIAL_QUADRATIC) {
    return coeff * position;
  }
  return coeff;
}

__device__ float solve_quadratic(float A, float B, float C) {
  // Numerically stable solution to quadratic equation
  if (A == 0.0f) {
    return -C / B;
  }
  float discriminant = sqrtf(fmaxf(B * B - 4.0f * A * C, 0.0f));
  if (B > 0.0f) {
    return (-B - discriminant) / (2.0f * A);
  } else {
    return (2.0f * C) / (-B + discriminant);
  }
}

__global__ void
langevin_multi_step_kernel(int *edges, float *positions, int *bounces,
                           int *bounce_instances, const float *edge_lengths,
                           const float *jump_weights, const float *drift_coeffs,
                           const int potential_type, const float base_dt,
                           const float sigma, const int num_edges,
                           const int num_particles, curandState *states) {
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= num_particles)
    return;

  // printf("num_particles %d", num_particles);
  int edge = edges[tid];
  float x = positions[tid];
  int bounce_count = bounces[tid];
  int bounce_instance = bounce_instances[tid];
  curandState local_state = states[tid];

  if (edge < 0 || edge >= num_edges) {
    printf("Invalid initial edge %d for particle %d\n", edge, tid);
    edge = 0;
  }

  for (int step = 0; step < STEPS_PER_KERNEL; ++step) {
    float dt = base_dt;
    int iterations = 0;

    while (dt > 0.0f && iterations++ < MAX_ITERATIONS_PER_STEP) {
      float w = curand_normal(&local_state);

      float drift = compute_drift(potential_type, edge, x, drift_coeffs);
      float sqrt_dt = sqrtf(dt);
      float x_next = x + dt * drift + sigma * sqrt_dt * w;
      float current_length = edge_lengths[edge];

      if (current_length <= 0.0f) {
        printf("Invalid edge length %f for edge %d\n", current_length, edge);
        current_length = 1.0f;
      }

      if (x_next > 0.0f && x_next <= current_length) {
        // no bounce
        x = x_next;
        dt = 0.0f;
      } else if (x_next > current_length) {
        // also no bounce
        x = 2.0f * current_length - x_next;
        dt = 0.0f;
      } else {
        // x_next < 0.0f -- bounce
        if (x != 0.0f) {
          // first bounce
          bounce_instance++;
        }
        bounce_count++;

        float a = drift * dt;
        float b = sigma * sqrt_dt * w;
        float sqrt_alpha = solve_quadratic(a, b, x);
        float alpha = sqrt_alpha * sqrt_alpha;

        dt *= (1.0f - alpha);
        float rand_val = curand_uniform(&local_state);
        int new_edge = 0;
        while (new_edge < num_edges - 1 && rand_val > jump_weights[new_edge]) {
          new_edge++;
        }

        if (new_edge < 0 || new_edge >= num_edges) {
          printf("Invalid new_edge %d, clamping to 0\n", new_edge);
          new_edge = 0;
        }

        edge = new_edge;
        x = 0.0f;
      }
    }
  }

  edges[tid] = edge;
  positions[tid] = x;
  bounces[tid] = bounce_count;
  bounce_instances[tid] = bounce_instance;
  states[tid] = local_state;
}

__global__ void langevin_multi_step_graph_kernel(
    int *edges, float *positions, int *bounces, int *bounce_instances,
    const float *edge_lengths, const float *drift_coeffs,
    const int2 *edge_vertices, const int *vertex_edge_offsets,
    const int *vertex_edge_indices, const int *vertex_edge_orientations,
    const float *vertex_edge_cumweights, const int potential_type,
    const float base_dt, const float sigma, const int num_edges,
    const int num_particles, curandState *states) {
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= num_particles)
    return;

  int edge = edges[tid];
  float x = positions[tid];
  int bounce_count = bounces[tid];
  int bounce_instance = bounce_instances[tid];
  curandState local_state = states[tid];

  if (edge < 0 || edge >= num_edges) {
    printf("Invalid initial edge %d for particle %d\n", edge, tid);
    edge = 0;
  }

  for (int step = 0; step < STEPS_PER_KERNEL; ++step) {
    float dt = base_dt;
    int iterations = 0;

    while (dt > 0.0f && iterations++ < MAX_ITERATIONS_PER_STEP) {
      float w = curand_normal(&local_state);

      float drift = compute_drift(potential_type, edge, x, drift_coeffs);
      float sqrt_dt = sqrtf(dt);
      float x_next = x + dt * drift + sigma * sqrt_dt * w;
      float current_length = edge_lengths[edge];

      if (current_length <= 0.0f) {
        printf("Invalid edge length %f for edge %d\n", current_length, edge);
        current_length = 1.0f;
      }

      if (x_next > 0.0f && x_next <= current_length) {
        x = x_next;
        dt = 0.0f;
      } else {
        bool hit_start = x_next <= 0.0f;
        int bounce_vertex =
            hit_start ? edge_vertices[edge].x : edge_vertices[edge].y;
        if (bounce_vertex < 0) {
          bounce_vertex =
              hit_start ? edge_vertices[edge].x : edge_vertices[edge].y;
        }

        if (x != 0.0f && x != current_length) {
          bounce_instance++;
        }
        bounce_count++;

        float a = drift * dt;
        float b = sigma * sqrt_dt * w;
        float sqrt_alpha =
            solve_quadratic(a, b, hit_start ? x : current_length - x);
        float alpha = sqrt_alpha * sqrt_alpha;
        dt *= (1.0f - alpha);

        int start = vertex_edge_offsets[bounce_vertex];
        int end = vertex_edge_offsets[bounce_vertex + 1];
        int degree = end - start;

        if (degree > 0) {
          float rand_val = curand_uniform(&local_state);
          int choice = end - 1;
          for (int idx = start; idx < end; ++idx) {
            if (rand_val <= vertex_edge_cumweights[idx]) {
              choice = idx;
              break;
            }
          }
          int new_edge = vertex_edge_indices[choice];
          int orientation = vertex_edge_orientations[choice];
          if (new_edge < 0 || new_edge >= num_edges) {
            new_edge = edge;
          }
          edge = new_edge;
          float new_length = edge_lengths[new_edge];
          x = (orientation == 0) ? 0.0f : new_length;
        } else {
          x = hit_start ? 0.0f : current_length;
        }
      }
    }
  }

  edges[tid] = edge;
  positions[tid] = x;
  bounces[tid] = bounce_count;
  bounce_instances[tid] = bounce_instance;
  states[tid] = local_state;
}

__global__ void setup_kernel(curandState *states, unsigned long long seed,
                             int num_particles) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= num_particles)
    return;
  curand_init(seed + tid, 0, 0, &states[tid]);
}

// New histogram kernel
__global__ void compute_histogram_kernel(const int *edges,
                                         const float *positions,
                                         const float *edge_lengths,
                                         const int *bin_offsets,
                                         const float *bin_lengths,
                                         int *histograms,
                                         int num_edges, int num_particles) {
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= num_particles)
    return;

  const int edge = edges[tid];
  const float pos = positions[tid];

  if (edge < 0 || edge >= num_edges)
    return;
  const float length = edge_lengths[edge];
  if (pos < 0.0f || pos > length)
    return;

  const int start = bin_offsets[edge];
  const int end = bin_offsets[edge + 1];
  if (start >= end)
    return;

  float accum = 0.0f;
  int bin = start;
  for (int i = start; i < end; ++i) {
    const float blen = bin_lengths[i];
    if (blen <= 0.0f)
      continue;
    accum += blen;
    bin = i;
    // Assign to the first bin whose upper boundary exceeds the position.
    if (pos <= accum)
      break;
  }

  atomicAdd(&histograms[bin], 1);
}

// Histogram when each edge has uniform bin width.
__global__ void compute_histogram_uniform_kernel(const int *edges,
                                                 const float *positions,
                                                 const int *bin_offsets,
                                                 const int *bin_counts,
                                                 const float *bin_widths,
                                                 int *histograms,
                                                 int num_edges,
                                                 int num_particles) {
  const int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid >= num_particles)
    return;

  const int edge = edges[tid];
  if (edge < 0 || edge >= num_edges)
    return;

  const int count = bin_counts[edge];
  if (count <= 0)
    return;

  const float width = bin_widths[edge];
  if (width <= 0.0f)
    return;

  const float pos = positions[tid];
  int bin_local = (int)(pos / width);
  if (bin_local >= count)
    bin_local = count - 1;
  if (bin_local < 0)
    bin_local = 0;

  const int bin = bin_offsets[edge] + bin_local;
  atomicAdd(&histograms[bin], 1);
}
}
