from typing import Optional

import semidiscrete_pointcloud as sdpc

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ott.problems.linear import linear_problem

__all__ = ["SemidiscreteLinearProblem"]


@jtu.register_pytree_node_class
class SemidiscreteLinearProblem:
  """Semidiscrete linear OT problem.

  Instances of this problem can be sampled using the :meth:`sample` method.

  Args:
    geom: Semidiscrete point cloud geometry.
    b: The second marginal. If :obj:`None`, it will be uniform.
    tau_b: If :math:`< 1`, defines how much unbalanced the problem is
      on the second marginal. Currently not implemented.
  """

  def __init__(
      self,
      geom: sdpc.SemidiscretePointCloud,
      b: Optional[jax.Array] = None,
      tau_b: float = 1.0,
  ):
    assert tau_b == 1.0, "Unbalanced semidiscrete problem is not supported."
    self.geom = geom
    self._b = b
    self.tau_b = tau_b

  def sample(
      self,
      rng: jax.Array,
      num_samples: int,
  ) -> linear_problem.LinearProblem:
    """Sample a linear OT problem.

    Args:
      rng: Random key used for seeding.
      num_samples: Number of samples.

    Returns:
      The sampled linear problem.
    """
    geom = self.geom.sample(rng, num_samples)
    return linear_problem.LinearProblem(
        geom, a=None, b=self._b, tau_a=1.0, tau_b=self.tau_b
    )

  @property
  def b(self) -> jnp.ndarray:
    """Second marginal."""
    if self._b is not None:
      return self._b
    _, m = self.geom.shape
    return jnp.full((m,), fill_value=1.0 / m, dtype=self.geom.y.dtype)

  def tree_flatten(self):  # noqa: D102
    return (self.geom, self._b), {"tau_b": self.tau_b}

  @classmethod
  def tree_unflatten(  # noqa: D102
      cls, aux_data, children
  ) -> "SemidiscreteLinearProblem":
    return cls(*children, **aux_data)
