import jax.numpy as jnp

from utils.path_utils import project_path

def avg_stddiv_across_marginals(samples):
    # In accordance to https://arxiv.org/abs/2307.01198
    # We compute the stddiv estimate
    # Input: A (batch_size, dim) tensor of terminal examples
    # Output: The computed value 
    d = samples.shape[-1]
    res = 0
    for i in range(d):
        res += jnp.std(samples[:,i])
    return res/d 

def moving_averages(dictionary, window_size=5):
    mov_avgs = {}
    for key, value in dictionary.items():
        try:
            if not 'mov_avg' in key:
                mov_avgs[f'{key}_mov_avg'] = [jnp.mean(jnp.array(value[-min(len(value), window_size):]), axis=0)]
        except:
            pass
    return mov_avgs


def extract_last_entry(dictionary):
    last_entries = {}
    for key, value in dictionary.items():
        try:
            last_entries[key] = value[-min(len(value), 1)]
        except:
            pass
    return last_entries


def save_samples(cfg, logger, samples):
    if len(logger['KL/elbo']) > 1:
        if logger['KL/elbo'][-1] >= jnp.max(jnp.array(logger['KL/elbo'][:-1])):
            jnp.save(project_path(f'{cfg.log_dir}/samples_{cfg.algorithm.name}_{cfg.target.name}_{cfg.target.dim}D_seed{cfg.seed}'), samples)
        else:
            return
    else:
        jnp.save(project_path(f'{cfg.log_dir}/samples_{cfg.algorithm.name}_{cfg.target.name}_{cfg.target.dim}D_seed{cfg.seed}'),
                 samples)


def compute_reverse_ess(log_weights, eval_samples):
    # Subtract the maximum log weight for numerical stability
    max_log_weight = jnp.max(log_weights)
    stable_log_weights = log_weights - max_log_weight

    # Compute the importance weights in a numerically stable way
    is_weights = jnp.exp(stable_log_weights)

    # Compute the sums needed for ESS
    sum_is_weights = jnp.sum(is_weights)
    sum_is_weights_squared = jnp.sum(is_weights ** 2)

    # Calculate the effective sample size (ESS)
    ess = (sum_is_weights ** 2) / (eval_samples * sum_is_weights_squared)

    return ess


if __name__ == '__main__':
    # Example dictionary
    example_dict = {
        'key1': [1, 2, 3, 4],
        'key2': [5, 6],
        'key3': []
    }

    # Convert the dictionary values to JAX arrays
    jax_example_dict = {key: jnp.array(value) for key, value in example_dict.items()}

    # Compute moving average over the last five entries
    result_dict = extract_last_entry(jax_example_dict)

    print(result_dict)
