import jax.numpy as jnp
import jax
import itertools
import numpy as np
from configs.datasets_config import get_dataset_info
import copy
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
import bond_analyze
from qm9.rdkit_functions import BasicMolecularMetrics

"""
QM9 configuration

num_atoms   (batch_size,)
charges     (batch_size, n_nodes, 1)
positions   (batch_size, n_nodes, 3)
one_hot     (batch_size, n_nodes, 5)
atom_mask   (batch_size, n_nodes)
edge_mask   (batch_size * n_nodes * n_nodes, 1)

index       (batch_size,)
A           (batch_size,)
B           (batch_size,)
C           (batch_size,)
mu          (batch_size,)
alpha       (batch_size,)
homo        (batch_size,)
lumo        (batch_size,)
gap         (batch_size,)
r2          (batch_size,)
zpve        (batch_size,)
U0          (batch_size,)
U           (batch_size,)
H           (batch_size,)
G           (batch_size,)
Cv          (batch_size,)
omega1      (batch_size,)
zpve_thermo (batch_size,)
U0_thermo   (batch_size,)
U_thermo    (batch_size,)
H_thermo    (batch_size,)
G_thermo    (batch_size,)
Cv_thermo   (batch_size,)
"""

def remove_mean_with_mask(x, node_mask):
  masked_max_abs_value = jnp.sum(jnp.abs(x * ~node_mask))
  assert masked_max_abs_value < 1e-5, f'Error {masked_max_abs_value} too high'
  N = jnp.sum(node_mask, axis=1, keepdims=True) # (B, 1)

  mean = jnp.sum(x, axis=1, keepdims=True) / N

  x = x - mean * node_mask
  return x

def sample_center_gravity_zero_gaussian_with_mask(rng, shape, node_mask):
  assert len(shape) == 3
  x = jax.random.normal(rng, shape)
  x_masked = x * node_mask
  x_projected = remove_mean_with_mask(x_masked, node_mask)
  return x_projected

def check_mask_correct(variables, node_mask):
  for i, variable in enumerate(variables):
    if len(variable) > 0:
      assert_correctly_masked(variable, node_mask)

def assert_correctly_masked(variable, node_mask):
    assert jnp.sum(jnp.abs(variable * (1 - node_mask))) < 1e-8

def assert_mean_zero_with_mask(x, node_mask, eps=1e-10):
  assert_correctly_masked(x, node_mask)
  largest_value = jnp.max(jnp.abs(x))
  error = jnp.max(jnp.abs(jnp.sum(x, axis=1, keepdims=True)))
  rel_error = error / (largest_value + eps)
  assert rel_error < 1e-2, f'Mean is not zero, relative_error {rel_error}'

def sample_gaussian_with_mask(rng, shape, node_mask):
  assert len(shape) == 3
  x = jax.random.normal(rng, shape)
  x_masked = x * node_mask
  return x_masked

def sum_except_batch(x):
  return jnp.sum(jnp.reshape(x, (x.shape[0], -1)), axis=-1)

# Dequantizer
def uniformdequantizer(rng, tensor, node_mask, edge_mask, context):
  category, integer = tensor['categorical'], tensor['integer']
  zeros = jnp.zeros((integer.shape[0],))

  cat_rng, int_rng = jax.random.split(rng)
  out_category = category + jax.random.uniform(cat_rng, category.shape) - 0.5
  out_integer = integer + jax.random.uniform(int_rng, integer.shape) - 0.5

  if node_mask is not None:
    out_category = out_category * node_mask
    out_integer = out_integer * node_mask
  
  out = {
    'categorical': out_category,
    'integer':     out_integer,
  }

  return out, zeros

def uniformdequantizer_reverse(tensor):
  categorical, integer = tensor['categorical'], tensor['integer']
  categorical, integer = jnp.round(categorical), jnp.round(integer)
  return {
    'categorical': categorical,
    'integer':     integer,
  }
