import jax
import math
import torch
import numpy as np
import jax.numpy as jnp

def toy_gmm(n_comp=8, std=0.075, radius=0.8):
  """Ring of 2D Gaussians. Returns energy and sample functions."""
  means_x = np.cos(2 * np.pi * np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(n_comp, 1, 1, 1)
  means_y = np.sin(2 * np.pi * np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(n_comp, 1, 1, 1)
  mean = radius * np.concatenate((means_x, means_y), axis=1)
  weights = np.ones(n_comp) / n_comp

  def nll(x):
    means = torch.from_numpy(mean.reshape((-1, 1, 2))).to(x.device)
    c = math.log(n_comp * 2 * math.pi * std**2)
    f = torch.logsumexp(torch.sum(-0.5 * torch.square((x - means) / std), axis=2), dim=0) + c
    return torch.exp(f.float() + math.log(2))

  def sample(n_samples):
    toy_sample = np.zeros(0).reshape((0, 2, 1, 1))
    sample_group_sz = np.random.multinomial(n_samples, weights)
    for i in range(n_comp):
      sample_group = mean[i] + std * np.random.randn(
          2 * sample_group_sz[i]).reshape(-1, 2, 1, 1)
      toy_sample = np.concatenate((toy_sample, sample_group), axis=0)
      # np.random.shuffle(toy_sample)
    data = toy_sample[:, :, 0, 0]
    return torch.from_numpy(data).float()

  return nll, sample


def toy_gauss(radius=0.5):
  """Ring of 2D Gaussians. Returns energy and sample functions."""
  std = radius

  def nll(x):
    c = np.log(2 * np.pi * std**2)
    f = -0.5 * jnp.square((x) / std) + c
    return f

  def sample(n_samples):
    data = np.random.randn(n_samples, 2) * radius
    return data

  return nll, sample


def toy_box(scale=1.0):
  """Ring of 2D Gaussians. Returns energy and sample functions."""

  def nll(x):
    return 1

  def sample(n_samples):
    data = np.random.uniform(-scale, scale, (n_samples, 2))
    return data

  return nll, sample


def toy_bar(scale=0.2):
  """Ring of 2D Gaussians. Returns energy and sample functions."""

  def nll(x):
    return (x[:, 0].abs() <= scale).float() / (2 * 2 * scale)

  def sample(n_samples):
    data = np.random.uniform(-1, 1, (n_samples, 2))
    data[:, 0] = data[:, 0] * scale
    return torch.from_numpy(data).float()

  return nll, sample


def bar_horizontal(scale=0.2):
  """Ring of 2D Gaussians. Returns energy and sample functions."""

  def nll(x):
    return 1

  def sample(n_samples):
    data = np.random.uniform(-1, 1, (n_samples, 2))
    data[:, 1] = data[:, 1] * scale
    return data

  return nll, sample


def toy_gmm_left(n_comp=8, std=0.075, radius=0.5):
  """Ring of 2D Gaussians. Returns energy and sample functions."""
  means_x = np.cos(2 * np.pi *
                   np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(
                       n_comp, 1, 1, 1)
  means_y = np.sin(2 * np.pi *
                   np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(
                       n_comp, 1, 1, 1)
  mean = radius * np.concatenate((means_x, means_y), axis=1)
  mean = mean[[0, 1, 2, 3]]
  n_comp = mean.shape[0]
  weights = np.ones(n_comp) / n_comp

  def nll(x):
    means = jnp.array(mean.reshape((-1, 1, 2)))
    c = np.log(n_comp * 2 * np.pi * std**2)
    f = jax.nn.logsumexp(
        jnp.sum(-0.5 * jnp.square((x - means) / std), axis=2), axis=0) + c
    # f = f + np.log(2)
    return f

  def sample(n_samples):
    toy_sample = np.zeros(0).reshape((0, 2, 1, 1))
    sample_group_sz = np.random.multinomial(n_samples, weights)
    for i in range(n_comp):
      sample_group = mean[i] + std * np.random.randn(
          2 * sample_group_sz[i]).reshape(-1, 2, 1, 1)
      toy_sample = np.concatenate((toy_sample, sample_group), axis=0)
      np.random.shuffle(toy_sample)
    data = toy_sample[:, :, 0, 0]
    return data

  return nll, sample


def toy_gmm_right(n_comp=8, std=0.075, radius=0.5):
  """Ring of 2D Gaussians. Returns energy and sample functions."""
  means_x = np.cos(2 * np.pi *
                   np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(
                       n_comp, 1, 1, 1)
  means_y = np.sin(2 * np.pi *
                   np.linspace(0, (n_comp - 1) / n_comp, n_comp)).reshape(
                       n_comp, 1, 1, 1)
  mean = radius * np.concatenate((means_x, means_y), axis=1)
  mean = mean[[4, 5, 6, 7]]
  n_comp = mean.shape[0]
  weights = np.ones(n_comp) / n_comp

  def nll(x):
    means = jnp.array(mean.reshape((-1, 1, 2)))
    c = np.log(n_comp * 2 * np.pi * std**2)
    f = jax.nn.logsumexp(
        jnp.sum(-0.5 * jnp.square((x - means) / std), axis=2), axis=0) + c
    # f = f + np.log(2)
    return f

  def sample(n_samples):
    toy_sample = np.zeros(0).reshape((0, 2, 1, 1))
    sample_group_sz = np.random.multinomial(n_samples, weights)
    for i in range(n_comp):
      sample_group = mean[i] + std * np.random.randn(
          2 * sample_group_sz[i]).reshape(-1, 2, 1, 1)
      toy_sample = np.concatenate((toy_sample, sample_group), axis=0)
      np.random.shuffle(toy_sample)
    data = toy_sample[:, :, 0, 0]
    return data

  return nll, sample


def right_bar(scale=0.1):
  """Ring of 2D Gaussians. Returns energy and sample functions."""

  def nll(x):
    return 1

  def sample(n_samples):
    data = np.random.uniform(-1, 1, (n_samples, 2))
    data[:, 0] = data[:, 0] * scale + 0.2
    return data

  return nll, sample


def left_bar(scale=0.1):
  """Ring of 2D Gaussians. Returns energy and sample functions."""

  def nll(x):
    return 1

  def sample(n_samples):
    data = np.random.uniform(-1, 1, (n_samples, 2))
    data[:, 0] = data[:, 0] * scale - 0.2
    return data

  return nll, sample