import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import ot as pot
import scipy

import flax.linen as nn
import functools
from flax.training import train_state
import optax

def sample_moons(rng, N, std=0.2):
  N_outer = N // 2
  N_inner = N - N // 2
  outer_circ_x = jnp.cos(jnp.linspace(0, jnp.pi, N_outer))
  outer_circ_y = jnp.sin(jnp.linspace(0, jnp.pi, N_outer))
  inner_circ_x = 1 - jnp.cos(jnp.linspace(0, jnp.pi, N_inner))
  inner_circ_y = 1 - jnp.sin(jnp.linspace(0, jnp.pi, N_inner)) - 0.5

  X = jnp.vstack([jnp.append(outer_circ_x, inner_circ_x),
                  jnp.append(outer_circ_y, inner_circ_y)]).T
  y = jnp.hstack([jnp.zeros(N_outer),
                  jnp.ones(N_inner)], dtype=jnp.int32)

  rng, step_rng = jax.random.split(rng)
  X += jax.random.uniform(step_rng, (N, 2)) * std

  X = X * 3 - 1
  return X, y

def sample_8gaussian(rng, N, scale=5, var=0.1):
  # 8 centers
  centers = jnp.array([
    (1, 0),
    (-1, 0),
    (0, 1),
    (0, -1),
    (1.0 / jnp.sqrt(2), 1.0 / jnp.sqrt(2)),
    (1.0 / jnp.sqrt(2), -1.0 / jnp.sqrt(2)),
    (-1.0 / jnp.sqrt(2), 1.0 / jnp.sqrt(2)),
    (-1.0 / jnp.sqrt(2), -1.0 / jnp.sqrt(2)),
  ]) * scale # (8, 2)
  noise_rng, mixture_rng = jax.random.split(rng)
  noise = jax.random.normal(noise_rng, (N, 2)) * jnp.sqrt(var)
  choice_mixture = jax.random.randint(mixture_rng, (N,), minval=0, maxval=8)
  data = centers[choice_mixture] + noise
  return data

def plot_samples(s1, s2):
  plt.scatter(s1[:, 0], s1[:, 1], s=4, alpha=1, c="blue")
  plt.scatter(s2[:, 0], s2[:, 1], s=4, alpha=1, c="red")
  plt.legend(['moons', '8gaussians'])
  plt.xticks([])
  plt.yticks([])
  plt.xlim(-5, 5)
  plt.ylim(-5, 5)
  plt.axis('equal')
  plt.show()


def plot_trajectories(traj):
  """Plot trajectories of some selected samples."""
  n = traj.shape[1]
  plt.figure(figsize=(6, 6))
  plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
  plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
  plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
  plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
  plt.xticks([])
  plt.yticks([])
  plt.xlim(-5, 5)
  plt.ylim(-5, 5)
  plt.axis('equal')
  plt.show()


def get_w_distance(source, target):
  assert source.shape == target.shape
  ot_fn = pot.emd2
  a, b = pot.unif(source.shape[0]), pot.unif(target.shape[0])
  M = scipy.spatial.distance.cdist(source, target)
  M2 = M ** 2
  ret = ot_fn(a, b, M, numItermax=int(1e7))
  ret2 = ot_fn(a, b, M2, numItermax=int(1e7))
  ret2 = jnp.sqrt(ret2)
  return ret, ret2

rng = jax.random.PRNGKey(42)
N = 256
n_epochs = 25000
aug_dim = 4
lambda1 = 0.2
lambda2 = 0.2
nfe = 30

class Dense(nn.Module):
  aug_dim: int = 0

  @nn.compact
  def __call__(self, x, t):
    x = jnp.concatenate([t, x], axis=-1)
    x = nn.Dense(64)(x)
    x = nn.swish(x)
    x = nn.Dense(64)(x)
    x = nn.swish(x)
    x = nn.Dense(2 + self.aug_dim)(x)
    return x

rng, step_rng = jax.random.split(rng)
model = Dense(aug_dim=aug_dim)