#==========================================================================#
def preprocess_batch(data, config, evaluation=False, jit=True):
  """
  preprocess_batch
  Input
    Raw batch data in the training set
    elements: "positions", "atom_mask", "edge_mask", "one_hot", "charges"
  Output
    Preprocessed data, used for training
  """
  for k in data:
    data[k] = jnp.array(data[k])

  x = jnp.array(data["positions"])
  node_mask = jnp.expand_dims(data["atom_mask"], axis=2)
  edge_mask = data["edge_mask"]
  one_hot = data["one_hot"]
  charges = data["charges"] if config.data.include_charges else 0

  all_batch_size, n_max_atoms, n_coords = x.shape

  x = remove_mean_with_mask(x, node_mask)
  # if config.data.augment_noise > 0:
  #   raise NotImplementedError()
  #   rng, step_rng = jax.random.split(rng)
  #   eps = sample_center_gravity_zero_gaussian_with_mask(step_rng, x.shape, node_mask)
  #   x = x + config.data.augment_noise * eps
  #   x = remove_mean_with_mask(x, jnp.expand_dims(jnp.array(data['atom_mask']), 2))
  
  check_mask_correct([x, one_hot, charges], node_mask)
  assert_mean_zero_with_mask(x, node_mask)

  # Finalize batch
  h = {"categorical": one_hot, "integer": charges}

  if len(config.data.conditioning) > 0:
    raise NotImplementedError() # TODO: conditioning argument.
    """
    context = qm9utils.prepare_context(
        args.conditioning, data, property_norms
    ).to(device, dtype)
    assert_correctly_masked(context, node_mask)
    """
  else:
    context = jnp.zeros((all_batch_size, n_max_atoms, 0))

  dict_before_pmap = {
    'x': x,
    'h': h, # list of lists
    'node_mask': node_mask,
    'edge_mask': edge_mask,
    'context': context,
  }

  batch_size = config.training.batch_size if not evaluation else config.eval.batch_size
  assert batch_size % jax.device_count() == 0
  per_device_batch_size = batch_size // jax.local_device_count()

  if jit:
    jit_tuple = (jax.local_device_count(), config.training.n_jitted_steps, per_device_batch_size)
  else:
    jit_tuple = (jax.local_device_count(), per_device_batch_size)
      
  k = 'edge_mask'
  dict_before_pmap[k] = jnp.reshape(dict_before_pmap[k], (all_batch_size, n_max_atoms * n_max_atoms, 1))
  dict_before_pmap = jax.tree_util.tree_map(
    lambda x: jnp.reshape(x, jit_tuple + x.shape[1:]),
    dict_before_pmap
  )

  data[k] = jnp.reshape(data[k], (all_batch_size, n_max_atoms * n_max_atoms, 1))
  data = jax.tree_util.tree_map(
    lambda x: jnp.reshape(x, jit_tuple + x.shape[1:]),
    data
  )

  dict_pmap = dict_before_pmap
  data_pmap = data
  return dict_pmap, data_pmap


def get_adj_matrix(n_nodes, bs):
  """
  Get adjacency matrix
  Input
    n_nodes: number of nodes (max nodes in the minibatch)
    bs: minibatch size
  Output
    
  """
  # get edges for a single sample
  rows, cols = [], []
  for batch_idx in range(bs):
    for i, j in itertools.product(range(n_nodes), range(n_nodes)):
      rows.append(i + batch_idx * n_nodes)
      cols.append(j + batch_idx * n_nodes)
  edges = [
    jnp.array(rows), jnp.array(cols)
  ]
  return edges

def coord2diff(x, edge_index, norm_constant=1):
    row, col = edge_index
    coord_diff = x[row] - x[col]
    radial = jnp.expand_dims(jnp.sum(coord_diff ** 2, axis=1), axis=1)
    norm = jnp.sqrt(radial + 1e-8)
    coord_diff = coord_diff / (norm + norm_constant)
    return radial, coord_diff

def sin_embedding(x):
  max_res = 15.0
  min_res = 15.0 / 2000.0
  div_factor = 4
  n_frequencies = int(jnp.log(max_res / min_res) / jnp.log(div_factor)) + 1
  frequencies = 2 * jnp.pi * div_factor ** jnp.arange(n_frequencies) / max_res
  dim = len(frequencies) * 2
  if x is None:
    return dim
  
  # forward
  else:
    x = jnp.sqrt(x + 1e-8)
    emb = x * frequencies[None, :]
    emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
    return emb

def normalize(config, x, h, node_mask):
  # Normalize x values
  norm_values = config.data.norm_values
  norm_biases = config.data.norm_biases

  x = x / norm_values[0]
  h_cat = ((h["categorical"] - norm_biases[1]) / norm_values[1]) * node_mask
  h_int = (h["integer"] - norm_biases[2]) / norm_values[2]
  if config.data.include_charges:
    h_int = h_int * node_mask

  # Renew h dictionary.
  h = {"categorical": h_cat, "integer": h_int}
  return x, h

