"""Very simple toy dataset for divisibility."""
import dataclasses
from typing import Optional

import numpy as np
import tensorflow as tf


_MIN_DIVISOR = 2
_MAX_DIVISOR = 30

_MIN_DIVIDEND = 2
_MAX_DIVIDEND = 999_999


@dataclasses.dataclass
class DivisibilityDatasetConfig:
    """Only the config that affects the dataset (e.g. not stuff like buffer size)."""
    min_divisor: int = _MIN_DIVISOR
    max_divisor: int = _MAX_DIVISOR
    min_dividend: int = _MIN_DIVIDEND
    max_dividend: int = _MAX_DIVIDEND

    force_n_divisor_digits: Optional[int] = None
    force_n_dividend_digits: Optional[int] = None

    @property
    def n_divisor_digits(self) -> int:
        if self.force_n_divisor_digits is not None:
            return self.force_n_divisor_digits
        return _n_digits(self.max_divisor)

    @property
    def n_dividend_digits(self) -> int:
        if self.force_n_dividend_digits is not None:
            return self.force_n_dividend_digits
        return _n_digits(self.max_dividend)

    @property
    def n_total_digits(self) -> int:
        return self.n_divisor_digits + self.n_dividend_digits

    def get_divisors_from_examples(self, examples: np.ndarray) -> np.ndarray:
        basis = np.flip(10**np.arange(self.n_divisor_digits, dtype=np.int64))
        return examples[..., :self.n_divisor_digits] @ basis

    def get_dividends_from_examples(self, examples: np.ndarray) -> np.ndarray:
        basis = np.flip(10**np.arange(self.n_dividend_digits, dtype=np.int64))
        return examples[..., -self.n_dividend_digits:] @ basis


def _n_digits(number: int) -> int:
    return len(str(int(number)))


def _make(
    min_divisor=_MIN_DIVISOR,
    max_divisor=_MAX_DIVISOR,
    min_dividend=_MIN_DIVIDEND,
    max_dividend=_MAX_DIVIDEND,
    buffer_size=4 * 1024,
    dtype=np.int64,
):
    # Dividends up to 6 digits, divisors from 2 to 30
    divisors = np.random.randint(
        low=min_divisor,
        high=max_divisor + 1,
        size=[buffer_size],
        dtype=dtype,
    )

    # The divisors might or might not divide these.
    possible_dividends1 = np.random.randint(
        low=min_dividend,
        high=max_dividend + 1,
        size=[buffer_size],
        dtype=dtype,
    )

    # The divisors will always divide these.
    possible_dividends2 = np.random.randint(
        low=min_dividend,
        high=max_dividend // divisors + 1,
        size=[buffer_size],
        dtype=dtype,
    )
    possible_dividends2 *= divisors

    approx_prob_div1_true = 1 / 29 * np.sum(1 / np.arange(min_divisor, max_divisor + 1, dtype=np.float64))

    # We do this to ensure that mixture has a 50-50 mix of true and false answers.
    p_choose1 = 1 / (2 * (1 - approx_prob_div1_true))

    choices = np.random.uniform(size=[buffer_size]) < p_choose1

    dividends = np.where(choices, possible_dividends1, possible_dividends2)
    labels = (dividends % divisors == 0).astype(dtype)

    divisors = divisors.astype(str)
    dividends = dividends.astype(str)

    return divisors, dividends, labels


def create_ds(
    cfg: DivisibilityDatasetConfig,
    buffer_size: int = 4 * 1024,
):

    def gen():
        while True:
            divisors, dividends, labels = _make(
                min_divisor=cfg.min_divisor,
                max_divisor=cfg.max_divisor,
                min_dividend=cfg.min_dividend,
                max_dividend=cfg.max_dividend,
                buffer_size=buffer_size,
                dtype=np.int64,
            )
            for a, b, label in zip(divisors, dividends, labels):
                a_padding = (cfg.n_divisor_digits - len(a)) * '0'
                a = a_padding + a
                b_padding = (cfg.n_dividend_digits - len(b)) * '0'
                b = b_padding + b
                example = a + b
                example = tuple(int(x) for x in example)
                yield example, label

    ds = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=[cfg.n_total_digits], dtype=tf.int64),
            tf.TensorSpec(shape=(), dtype=tf.int64),
        )
    )

    return ds
