"""
Module containing functions for generating synthetic
datasets with known properties, for model testing and
experimentation.
"""

import numpy as np

from spotlight.interactions import Interactions


def _build_transition_matrix(num_items,
                             concentration_parameter,
                             random_state,
                             atol=0.001):

    def _is_doubly_stochastic(matrix, atol):

        return (np.all(np.abs(1.0 - matrix.sum(axis=0)) < atol) and
                np.all(np.abs(1.0 - matrix.sum(axis=1)) < atol))

    transition_matrix = random_state.dirichlet(
        np.repeat(concentration_parameter, num_items),
        num_items)

    for _ in range(100):

        if _is_doubly_stochastic(transition_matrix, atol):
            break

        transition_matrix /= transition_matrix.sum(axis=0)
        transition_matrix /= transition_matrix.sum(1)[:, np.newaxis]

    return transition_matrix


def _generate_sequences(num_steps,
                        transition_matrix,
                        order,
                        random_state):

    elements = []

    num_states = transition_matrix.shape[0]

    transition_matrix = np.cumsum(transition_matrix,
                                  axis=1)

    rvs = random_state.rand(num_steps)
    state = random_state.randint(transition_matrix.shape[0], size=order,
                                 dtype=np.int64)

    for rv in rvs:

        row = transition_matrix[state].mean(axis=0)
        new_state = min(num_states - 1,
                        np.searchsorted(row, rv))

        state[:-1] = state[1:]
        state[-1] = new_state

        elements.append(new_state)

    return np.array(elements, dtype=np.int32)


def generate_sequential(num_users=100,
                        num_items=1000,
                        num_interactions=10000,
                        concentration_parameter=0.1,
                        order=3,
                        random_state=None):
    """
    Generate a dataset of user-item interactions where sequential
    information matters.

    The interactions are generated by a n-th order Markov chain with
    a uniform stationary distribution, where transition probabilities
    are given by doubly-stochastic transition matrix. For n-th order chains,
    transition probabilities are a convex combination of the transition
    probabilities of the last n states in the chain.

    The transition matrix is sampled from a Dirichlet distribution described
    by a constant concentration parameter. Concentration parameters closer
    to zero generate more predictable sequences.

    Parameters
    ----------

    num_users: int, optional
        number of users in the dataset
    num_items: int, optional
        number of items (Markov states) in the dataset
    num_interactions: int, optional
        number of interactions to generate
    concentration_parameter: float, optional
        Controls how predictable the sequence is. Values
        closer to zero give more predictable sequences.
    order: int, optional
        order of the Markov chain
    random_state: numpy.random.RandomState, optional
        random state used to generate the data

    Returns
    -------

    Interactions: :class:`spotlight.interactions.Interactions`
        instance of the interactions class
    """

    if random_state is None:
        random_state = np.random.RandomState()

    transition_matrix = _build_transition_matrix(
        num_items - 1,
        concentration_parameter,
        random_state)

    user_ids = np.sort(random_state.randint(0,
                                            num_users,
                                            num_interactions,
                                            dtype=np.int32))
    item_ids = _generate_sequences(num_interactions,
                                   transition_matrix,
                                   order,
                                   random_state) + 1
    timestamps = np.arange(len(user_ids), dtype=np.int32)
    ratings = np.ones(len(user_ids), dtype=np.float32)

    return Interactions(user_ids,
                        item_ids,
                        ratings=ratings,
                        timestamps=timestamps,
                        num_users=num_users,
                        num_items=num_items)