def icp(A, B, max_iterations=100, tolerance=0.001):
  """
  The Iterative Closest Point method: finds best-fit transform that maps points A on to points B
  Input:
      A: Nxm numpy array of source mD points
      B: Nxm numpy array of destination mD point
      init_pose: (m+1)x(m+1) homogeneous transformation
      max_iterations: exit algorithm after max_iterations
      tolerance: convergence criteria
  Output:
      R: final Rotation matrix for A
      rotated: Euclidean distances (errors) of the nearest neighbor
      i: number of iterations to converge
  """

  assert A.shape == B.shape

  # get number of dimensions
  m = A.shape[1]

  src = np.copy(A)
  dst = np.copy(B)

  prev_error = 0

  for i in range(max_iterations):
      # get assignments
      distances, indices = get_assignments(src, dst)

      # compute the transformation between the current source and nearest destination points
      _, R, _ = best_fit_transform(src, dst[indices, :])

      # rotate and update the current source
      src = np.dot(R, src.T).T

      # check error
      mean_error = np.mean(distances)
      if np.abs(prev_error - mean_error) < tolerance:
          break
      prev_error = mean_error
  if i > max_iterations - 1:
      print("out of iteration")

  # calculate final transformation
  _, R, _ = best_fit_transform(A, src)
  A_rotated = np.dot(R, A.T).T
  return R, A_rotated, indices

def best_fit_transform(A, B):
  """
  Calculates the least-squares best-fit transform that maps corresponding points A to B in m spatial dimensions
  Input:
    A: Nxm numpy array of corresponding points
    B: Nxm numpy array of corresponding points
  Returns:
    T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B
    R: mxm rotation matrix
    t: mx1 translation vector
  """

  assert A.shape == B.shape

  # get number of dimensions
  m = A.shape[1]

  # translate points to their centroids
  centroid_A = np.mean(A, axis=0)
  centroid_B = np.mean(B, axis=0)
  AA = A - centroid_A
  BB = B - centroid_B

  # rotation matrix
  H = np.dot(AA.T, BB)
  U, S, Vt = np.linalg.svd(H)
  R = np.dot(Vt.T, U.T)

  # special reflection case
  if np.linalg.det(R) < 0:
      Vt[m - 1, :] *= -1
      R = np.dot(Vt.T, U.T)

  # translation
  t = centroid_B.T - np.dot(R, centroid_A.T)

  # homogeneous transformation
  T = np.identity(m + 1)
  T[:m, :m] = R
  T[:m, m] = t

  return T, R, t

def get_assignments(src, dst):
  distance_mtx = cdist(src, dst, metric="euclidean")
  _, dest_ind = linear_sum_assignment(distance_mtx, maximize=False)
  distances = distance_mtx[range(len(dest_ind)), dest_ind]
  return distances, dest_ind

