import collections
import jax
import jax.numpy as np

from jax import random
from jax import grad, jit, vmap, value_and_grad
from jax.experimental import optimizers
from jax.util import partial, safe_zip, safe_map, unzip2
from jax.tree_util import tree_map, tree_multimap, tree_flatten, tree_unflatten, tree_reduce
from jax.lax import fori_loop

from jax.scipy import stats

from posteriors.utils import sample_weights_diag
import operator

SWAGState = collections.namedtuple('SWAGState', [
    'running_means',
    'running_second_moment',
    'n',
    'max_n_params',
    'cov_list',
    ]
)

def _diff(params, mus):
    return tree_multimap(operator.sub, params, mus)

def _square(params):
    return tree_map(lambda p: p**2, params)

def _running_average(params, mus, n):
    return tree_multimap(lambda p, mu: (n * mu + p) / (n + 1), params, mus)


def init_swag(cur_params, max_n_params=20, all_params=[]):
    return SWAGState(running_means=cur_params, 
                     running_second_moment=_square(cur_params),
                     n=1,
                     max_n_params=max_n_params,
                     cov_list=[],
                     )

def update_swag(swag_state, cur_params):
    """
    Adds a new parameter to the SWAG history. If it exceeds the maximum history
    size, we pop the oldest parameter.
    """
    new_means = _running_average(cur_params, swag_state.running_means, swag_state.n)
    new_vars = _running_average(_square(cur_params), swag_state.running_second_moment, swag_state.n)

    cov_list = swag_state.cov_list
    if len(cov_list) >= swag_state.max_n_params:
        cov_list.pop(0)
    cov_list.append(_diff(cur_params, new_means))

    return SWAGState(running_means=new_means,
                     running_second_moment=new_vars,
                     n=swag_state.n+1,
                     max_n_params=swag_state.max_n_params,
                     cov_list=cov_list)

@jit
def sample_swag(rng, swag_state, scale=1.0, diag=True):
    mus = swag_state.running_means
    diag_vars = _diff(swag_state.running_second_moment,  _square(swag_state.running_means))

    if diag:
        return sample_weights_diag(rng, (mus, diag_vars), scale)
    
    rng1, rng2 = random.split(rng)
    diag_component = sample_weights_diag(rng,  (mus, diag_vars), scale) # leave factor of 2 for scale
    k = len(swag_state.cov_list)
    low_rank_noise = random.normal(rng2, shape=(k,))
    stacked_cov = np.stack(swag_state.cov_list)
    

### deprecated
def collect_posterior(swag_state):
    """
    Returns the mean (SWA solution) and variances
    """
    mu = swag_state.running_means
    diag_vars = _diff(swag_state.running_second_moment, _square(swag_state.running_means))
    return mu, diag_vars 

