# pylint: disable=g-bare-generic
import math
import torch

num_groups = 1; num_particles_per_group = 4

def leapfrog_step(x_0, v_0, gradient_target, step_size, mass_diag_sqrt, num_steps: int):
  """ Multiple leapfrog steps with no metropolis correction. """
  x_k = x_0.detach().clone().data; v_k = v_0.detach().clone().data
  if mass_diag_sqrt is None:
    mass_diag_sqrt = torch.ones_like(x_k)
  mass_diag = mass_diag_sqrt ** 2.

  for _ in range(num_steps):  # Inefficient version - should combine half steps
    v_k += 0.5 * step_size * gradient_target(x_k)  # half step in v
    x_k += step_size * v_k / mass_diag  # Step in x
    v_k += 0.5 * step_size * gradient_target(x_k)  # half step in v

  return x_k.detach().data, v_k.detach().data

def leader_leapfrog_step(x_0, v_0, elastic_force, energy_target, gradient_target, step_size, mass_diag_sqrt, num_steps: int, num_particles):
  """ Multiple leapfrog steps with no metropolis correction. """
  x_k = x_0.detach().clone().data; v_k = v_0.detach().clone().data
  if mass_diag_sqrt is None:
    mass_diag_sqrt = torch.ones_like(x_k)
  mass_diag = mass_diag_sqrt ** 2.
  
  num_groups = 1; num_particles_per_group = 4
  batch_size = x_k.size(0) // (num_groups * num_particles_per_group)
  particles_per_group = num_groups * num_particles_per_group
  m, n = torch.meshgrid( torch.tensor(range(batch_size)), torch.tensor(range(num_groups)) )

  for _ in range(num_steps):  # Inefficient version - should combine half steps
    x_reshaped = x_k.view(batch_size, num_groups, num_particles_per_group, -1)
    log_energy_reshaped = energy_target(x_k).view(batch_size, num_groups, num_particles_per_group)
    local_leaders_ranks = log_energy_reshaped.argmax(dim=2)
    local_leaders_x_reshaped = x_reshaped[m, n, local_leaders_ranks[m, n], :].unsqueeze(2)
    local_log_energy_reshaped = log_energy_reshaped[m, n, local_leaders_ranks[m, n]].unsqueeze(2) # (num_nodes, num_groups, 1)
    global_leader_rank = local_log_energy_reshaped.argmax(dim=1).squeeze()
    l_k = local_leaders_x_reshaped[range(batch_size), global_leader_rank].view(batch_size, *x_k.shape[1:])

    l_k = l_k.repeat_interleave(particles_per_group, dim=0)
    worker_energy_reshaped = log_energy_reshaped.exp() / log_energy_reshaped.exp().sum(dim=2, keepdim=True)
    leader_energy_reshaped, _ = worker_energy_reshaped.max(dim=2, keepdim=True)
    pulling_scale = leader_energy_reshaped - worker_energy_reshaped
    pulling_strength = (elastic_force * pulling_scale).view([x_k.size(0)] + [1 for _ in range(x_k.dim()-1)])
    # pulling_strength = (0.2 * pulling_scale).view([x_k.size(0)] + [1 for _ in range(x_k.dim()-1)])
    v_k += 0.5 * pulling_strength * (l_k - x_k)
    v_k += 0.5 * step_size * gradient_target(x_k)  # half step in v

    x_k += step_size * v_k / mass_diag  # Step in x

    x_reshaped = x_k.view(batch_size, num_groups, num_particles_per_group, -1)
    log_energy_reshaped = energy_target(x_k).view(batch_size, num_groups, num_particles_per_group)
    local_leaders_ranks = log_energy_reshaped.argmax(dim=2)
    local_leaders_x_reshaped = x_reshaped[m, n, local_leaders_ranks[m, n], :].unsqueeze(2)
    local_log_energy_reshaped = log_energy_reshaped[m, n, local_leaders_ranks[m, n]].unsqueeze(2) # (num_nodes, num_groups, 1)
    global_leader_rank = local_log_energy_reshaped.argmax(dim=1).squeeze()
    l_k = local_leaders_x_reshaped[range(batch_size), global_leader_rank].view(batch_size, *x_k.shape[1:])

    l_k = l_k.repeat_interleave(particles_per_group, dim=0)
    worker_energy_reshaped = log_energy_reshaped.exp() / log_energy_reshaped.exp().sum(dim=2, keepdim=True)
    leader_energy_reshaped, _ = worker_energy_reshaped.max(dim=2, keepdim=True)
    pulling_scale = leader_energy_reshaped - worker_energy_reshaped
    pulling_strength = (elastic_force * pulling_scale).view([x_k.size(0)] + [1 for _ in range(x_k.dim()-1)])
    # pulling_strength = (0.2 * pulling_scale).view([x_k.size(0)] + [1 for _ in range(x_k.dim()-1)])
    v_k += 0.5 * pulling_strength * (l_k - x_k)
    v_k += 0.5 * step_size * gradient_target(x_k)  # half step in v

  return x_k.detach().data, v_k.detach().data

def elastic_leapfrog_step(y_0, x_0, v_0, elastic_force, gradient_target, step_size, mass_diag_sqrt, num_steps: int):
  """ Multiple elastic leapfrog steps with no metropolis correction. """
  x_k = x_0.detach().clone().data
  v_k = v_0.detach().clone().data
  if mass_diag_sqrt is None:
    mass_diag_sqrt = torch.ones_like(x_k)
  mass_diag = mass_diag_sqrt ** 2.

  for _ in range(num_steps):  # Inefficient version - should combine half steps
    grad = gradient_target(x_k) - elastic_force * (x_k - y_0)
    v_k += 0.5 * step_size * grad.data  # half step in v
    x_k += step_size * v_k / mass_diag  # Step in x
    grad = gradient_target(x_k) - elastic_force * (x_k - y_0)
    v_k += 0.5 * step_size * grad.data  # half step in v

  return x_k.detach().data, v_k.detach().data