def transform_data(rng, config, batch):
  """
  input - Preprocessed
    x:         (B, N, 3)
    h:         {'cat': (B, N, 5), 'int': (B, N, 1)}
    node_mask: (B, N, 1) 
    edge_mask: (B, N * N, 1)
    context:   (B, N, C)

  Input - CNF (Part 1)
    'x':         x,
    'h':         h,
    'node_mask': node_mask,
    'edge_mask': edge_mask,
    'context':   context,

  Input - dynamics (Part 2)
    't':         t,
    'y':         y,
    'node_mask': node_mask,
    'edge_mask': edge_mask,
    'context':   context,

  Input - EGNN (Part 3)
    h:          (B * N, 5 (atom-cat) + 1 (charge-int) + 1 (time) + C (context))
    x:          (B * N, 3 (xyz))
    edge_index: [(B * N * N, 1), (B * N * N, 1)]
    node_mask:  (B * N, 1)
    edge_mask:  (B * N * N, 1)
  """
  dataset_info = get_dataset_info(config.data.dataset, config.data.remove_h)

  # base: preprocessed input
  h = batch['h']                  # {'categorical': (B, N, 5), 'integer': (B, N, 1)}
  x = batch['x']                  # (B, N, 3)
  node_mask = batch['node_mask']  # (B, N, 1)
  edge_mask = batch['edge_mask']  # (B, N * N, 1)
  context = batch['context']      # (B, N, C)

  assert len(x.shape) == 3
  B, N, D = x.shape # B: batch size, N: n_atoms

  # (1) Preprocessed input (losses.compute_loss_and_nll) --> CNF forward input
  rng, step_rng = jax.random.split(rng)
  h, _ = uniformdequantizer(step_rng, h, node_mask, edge_mask, x)
  edge_mask = jnp.reshape(edge_mask, (B, N * N))
  assert_correctly_masked(x, node_mask)
  cnf_batch = {
    'x':         x,
    'h':         h,
    'node_mask': node_mask,
    'edge_mask': edge_mask,
    'context':   context,
  }

  # (2) CNF forward --> Dynamics forward input
  # (2-1) Normaliation
  x, h = normalize(config, x, h, node_mask)
  # in_node_nf = len(dataset_info["atom_decoder"]) + int(config.data.include_charges)
  # (2-2) Preprocess
  assert len(x.shape) == 3
  xh = jnp.concatenate([x, h['categorical'], h['integer']], axis=2)
  # (2-3) Get [x-centered-gaussian, h-gaussian] noise. shape (b, n, n_dims + in_node_nf)
  rng, step_rng = jax.random.split(rng)
  noise_x = sample_center_gravity_zero_gaussian_with_mask(
    step_rng,
    shape=(xh.shape[0], xh.shape[1], x.shape[2]),
    node_mask=node_mask,
  )
  rng, step_rng = jax.random.split(rng)
  noise_h = sample_gaussian_with_mask(
    step_rng,
    shape=(xh.shape[0], xh.shape[1], xh.shape[2] - x.shape[2]),
    node_mask=node_mask,
  )
  noise = jnp.concatenate([noise_x, noise_h], axis=2)

  # (2-4) Iterative Closest Point Method: Find best-fit transform from [noise] to [x].
  noise_copy = copy.deepcopy(noise)
  length = jnp.sum(jnp.squeeze(node_mask), axis=-1).astype(jnp.int32)
  for _idx, l in enumerate(length):
    _, z_rotated, _ = icp(noise[_idx, :l, :3], xh[_idx, :l, :3])
    noise_copy = noise_copy.at[_idx, :l, :3].set(z_rotated)
  noise = noise_copy

  # (2-5) Linear sum assignment (Hungarian algorithm)
  noise_copy = copy.deepcopy(noise)
  length = jnp.sum(jnp.squeeze(node_mask), axis=-1).astype(jnp.int32)
  distance_matrices = jnp.sqrt(
    jnp.sum((jnp.expand_dims(xh[:, :, :3], axis=2) - jnp.expand_dims(noise[:, :, :3], axis=1)) ** 2, axis=-1)
  )
  for _idx, l in enumerate(length):
    _, col_ind = linear_sum_assignment(distance_matrices[_idx, :l, :l], maximize=False)
    noise_copy = noise_copy.at[_idx, :l, :].set(noise[_idx, col_ind, :])
  noise = noise_copy

  rng, step_rng = jax.random.split(rng)
  t = jax.random.uniform(step_rng, (xh.shape[0], 1, 1))

  if config.data.on_hold_batch > 0 and config.data.cat_loss_step > 0:
    t[-config.data.on_hold_batch:, :, :] = t[-config.data.on_hold_batch:, :, :] * config.data.cat_loss_step
  # (2-6) different loss function for discrete path (discrete_path)
  # eps = 0 # TODO: original value was 1e-4
  # y = (1 - t) * xh + (eps + (1 - eps) * t) * noise # y = (1 - t) * xh + t * noise (noisy data)
  # field_z = (1 - eps) * noise - xh                 # field_z = noise - xh         (conditional vector field)
  # dynamics_batch = {
  #   # 't':         t,          # 'time'
  #   # 'y':         y,          # 'noisy data'
  #   'node_mask': node_mask,  # 'node mask'
  #   'edge_mask': edge_mask,
  #   'context':   context,
  # }

  # (3) Dynamics input --> EGNN input
  # (3-1) Forward with the network
  assert len(xh.shape) == 3
  B, N, dims = xh.shape
  edges = get_adj_matrix(N, B)
  node_mask = jnp.reshape(node_mask, (B * N, 1))
  edge_mask = jnp.reshape(edge_mask, (B * N * N, 1))
  xh = jnp.reshape(xh, (B * N, -1)) * node_mask
  noise = jnp.reshape(noise, (B * N, -1)) * node_mask
  noise_x, noise_h = jnp.split(noise, [D], axis=-1)
  x, h = jnp.split(xh, [D], axis=-1)

  # (3-2) Concatenate time to h --> [h, time]
  if config.data.condition_time:
    if t.size == 1:
      # t is the same for all elements in batch.
      h_time = h.at[:, 0:1].set(t.item())
    else:
      # t is different over the batch dimension.
      h_time = jnp.repeat(jnp.reshape(t, (B, 1)), N, axis=1)
      h_time = jnp.reshape(h_time, (B * N, 1))
    # h = jnp.concatenate([h, h_time], axis=1)

  # (3-3) Context: for conditional generation # TODO
  # if context.size > 0:
  #   context = jnp.reshape(context, (bs * n_nodes, self.context_node_nf))

  egnn_batch = {
    'h'        : h,         # h value,   shape (B * N, 6)
    'x'        : x,         # x value,   shape (B * N, 3)
    'edges'    : edges,     # edges,     shape [(B * N * N,), (B * N * N,)]
    'node_mask': node_mask, # node_mask, shape (B * N, 1)
    'edge_mask': edge_mask, # edge_mask, shape (B * N * N, 1)
    'noise_h'  : noise_h,   # h noise,   shape (B * N, 6)
    'noise_x'  : noise_x,   # x noise,   shape (B * N, 3)
    't'        : h_time,    # time,      shape (B * N, 1)
  }

  all_data = {
    'raw':      batch,
    'cnf':      cnf_batch,
    # 'dynamics': dynamics_batch,
    'egnn':     egnn_batch,
    # 'xh_true':  xh, # get raw xh data here.
    # 'z':        noise,
  }

  return all_data

