import random
import jax.numpy as jnp
import jax.random as jax_random
import numpy
from optax._src import base
from optax._src import combine
from optax._src import transform
from typing import Any, Optional

import torch

import pcax as px

import jax.numpy as jnp


def set_seed(seed):
    torch.manual_seed(seed)
    numpy.random.seed(seed)
    random.seed(seed)
    px.RKG.seed(seed)


def sample_multivariate_Gauss(mean, cov, key):
    L = jnp.linalg.cholesky(cov)
    rand = jax_random.normal(key, (mean.shape[0], 1))
    rand = jnp.matmul(L, rand).reshape(mean.shape)
    return mean + rand


def sample_multivariate_Gauss_diag_cov(mean, cov, key):
    rand = jax_random.normal(key, (mean.shape[0],))
    return mean + rand * jnp.sqrt(cov)


## define noisy sgd optimiser for MCPC
def sgdld(
    learning_rate: base.ScalarOrSchedule,
    momentum: Optional[float] = None,
    h_var: float = 1.0,
    gamma: float = 0.0,
    nesterov: bool = False,
    accumulator_dtype: Optional[Any] = None,
    activity_decay: float = None,
) -> base.GradientTransformation:

    eta = (
        2 * h_var * (1 - momentum) / learning_rate
        if momentum is not None
        else 2 * h_var / learning_rate
    )

    grad_transform = combine.chain(
        transform.add_noise(eta, gamma, 0),
        (
            transform.trace(
                decay=momentum, nesterov=nesterov, accumulator_dtype=accumulator_dtype
            )
            if momentum is not None
            else base.identity()
        ),
        (
            transform.add_decayed_weights(weight_decay=activity_decay)
            if activity_decay is not None
            else base.identity()
        ),
        transform.scale_by_learning_rate(learning_rate),
    )

    def init_fn(params):
        state = grad_transform.init(params)
        rand_int = px.RKG(1)[0]  # generate a random seed for each reinit
        new_state = (transform.add_noise(eta, gamma, rand_int).init(params), *state[1:])
        return new_state

    def update_fn(updates, state, params=None):
        return grad_transform.update(updates, state, params)

    return base.GradientTransformation(init_fn, update_fn)


## define noisy sgd optimiser for MCPC
def sgd_scaled(
    learning_rate: base.ScalarOrSchedule,
    momentum: Optional[float] = None,
    scale: float = 1.0,
    nesterov: bool = False,
    accumulator_dtype: Optional[Any] = None,
    activity_decay: float = None,
) -> base.GradientTransformation:

    return combine.chain(
        transform.scale(scale),
        (
            transform.trace(
                decay=momentum, nesterov=nesterov, accumulator_dtype=accumulator_dtype
            )
            if momentum is not None
            else base.identity()
        ),
        (
            transform.add_decayed_weights(weight_decay=activity_decay)
            if activity_decay is not None
            else base.identity()
        ),
        transform.scale_by_learning_rate(learning_rate),
    )