if aug_dim > 0:
  class AugDense(nn.Module):
    aug_dim: int = 0
    @nn.compact
    def __call__(self, x):
      x = nn.Dense(self.aug_dim, use_bias=False)(x)
      return x
  rng, step_rng = jax.random.split(rng)
  aug_model = AugDense(aug_dim=aug_dim)
  aug_variables = aug_model.init(step_rng, jnp.zeros((100, 2)))
  aug_dense_fn = functools.partial(aug_model.apply, {'params': aug_variables['params']})


def init_train_state(rng) -> train_state.TrainState:
  # Initialize the Model
  variables = model.init(rng, x=jnp.zeros((100, 2 + aug_dim)), t=jnp.zeros((100, 1)))
  # Create the optimizer
  optimizer = optax.adam(2e-4)
  # Create a State
  return train_state.TrainState.create(
      apply_fn = model.apply,
      tx=optimizer,
      params=variables['params']
  )


def draw_traj(state, input, nfe):
  t = jnp.linspace(0, 1, nfe + 1)
  trajectory = jnp.expand_dims(input, axis=0)
  for i in range(nfe):
    current_t = t[i]
    next_t = t[i + 1]
    t_shape = input.shape[:-1] + (1,)
    vec_t = jnp.full(t_shape, current_t)
    flow = state.apply_fn({'params': state.params}, input, vec_t)
    input += jnp.expand_dims(next_t - current_t, axis=-1) * flow
    trajectory = jnp.concatenate([trajectory, jnp.expand_dims(input, axis=0)], axis=0)
  return trajectory


rng, step_rng = jax.random.split(rng)
state = init_train_state(step_rng)

for idx in range(1, n_epochs + 1):
  rng, step_rng = jax.random.split(rng)
  moons, _ = sample_moons(step_rng, N, std=0.2)

  rng, step_rng = jax.random.split(rng)
  gaussian = sample_8gaussian(step_rng, N)

  if aug_dim > 0:
    # Augmentation
    rng, step_rng = jax.random.split(rng)
    aug_rng = jax.random.normal(step_rng, (N, aug_dim)) * 3.0 # y0
    aug_gaussian = lambda1 * aug_rng + lambda2 * aug_dense_fn(moons + gaussian) 
    moons = jnp.concatenate([moons, aug_rng], axis=-1)
    gaussian = jnp.concatenate([gaussian, aug_gaussian], axis=-1)
  
  # moon --> gaussian
  flow_ref = gaussian - moons

  rng, step_rng = jax.random.split(rng)
  t = jax.random.uniform(step_rng, (N, 1))

  perturbed_data = (1 - t) * moons + t * gaussian # 0 moon --> 1 gaussian

  def loss_fn(params):
    flow = state.apply_fn({'params': params}, perturbed_data, t)
    return jnp.mean(jnp.square(flow - flow_ref)), flow

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss_val, flow_cal), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)

  if idx % 5000 == 0:
    
    for nfe in [1, 2, 5, 10, 25]:

      w1_list = []
      w2_list = []

      for j in range(50):

        N_draw = 500
        rng, step_rng = jax.random.split(rng)
        input_moon, _ = sample_moons(step_rng, N_draw, std=0.2)
        rng, step_rng = jax.random.split(rng)
        aug_rng = jax.random.normal(step_rng, (N_draw, aug_dim)) * 3.0
        input_moon_aug = jnp.concatenate([input_moon, aug_rng], axis=-1)
        traj = draw_traj(state, input_moon_aug, nfe) # (NFE + 1, N, 2), get trajectory
        traj, aug_traj = jnp.split(traj, [2], axis=-1)
        rng, step_rng = jax.random.split(rng)
        gen_target, true_target = traj[-1], sample_8gaussian(step_rng, N_draw)
        w1, w2 = get_w_distance(gen_target, true_target)
        # print(f"{idx} loss {loss_val} w1 {w1:.5f} w2 {w2:.5f}")  

        w1_list.append(w1)
        w2_list.append(w2)
        
      w1_list = jnp.array(w1_list)
      w2_list = jnp.array(w2_list)
      w1_mean, w1_std = jnp.mean(w1_list), jnp.std(w1_list)
      w2_mean, w2_std = jnp.mean(w2_list), jnp.std(w2_list)

      print(f"{idx} loss {loss_val} nfe {nfe} w1 {w1_mean:.5f} pm {w1_std:.5f} w2 {w2_mean:.5f} pm {w2_std:.5f}")  