#==========================================================================#
# Analysis code
def analyze_stability(molecule_list, dataset_info):
  one_hot, x, node_mask = molecule_list['one_hot'], molecule_list['x'], molecule_list['node_mask']
  
  atoms_per_mol = [jnp.sum(m, dtype=jnp.int32) for m in node_mask]
  n_samples = len(x)

  molecule_stable = 0
  nr_sample_bonds = 0
  n_atoms = 0

  processed_list = []

  for i in range(n_samples):
    atom_type = jnp.argmax(one_hot[i], axis=1)[0:atoms_per_mol[i]]
    pos = x[i][0:atoms_per_mol[i]]
    processed_list.append((pos, atom_type))

  for mol in processed_list:
    pos, atom_type = mol
    validity_results = check_stability(pos, atom_type, dataset_info)

    molecule_stable += validity_results[0]
    nr_stable_bonds += validity_results[1]
    n_atoms += validity_results[2]
  
  # Validity results
  fraction_mol_stable = molecule_stable / n_samples
  fraction_atm_stable = nr_stable_bonds / n_atoms
  validity_dict = {
    'mol_stable': fraction_mol_stable,
    'atm_stable': fraction_atm_stable,
  }

  # RD-Kit
  metrics = BasicMolecularMetrics(dataset_info)
  rdkit_metrics = metrics.evaluate(processed_list)
  return validity_dict, rdkit_metrics

  
def check_stability(positions, atom_type, dataset_info, debug=False):

  assert len(positions.shape) == 2
  assert positions.shape[1] == 3
  atom_decoder = dataset_info['atom_decoder']
  x, y, z = jnp.split(positions, 3, axis=-1)
  nr_bonds = np.zeros(len(x), dtype='int')

  for i in range(len(x)):
    for j in range(i + 1, len(x)):
      p1 = np.array([x[i], y[i], z[i]])
      p2 = np.array([x[j], y[j], z[j]])
      dist = np.sqrt(np.sum((p1 - p2)**2))
      atom1, atom2 = atom_decoder[atom_type[i]], atom_decoder[atom_type[j]]
      pair = sorted([atom_type[i], atom_type[j]])
      if dataset_info['name'] == 'qm9' or \
         dataset_info['name'] == 'qm9_second_half' or \
         dataset_info['name'] == 'qm9_first_half':
          order = bond_analyze.get_bond_order(atom1, atom2, dist)
      elif dataset_info['name'] == 'geom':
          order = bond_analyze.geom_predictor((atom_decoder[pair[0]], atom_decoder[pair[1]]), dist)
      nr_bonds[i] += order
      nr_bonds[j] += order
  nr_stable_bonds = 0
  for atom_type_i, nr_bonds_i in zip(atom_type, nr_bonds):
      possible_bonds = bond_analyze.allowed_bonds[atom_decoder[atom_type_i]]
      if type(possible_bonds) == int:
          is_stable = possible_bonds == nr_bonds_i
      else:
          is_stable = nr_bonds_i in possible_bonds
      if not is_stable and debug:
          print("Invalid bonds for molecule %s with %d bonds" %
                (atom_decoder[atom_type_i], nr_bonds_i))
      nr_stable_bonds += int(is_stable)

  molecule_stable = nr_stable_bonds == len(x)
  return molecule_stable, nr_stable_bonds, len(x)


