"""Compute the floor of the square root of a binary number."""

import math
import random

import chex
import jax.nn as jnn
import jax.numpy as jnp

from tasks import task
from tasks.cs import binary_addition


class ComputeSqrt(task.GeneralizationTask):
  """A task which goal is to compute the square root of a binary number.

  The input is a number in binary (big-endian), and the output is the floor of
  the square root of this number, also in binary.
  Note the output length ie the length of the square root in binary is always
  ceil(input_length / 2) (because log(sqrt(x)) = 1/2 log(x)).

  Examples:
   100101 = 37 -> square root is 6.08... -> floor(6.08) = 6 -> 101
   111 = 7 -> square root is 2.64 -> floor(2.64) = 2 -> 10
  """

  def sample_batch(self, rng: chex.PRNGKey, batch_size: int,
                   length: int) -> task.Batch:
    """Returns a batch of binary numbers and their square roots, in binary."""
    del rng
    numbers = [random.randint(1, 2**length - 1) for _ in range(batch_size)]
    binary_numbers = binary_addition.numbers_to_fixed_length_binary(
        numbers, length=length, little_endian=False)

    sqrts = list(map(math.isqrt, numbers))
    binary_sqrts = binary_addition.numbers_to_fixed_length_binary(
        sqrts, length=self.output_length(length), little_endian=False)

    binary_numbers = jnp.array(binary_numbers, jnp.int32)
    binary_sqrts = jnp.array(binary_sqrts, jnp.int32)

    inputs = jnn.one_hot(binary_numbers, self.input_size)
    output = jnn.one_hot(binary_sqrts, self.output_size)
    return {'input': inputs, 'output': output}

  @property
  def input_size(self) -> int:
    """Returns the input size for the models."""
    return 2

  @property
  def output_size(self) -> int:
    """Returns the output size for the models."""
    return 2

  def output_length(self, input_length: int) -> int:
    return math.ceil(input_length / 2)

