from functools import partial
from io import StringIO
import math

# Using jax for its jit and autograd capabilities.  Will not get around to GPU backends for now.
import jax.numpy as jnp
from jax import value_and_grad, jit, vmap
from jax import random

def periodic_vector_mean_sequence(input_dimensions, period, displacement, total_iterations, key):
    direction = random.multivariate_normal(key, jnp.zeros(input_dimensions), jnp.eye(input_dimensions))
    
    mean = direction / jnp.linalg.norm(direction) * displacement
    means_single_period = jnp.concatenate([
        jnp.tile(mean, (math.floor(period / 2), 1)),
        jnp.tile(-1 * mean, (math.ceil(period / 2), 1)),
    ])
    
    num_periods = math.ceil(total_iterations / period)
    
    return jnp.tile(means_single_period, (num_periods, 1))[:total_iterations]

def test_periodic_vector_mean_sequence():
    test_means = mean_sequence(3, 5, 0.1, 31, random.PRNGKey(1))
    assert jnp.array_equal(test_means[0], test_means[1])
    assert jnp.array_equal(test_means[1], -1 * test_means[2])
    assert jnp.array_equal(test_means[2], test_means[3])
    assert jnp.array_equal(test_means[3], test_means[4])
    assert jnp.array_equal(test_means[0], test_means[5])
    assert test_means.shape[0] == 31
    assert test_means.shape[1] == 3
    assert jnp.linalg.norm(test_means[0]) == 0.1

    
def stochastic_vector_mean_sequence(input_dimensions, switching_interval, variance, total_iterations, key):
    num_means = int(jnp.ceil(total_iterations / switching_interval))
    means = jnp.sqrt(variance) * random.multivariate_normal(
        key, 
        jnp.zeros(input_dimensions), 
        jnp.eye(input_dimensions),
        shape=(num_means,),
    )
    tiled = jnp.tile(means, (1, switching_interval))
    all_means = tiled.reshape((num_means * switching_interval, input_dimensions))
    return all_means[:total_iterations]

def test_stochastic_vector_mean_sequence():
    test_means = mean_sequence(2, 3, 0.1, 10, random.PRNGKey(2))
    test_means
    assert test_means.shape[0] == 10
    assert test_means.shape[1] == 2
    assert jnp.array_equal(test_means[0], test_means[1])
    assert jnp.array_equal(test_means[1], test_means[2])
    assert not jnp.array_equal(test_means[2], test_means[3])
    assert jnp.array_equal(test_means[3], test_means[4])
    assert jnp.array_equal(test_means[4], test_means[5])
    assert not jnp.array_equal(test_means[5], test_means[6])
    

def stochastic_scalar_mean_sequence(switching_interval, variance, total_iterations, key):
    # Generates a scalar sequence of means, which is consumed for multidimensional problems 
    # as the mean of all inputs.
    #
    # TODO: Unsure if this is still compatible with train_step below.
    num_means = int(jnp.ceil(total_iterations / switching_interval))
    means = jnp.sqrt(variance) * random.normal(key, (num_means,))
    
    return jnp.tile(means, (switching_interval, 1)).T.flatten()[:total_iterations]

    
def random_weights(input_dim, output_dim, variance, key):
    return jnp.sqrt(variance) * random.normal(key, (output_dim, input_dim))


def target(input_sample, weights, variance, key):
    without_noise = jnp.dot(weights, input_sample)
    noise = jnp.sqrt(variance) * random.normal(key, without_noise.shape)
    return without_noise + noise
batched_target = vmap(target, in_axes=(0, None, None, 0))


def predict(input_sample, weights):
    return jnp.dot(weights, input_sample)
batched_predict = vmap(predict, in_axes=(0, None))


def mse_loss(input_batch, weights, targets, batched_predict=batched_predict):
    # Default batched_predict is the linear one above.  The nondefault predictor will
    # be useful for neural net models.
    
    # This should divide by:
    # - input_batch.shape[0] to normalize w.r.t. batch size.
    # - targets.shape[1] to normalize w.r.t. output dimensionality
    # Otherwise the effective step size is varied, which is critical to control for these experiments.
    predictions = batched_predict(input_batch, weights)
    losses = jnp.sum((targets - predictions) ** 2) / (input_batch.shape[0] + targets.shape[1])
    return jnp.mean(losses)


@partial(jit, static_argnums=(2, 3))
def train_step(
    input_mean,
    input_variance,
    batch_size,
    input_dimensions,
    target_weights,
    observation_noise_variance,
    step_size,
    momentum,
    weights,
    velocity,
    key,
):
    key, subkey = random.split(key)
    input_batch = jnp.concatenate([
        input_mean + jnp.sqrt(input_variance) * random.normal(subkey, (batch_size, input_dimensions)),
        jnp.ones((batch_size, 1)),
    ], 1)

    key, *subkeys = random.split(key, batch_size + 1)
    targets = batched_target(input_batch, target_weights, observation_noise_variance, jnp.array(subkeys))

    loss, grads = value_and_grad(mse_loss, argnums=1)(input_batch, weights, targets)

    new_velocity = momentum * velocity - step_size * grads
    new_weights = weights + velocity
    
    return loss, new_weights, new_velocity


@partial(jit, static_argnums=(2, 3))
def train_step_adam(
    input_mean,
    input_variance,
    batch_size,
    input_dimensions,
    target_weights,
    observation_noise_variance,
    step_size,
    beta_1,
    beta_2,
    weights,
    m_t,
    v_t,
    step,
    key,
):
    key, subkey = random.split(key)
    input_batch = jnp.concatenate([
        input_mean + jnp.sqrt(input_variance) * random.normal(subkey, (batch_size, input_dimensions)),
        jnp.ones((batch_size, 1)),
    ], 1)

    key, *subkeys = random.split(key, batch_size + 1)
    targets = batched_target(input_batch, target_weights, observation_noise_variance, jnp.array(subkeys))

    loss, grads = value_and_grad(mse_loss, argnums=1)(input_batch, weights, targets)

    new_m_t = beta_1 * m_t + (1 - beta_1) * grads
    new_v_t = beta_2 * v_t + (1 - beta_2) * grads ** 2
    
    m_corrected = new_m_t / (1 - beta_1 ** step)
    v_corrected = new_v_t / (1 - beta_2 ** step)
    
    new_weights = weights - step_size * m_corrected / (jnp.sqrt(v_corrected) + 1e-8)
    
    return loss, new_weights, new_m_t, new_v_t


if __name__ == '__main__':
    test_periodic_vector_mean_sequence()