import numpy as np
import pyknotid.spacecurves as sp


def _create_knot(arr):
  # ?? add_closure does not always work if >0.02 (see src)
  return sp.Knot(arr, add_closure=True, verbose=False)


def number_of_crossings(arr):
  knot = _create_knot(arr)
  return len(knot.raw_crossings(include_closure=False)) // 2


def check_cython():
  try:
    from pyknotid.spacecurves import chelpers

    return True
  except ImportError:
    return False


def gauss_code(arr):
  # # un-rotate. traverse from the smallest x (leftmost) bead
  # shift = np.argmin(arr[:, 0])
  # arr = np.roll(arr, -shift, axis=0)
  # # Ensure counter-clockwise direction
  # p0, p1, p2 = arr[0], arr[1], arr[2]
  # det = (p1[0] - p0[0]) * (p2[1] - p0[1]) - (p1[1] - p0[1]) * (p2[0] - p0[0])
  # if det < 0:
  #     # If clockwise, reverse the array
  #     arr = arr[::-1]

  knot = _create_knot(arr)
  return knot.gauss_code()


def knot_identify(arr):
  knot = _create_knot(arr)
  return knot.identify()


def eq_gauss_code(arr1, arr2, allow_flipped_or_mirrored=False):
  gc1 = gauss_code(arr1)
  gc2 = gauss_code(arr2)
  return eq_gauss_code2(gc1, gc2, allow_flipped_or_mirrored)


def eq_gauss_code2(gc1, gc2, allow_flipped_or_mirrored=False):
  if not allow_flipped_or_mirrored:
    return str(gc1) == str(gc2)

  gc2s = [
    gc2,
    gc2.flipped(),
    gc2.mirrored(),
    gc2.flipped().mirrored(),
    gc2.mirrored().flipped(),
  ]
  return any(str(gc1) == str(gc2) for gc2 in gc2s)


def eq_top(arr1, arr2):
  return knot_identify(arr1) == knot_identify(arr2)


def eq_coord(arr1, arr2):
  # TODO: add invariance to rotation and translation
  return np.allclose(arr1, arr2, atol=1e-3)


def eq_knot(method, arr1, arr2, **kwargs):
  if method == "coord":
    return eq_coord(arr1, arr2)
  if method == "gc":
    return eq_gauss_code(arr1, arr2, **kwargs)
  if method == "top":
    return eq_top(arr1, arr2)
  raise ValueError("Invalid method")


def colorful(x: float) -> np.ndarray:
  """x: float in [0, 1], returns an RGB color in (4,)"""
  if x == 0.0:
    return np.array([1.0, 1.0, 1.0, 1.0])  # white for the start
  r = np.sin(2 * np.pi * (x + 0.0)) * 0.5 + 0.5
  g = np.sin(2 * np.pi * (x + 0.33)) * 0.5 + 0.5
  b = np.sin(2 * np.pi * (x + 0.67)) * 0.5 + 0.5
  return np.array([r, g, b, 1.0])
