"""
Copyright (c) ANONYMOUS
All rights reserved.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

from functools import partial

import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as onp
import wandb
import numpy as np

import optax
from ccoa import envs, utils, contribution
from ccoa.experiment import CallbackState


def get_parallel_causal(rng, ctx, state, get_metric_fn):
    causal_dict = dict()
    if isinstance(ctx.contribution, contribution.parallel.ParallelContribution):
        for k in ctx.contribution.contribution_dict.keys():
            if isinstance(ctx.contribution.contribution_dict[k], contribution.causal.CausalContributionBase):
                causal_dict[k] = (ctx.contribution.contribution_dict[k], state.contribution.state[k])
    else:
        causal_dict[""] = (ctx.contribution, state.contribution)

    all_metric = dict()
    for key, (ctx_contribution, state_contribution) in causal_dict.items():
        all_metric.update({_k + "_" + key: _v for (_k, _v) in get_metric_fn(rng, ctx_contribution, state_contribution).items()})

    return all_metric


def make_contribution_bias_callback(mdp, first_state_only):
    if mdp is None:
        return lambda x, y, z: dict()

    @partial(jax.jit, static_argnums=1)
    def callback_contribution_bias(rng, ctx, state: CallbackState):
        """
        Compute the MSE of the contribution along the sampled trajectories.
        """
        policy_prob = jax.nn.softmax(
            ctx.agent.forward_policy_train(state.agent.params, mdp.mdp_observation)
        )

        trajectory = state.trajectory
        num_timesteps = trajectory.observations.shape[1]

        def get_metric_fn(rng, ctx_contribution, state_contribution):
            groundtruth_contribution_coeff = jax.vmap(
                ctx_contribution.get_groundtruth_contribution_coeff, in_axes=(None, 0, None, None)
            )(state_contribution, trajectory, mdp, policy_prob)
            contribution_coeff = jax.vmap(ctx_contribution.get_contribution_coeff, in_axes=(None, 0))(
                state_contribution, trajectory
            )
            mask_tril = jnp.expand_dims(1 - jnp.tri(num_timesteps, num_timesteps, k=-1), -1)
            # We also group the contribution coeff by the reward corresponding to s', to get some more insights.
            #  this we do by creating masks that select each specific reward value
            reward_masks = jax.vmap(jax.vmap(lambda x, y: x == y, in_axes=(0, None)), in_axes=(None, 0))(
                trajectory.rewards,
                ctx_contribution.reward_values)
            reward_masks = jnp.expand_dims(reward_masks, 2)
            reward_masks = jnp.expand_dims(reward_masks, -1)
            reward_masks = reward_masks * mask_tril

            has_key_mask = jax.vmap(jax.vmap(lambda x: x[8] == 1))(trajectory.observations)
            has_key_mask = jnp.expand_dims(has_key_mask, 1)
            has_key_mask = jnp.expand_dims(has_key_mask, -1)
            has_no_key_mask = (1 - has_key_mask)*mask_tril
            has_key_mask = has_key_mask * mask_tril

            if first_state_only:
                groundtruth_contribution_coeff = groundtruth_contribution_coeff[:,0:1,:,:]
                contribution_coeff = contribution_coeff[:,0:1,:,:]
                mask_tril = mask_tril[0:1,:,:]
                reward_masks = reward_masks[:,:,0:1,:,:]
                has_key_mask = has_key_mask[:,0:1,:,:]
                has_no_key_mask = has_no_key_mask[:, 0:1, :,:]

            mse = ((groundtruth_contribution_coeff - contribution_coeff) ** 2 * mask_tril)
            mse = mse.mean((0,-1)).sum() / mask_tril.sum()
            contribution_coeff_norm = (groundtruth_contribution_coeff**2 * mask_tril)
            contribution_coeff_norm = contribution_coeff_norm.mean((0,-1)).sum() / mask_tril.sum()

            mse_reward_grouped = (groundtruth_contribution_coeff - contribution_coeff) ** 2 * reward_masks
            contribution_coeff_reward_grouped = groundtruth_contribution_coeff**2 * reward_masks

            mse_reward_grouped_mean = jnp.where(reward_masks.sum(axis=(
                1,2,3,4)) == 0, float('nan'), mse_reward_grouped.mean(-1).sum(axis=(1,2,3)) / reward_masks.sum(axis=(
                1,2,3,4)))
            contribution_coeff_reward_grouped_mean = jnp.where(reward_masks.sum(axis=(
                1,2,3,4)) == 0, float('nan'), contribution_coeff_reward_grouped.mean(-1).sum(axis=(1,2,
                                                                                              3)) / reward_masks.sum(axis=(
                1,2,3,4)))

            mse_reward_grouped_has_key = (groundtruth_contribution_coeff - contribution_coeff) ** 2 * reward_masks * \
                                         has_key_mask
            contribution_coeff_reward_grouped_has_key = groundtruth_contribution_coeff ** 2 * reward_masks * has_key_mask

            mse_reward_grouped_has_key_mean = jnp.where((reward_masks*has_key_mask).sum(axis=(
                1, 2, 3, 4)) == 0, float('nan'), mse_reward_grouped_has_key.mean(-1).sum(axis=(1, 2,
                                                                                        3)) / (reward_masks*has_key_mask).sum(
                axis=(
                    1, 2, 3, 4)))
            contribution_coeff_reward_grouped_has_key_mean = jnp.where((reward_masks*has_key_mask).sum(axis=(
                1, 2, 3, 4)) == 0, float('nan'), contribution_coeff_reward_grouped_has_key.mean(-1).sum(axis=(1, 2,
                                                                                                      3)) / (reward_masks*has_key_mask).sum(
                axis=(
                    1, 2, 3, 4)))

            mse_reward_grouped_has_no_key = (groundtruth_contribution_coeff - contribution_coeff) ** 2 * reward_masks\
                                            * has_no_key_mask
            contribution_coeff_reward_grouped_has_no_key = groundtruth_contribution_coeff ** 2 * reward_masks * has_no_key_mask

            mse_reward_grouped_has_no_key_mean = jnp.where((reward_masks * has_no_key_mask).sum(axis=(
                1, 2, 3, 4)) == 0, float('nan'), mse_reward_grouped_has_no_key.mean(-1).sum(axis=(1, 2,
                                                                                             3)) / (reward_masks * has_no_key_mask).sum(
                axis=(
                    1, 2, 3, 4)))
            contribution_coeff_reward_grouped_has_no_key_mean = jnp.where((reward_masks * has_no_key_mask).sum(axis=(
                1, 2, 3, 4)) == 0, float('nan'), contribution_coeff_reward_grouped_has_no_key.mean(-1).sum(axis=(1, 2,
                                                                                                      3)) / (reward_masks * has_no_key_mask).sum(
                axis=(
                    1, 2, 3, 4)))

            mse_has_key = (groundtruth_contribution_coeff - contribution_coeff) ** 2 * has_key_mask
            mse_has_no_key = (groundtruth_contribution_coeff - contribution_coeff) ** 2 * has_no_key_mask
            contribution_coeff_has_key = groundtruth_contribution_coeff**2 * has_key_mask
            contribution_coeff_has_no_key = groundtruth_contribution_coeff ** 2 * has_no_key_mask

            mse_has_key_mean = jnp.where(has_key_mask.sum() == 0, float('nan'), mse_has_key.mean(-1).sum() /
                                         has_key_mask.sum())
            mse_has_no_key_mean = jnp.where(has_no_key_mask.sum() == 0, float('nan'), mse_has_no_key.mean(-1).sum() /
                                         has_key_mask.sum())
            contribution_coeff_has_key_mean = jnp.where(has_key_mask.sum() == 0, float('nan'), contribution_coeff_has_key.mean(-1).sum() /
                                         has_key_mask.sum())
            contribution_coeff_has_no_key_mean = jnp.where(has_no_key_mask.sum() == 0, float('nan'), contribution_coeff_has_no_key.mean(-1).sum() /
                                            has_key_mask.sum())



            metric = {
                "contribution_mse": mse,
                "contribution_relative_mse": mse / contribution_coeff_norm,
                "contribution_mse_log": 10 / jnp.log(10) * jnp.log(mse),
                "contribution_relative_mse_log": 10 / jnp.log(10) * (jnp.log(mse) - jnp.log(contribution_coeff_norm)),
                "contribution_mse_log_has_key": 10/jnp.log(10) * jnp.log(mse_has_key_mean),
                "contribution_mse_log_has_no_key": 10 / jnp.log(10) * jnp.log(mse_has_no_key_mean),
                "contribution_relative_mse_log_has_key": 10 / jnp.log(10) * (jnp.log(mse_has_key_mean) - jnp.log(
                    contribution_coeff_has_key_mean)),
                "contribution_relative_mse_log_has_no_key": 10 / jnp.log(10) * (jnp.log(mse_has_no_key_mean)-jnp.log(
                    contribution_coeff_has_no_key_mean)),
            }
            reward_values = np.array(ctx_contribution.reward_values)
            for i,r in enumerate(reward_values):
                metric["contribution_mse_log_reward_{:.3f}".format(float(r))] = 10/jnp.log(10) * jnp.log(
                    mse_reward_grouped_mean[i])
                metric["contribution_relative_mse_log_reward_{:.3f}".format(float(r))] = 10 / jnp.log(10) * (jnp.log(
                    mse_reward_grouped_mean[i]) - jnp.log(contribution_coeff_reward_grouped_mean[i]))
                metric["contribution_relative_mse_log_has_key_reward_{:.3f}".format(float(r))] = 10 / jnp.log(10) * (\
                                                                                              jnp.log(
                    mse_reward_grouped_has_key_mean[i]) - jnp.log(contribution_coeff_reward_grouped_has_key_mean[i]))
                metric["contribution_relative_mse_log_has_no_key_reward_{:.3f}".format(float(r))] = 10 / jnp.log(10) \
                                                                                                    * (\
                                                                                                 jnp.log(
                                                                                                     mse_reward_grouped_has_no_key_mean[
                                                                                                         i]) - jnp.log(
                    contribution_coeff_reward_grouped_has_no_key_mean[i]))
                metric["contribution_mse_log_has_key_reward_{:.3f}".format(float(r))] = 10 / jnp.log(10) * \
                                                                                                 jnp.log(
                                                                                                     mse_reward_grouped_has_key_mean[
                                                                                                         i])
                metric["contribution_mse_log_has_no_key_reward_{:.3f}".format(float(r))] = 10 / jnp.log(10) * \
                                                                                        jnp.log(
                                                                                            mse_reward_grouped_has_no_key_mean[
                                                                                                i])

            return metric
        return get_parallel_causal(rng, ctx, state, get_metric_fn)

    return callback_contribution_bias


def make_policy_var_callback(mdp, first_state_only, prefix):
    if mdp is None:
        return lambda x, y, z: dict()

    def get_policy_gradient_mean(agent, agent_state, policy_prob, advantage):
        policy_transition = jax.vmap(lambda a, b: a @ b)(policy_prob, mdp.mdp_transition)
        batch_inner_prod = jax.vmap(lambda a, b: a @ b)

        def get_loss(params):
            def get_summed_loss(curr_state, timestep):
                _policy_prob = jax.nn.softmax(agent.forward_policy_train(params, mdp.mdp_observation))
                curr_loss = curr_state @ batch_inner_prod(_policy_prob, advantage)
                next_state = curr_state @ policy_transition
                return next_state, curr_loss

            carry, timestep_loss = jax.lax.scan(
                get_summed_loss, mdp.init_state, jnp.arange(mdp.max_trial)
            )
            return -timestep_loss.mean()

        return jax.value_and_grad(get_loss)(agent_state.params)[1]

    @partial(jax.jit, static_argnums=1)
    def callback_policy_grad_var(rng, ctx, state: CallbackState):
        """
        Compute the variance of the policy gradient within a batch averaged over parameters.
        """
        (rng_grad,) = jax.random.split(rng, 1)
        policy_prob = jax.nn.softmax(
            ctx.agent.forward_policy_train(state.agent.params, mdp.mdp_observation)
        )
        gt_advantage = mdp.get_advantage(policy_prob)
        if first_state_only:
            gt_advantage = gt_advantage.at[1:].set(0)
        gt_grad = get_policy_gradient_mean(ctx.agent, state.agent, policy_prob, gt_advantage)
        trajectory = state.trajectory
        batch_size = trajectory.observations.shape[0]

        fn_dict = dict()
        if isinstance(ctx.contribution, contribution.parallel.ParallelContribution):
            for k in ctx.contribution.contribution_dict.keys():
                fn_dict[k] = (
                    partial(ctx.contribution.expected_advantage, key=k),
                    partial(ctx.contribution.__call__, key=k),
                )
        else:
            fn_dict[""] = (ctx.contribution.expected_advantage, ctx.contribution.__call__)

        all_metric = dict()
        for key, (expected_advantage, call) in fn_dict.items():
            expected_advantage = expected_advantage(state.contribution, mdp, policy_prob)
            if first_state_only:
                expected_advantage = expected_advantage.at[1:].set(0)

            expected_grad = get_policy_gradient_mean(
                ctx.agent, state.agent, policy_prob, expected_advantage
            )

            # Get the current batch of trajectories and compute corresponding return contributions
            return_contribution = jax.vmap(call, in_axes=(None, 0))(state.contribution, trajectory)
            if first_state_only:
                return_contribution = return_contribution.at[:, 1:].set(0)
            # print("expected_gt", gt_advantage[0])
            # print("expected_{}".format(key), expected_advantage[0])
            # print("actual_{}".format(key), return_contribution.mean(0)[0])
            # print("actual_{}_std".format(key), return_contribution.std(0)[0]/32)

            # Compute the agent's parameter grad for each trajectory in the batch individually
            trajectory_expanded = jtu.tree_map(partial(jnp.expand_dims, axis=1), trajectory)
            return_contribution_expanded = jnp.expand_dims(return_contribution, axis=1)
            rngs_grad = jax.random.split(rng_grad, batch_size)
            batched_agent_grad = jax.vmap(ctx.agent.grad, in_axes=(None, 0, 0, 0, None))
            _, grads = batched_agent_grad(
                state.agent.params, rngs_grad, trajectory_expanded, return_contribution_expanded, 0
            )

            def cosine_similarity(a, b):
                c = a @ b
                return jnp.where(c == 0, 0, c / (jnp.linalg.norm(a) * jnp.linalg.norm(b)))

            def signal_noise_norms(a, u):
                cos = cosine_similarity(a, u)
                signal = jnp.linalg.norm(a) ** 2 * cos**2 * jnp.sign(cos)

                noise = jnp.linalg.norm(a) ** 2  * (1 - cos**2)
                return signal, noise
                # norm = jnp.linalg.norm(a) ** 2
                # return signal, norm

            def signal_noise_clipped_norms(a,u):
                cos = cosine_similarity(a,u)
                signal = jnp.linalg.norm(a)**2 * jnp.max(jnp.array([0, cos]))**2
                noise = jnp.linalg.norm(a)**2 * (1 - jnp.max(jnp.array([0, cos]))**2)
                return signal, noise

            def get_metrics(flat_gt_grad, flat_expected_grad, flat_grads):
                bias_cosine_sim = cosine_similarity(
                    flat_expected_grad, flat_gt_grad
                )

                avr_cosine_sim_vs_mean = jax.vmap(cosine_similarity, in_axes=(0, None))(
                    flat_grads, flat_expected_grad
                )

                avr_cosine_sim_vs_gt = jax.vmap(cosine_similarity, in_axes=(0, None))(
                    flat_grads, flat_gt_grad
                )

                signal, noise = jax.vmap(signal_noise_norms, in_axes=(0, None))(
                    flat_grads, flat_gt_grad
                )

                signal_clipped, noise_clipped = jax.vmap(signal_noise_clipped_norms, in_axes=(0, None))(
                    flat_grads, flat_gt_grad
                )

                signal_clipped_vs_mean, noise_clipped_vs_mean = jax.vmap(signal_noise_clipped_norms, in_axes=(0, None))(
                    flat_grads, flat_expected_grad
                )

                signal_wtr_mean, noise_wtr_mean = jax.vmap(signal_noise_norms, in_axes=(0, None))(
                    flat_grads, flat_expected_grad
                )

                sample_cosine_sim = cosine_similarity(
                    jnp.mean(flat_grads, 0), flat_gt_grad
                )
                sample_avr_cosine_sim_vs_mean = jax.vmap(cosine_similarity, in_axes=(0, None))(
                    flat_grads, jnp.mean(flat_grads, 0)
                )

                sample_signal_wtr_mean, sample_noise_wtr_mean = jax.vmap(signal_noise_norms, in_axes=(0, None))(
                    flat_grads, jnp.mean(flat_grads, 0),
                )

                metric = {
#                    "expected_advantage_0": expected_advantage[0,0],
#                    "expected_advantage_1": expected_advantage[0,1],
#                    "expected_advantage_2": expected_advantage[0,2],
#                    "expected_advantage_3": expected_advantage[0,3],
                    "policy_grad_var": jnp.mean(
                        (flat_grads - flat_expected_grad) ** 2
                    )/jnp.mean(flat_gt_grad**2),
                    "policy_grad_var_dB": 20/jnp.log(10)*(-jnp.log(jnp.linalg.norm(flat_gt_grad))+ jnp.log(
                        jnp.mean(jnp.linalg.norm(flat_grads - flat_expected_grad,axis=1))
                    )),
                    "policy_grad_var_unnormalized": jnp.mean(
                        (flat_grads - flat_expected_grad) ** 2
                    ),
                    "policy_grad_var_unnormalized_dB": 20/jnp.log(10)*(jnp.log(
                        jnp.mean(jnp.linalg.norm(flat_grads - flat_expected_grad,axis=1))
                    )),
                    "policy_grad_snr_dB": 20/jnp.log(10)*(jnp.log(jnp.linalg.norm(flat_gt_grad)) - jnp.log(
                    jnp.mean(jnp.linalg.norm(flat_grads - flat_gt_grad, axis=1)))),
                    "policy_grad_snr_dB_vs_mean": 20/jnp.log(10)*(jnp.log(jnp.linalg.norm(flat_expected_grad)) -
                                                                  jnp.log(
                    jnp.mean(jnp.linalg.norm(flat_grads - flat_expected_grad, axis=1)))),
                    "policy_grad_snr_cos_clipped_dB": 10/jnp.log(10)*(jnp.log(jnp.mean(signal_clipped)) - jnp.log(jnp.mean(
                        noise_clipped))),
                    "policy_grad_snr_cos_clipped_dB_vs_mean": 10 / jnp.log(10) * (
                                jnp.log(jnp.mean(signal_clipped_vs_mean)) - jnp.log(jnp.mean(
                            noise_clipped_vs_mean))),
#                    "policy_grad_l2_error": jnp.mean(
#                        (flat_grads - flat_gt_grad) ** 2)/jnp.mean(flat_gt_grad**2),
#                    "policy_grad_l2_error_unnormalized": jnp.mean(
#                        (flat_grads - flat_gt_grad) ** 2),
                    "policy_grad_bias": jnp.mean(
                        (flat_gt_grad - flat_expected_grad) ** 2
                    )/jnp.mean(flat_gt_grad**2),
                    "policy_grad_bias_dB": 20 / jnp.log(10) * (-jnp.log(jnp.linalg.norm(flat_gt_grad)) + jnp.log(
                        jnp.linalg.norm(flat_gt_grad - flat_expected_grad)
                    )),
                    "policy_grad_bias_unnormalized_dB": 20 / jnp.log(10) * (jnp.log(
                        jnp.linalg.norm(flat_gt_grad - flat_expected_grad)
                    )),
                    "policy_grad_bias_unnormalized":jnp.mean(
                        (flat_gt_grad - flat_expected_grad) ** 2
                    ),
                    "policy_grad_cos_sim_bias": bias_cosine_sim,
                    "policy_grad_avr_cos_sim_vs_mean": jnp.mean(avr_cosine_sim_vs_mean),
                    "policy_grad_snr_var_vs_mean": jnp.mean(signal_wtr_mean) / (jnp.mean(noise_wtr_mean)+jnp.mean(jnp.abs(
                    signal_wtr_mean))),
                    "policy_grad_snr_var_vs_mean_sample": jnp.mean(jnp.abs(sample_signal_wtr_mean)) / (
                            jnp.mean(sample_noise_wtr_mean) + jnp.mean(jnp.abs(
                        sample_signal_wtr_mean))),

                    # "policy_grad_unsigned_snr_var_vs_mean": jnp.mean(jnp.abs(signal_wtr_mean)) / (jnp.mean(noise_wtr_mean)+jnp.mean(jnp.abs(
                    # signal_wtr_mean))),

                    # "policy_grad_unsigned_snr_var_vs_mean_sample": jnp.mean(jnp.abs(signal_wtr_mean)) / (
                    #             jnp.mean(noise_wtr_mean) + jnp.mean(jnp.abs(
                    #         signal_wtr_mean))),

                    "policy_grad_var_sample": jnp.mean(jnp.var(flat_grads, axis=0))/jnp.mean(flat_gt_grad**2),
                    "policy_grad_var_sample_unnormalized": jnp.mean(jnp.var(flat_grads, axis=0)),

                    "policy_grad_bias_sample": jnp.mean(
                        (jnp.mean(flat_grads, axis=0) - flat_gt_grad) ** 2
                    )/jnp.mean(flat_gt_grad**2),
                    "policy_grad_bias_sample_unnormalized": jnp.mean(
                        (jnp.mean(flat_grads, axis=0) - flat_gt_grad) ** 2
                    ),

                    "policy_grad_cos_sim_bias_sample": sample_cosine_sim,
                    "policy_grad_avr_cos_sim_vs_mean_sample":  jnp.mean(sample_avr_cosine_sim_vs_mean),

                    "policy_grad_avr_cos_sim_vs_gt": jnp.mean(avr_cosine_sim_vs_gt),
                    "policy_grad_snr_var": jnp.mean(signal) / (jnp.mean(noise)+jnp.mean(jnp.abs(signal))),
                    # "policy_grad_unsigned_snr_var": jnp.mean(jnp.abs(signal)) / (jnp.mean(noise)+jnp.mean(jnp.abs(signal))),

                    # "policy_grad_real_snr_var": jnp.mean(signal) / (jnp.mean(noise)),
                }
                return metric

            batched_flatcat = jax.vmap(utils.flatcat)
            gt_grad = utils.flatcat(gt_grad)
            expected_grad = utils.flatcat(expected_grad)
            grads = batched_flatcat(grads)

            metric = get_metrics(gt_grad, expected_grad, grads)
            all_metric.update({prefix + _k + "_" + key: _v for (_k, _v) in metric.items()})
            # def normalize_vect(u):
            #     return jax.lax.cond(jnp.linalg.norm(u)>0, lambda _: u/jnp.linalg.norm(u), lambda _:u, None)

            # normalized_metric = get_metrics(normalize_vect(gt_grad), normalize_vect(expected_grad), jax.vmap(normalize_vect)(grads))
            # all_metric.update({"normalized_"+_k + "_" + key: _v for (_k, _v) in normalized_metric.items()})

        # all_metric.update({"observation": wandb.Image(
        #     onp.array(mdp.mdp_observation).T
        #     ),**{"transition_{}".format(i): wandb.Image(
        #         onp.array(mdp.mdp_transition[:,i,:]).T
        #     ) for i in range(4)},
        #     **{"reward_probs_{}".format(i): wandb.Image(
        #         onp.array(mdp.mdp_reward_probs[:,i,:]).T
        #     ) for i in range(4)}})

        return all_metric

    return callback_policy_grad_var


@partial(jax.jit, static_argnums=1)
def get_contribution(rng, ctx, state, trajectory):
    def get_single_contribution_fn(rng, ctx_contribution, state_contribution):
        return_contribution = ctx_contribution.get_contribution_coeff(state_contribution, trajectory)
        return {"contribution": return_contribution}
    return get_parallel_causal(rng, ctx, state, get_single_contribution_fn)


@partial(jax.jit, static_argnums=0)
def compute_discount_equivalent(ctx, state, trajectory, contributions):
    num_actions = ctx.env.num_actions

    def compute_single_discount_equivalent(logit, contribution, action):
        def loss_mc(l, a):
            loss = jax.nn.log_softmax(l) * jax.nn.one_hot(a, num_actions)
            return loss.sum()

        def loss_contribution(l, c):
            loss = jax.nn.softmax(l) * c
            return loss.sum()

        def loss_policy(l,a):
            loss = jax.nn.softmax(l) * jax.nn.one_hot(a, num_actions)
            return loss.sum()

        policy_grad = jax.grad(loss_policy)(logit, action)
        mc_grad = jax.grad(loss_mc)(logit, action)
        contribution_grad = jax.grad(loss_contribution)(logit, contribution)

        return (policy_grad * contribution_grad).sum() / (policy_grad * mc_grad).sum()
    return jax.vmap(jax.vmap(compute_single_discount_equivalent, in_axes=(None, 0, None)), in_axes=(0,0,0))(
        trajectory.logits, contributions, trajectory.actions
    )

def visualize_contribution(rng, ctx, state: CallbackState, num_samples=10):
    # NOTE: This callback is slow
    # Compute contribution coeffs for a sample of trajectories in buffer to get min/max estimates
    trajectory_samples = ctx.buffer.sample(rng, state.buffer, num_samples)
    trajectory = jtu.tree_map(lambda x: x[0], trajectory_samples)
    contribution_dict = get_contribution(rng, ctx, state, trajectory)
    all_dict = dict()

    offset = (isinstance(ctx.env, envs.treasure_conveyor.ConveyorTreasure) or \
        isinstance(ctx.env, envs.treasure_conveyor.MultiConveyorTreasure)) and ctx.env.treasure_at_door


    for k,return_contribution in contribution_dict.items():
        # Compute discount equivalent (summary of the contribution coefficients to a scalar):
        contribution_discount_equivalent = compute_discount_equivalent(ctx, state, trajectory, return_contribution)
        # Compute colour map for contribution coeffs
        c_norm = colors.TwoSlopeNorm(
           vcenter=0.0, vmin=-1.0, vmax=max([jnp.max(return_contribution), 3.0])
        )
        c_norm2 = colors.TwoSlopeNorm(
            vcenter=0.0, vmin=-1, vmax=1
        )
        # scalar_map = cmx.ScalarMappable(norm=c_norm, cmap=plt.get_cmap("PiYG"))
        scalar_map = cmx.ScalarMappable(norm=c_norm, cmap=plt.get_cmap("PiYG"))
        scalar_map2 = cmx.ScalarMappable(norm=c_norm2, cmap=plt.get_cmap("PiYG"))

        contribution_treasure = scalar_map.to_rgba(return_contribution[:, -1 - offset, :].T)
        contribution_middle = scalar_map.to_rgba(
            return_contribution[:, int(ctx.env.length / 2) - 1, :].T
        )
        contribution_begin = scalar_map.to_rgba(
            return_contribution[:, 1, :].T
        )
        discount_equivalent_treasure = scalar_map2.to_rgba(
            jnp.expand_dims(contribution_discount_equivalent[:, -1 - offset],0))
        discount_equivalent_middle = scalar_map2.to_rgba(
            jnp.expand_dims(contribution_discount_equivalent[:, int(ctx.env.length / 2) - 1],0)
        )
        discount_equivalent_begin = scalar_map2.to_rgba(
            jnp.expand_dims(contribution_discount_equivalent[:, 1],0)
        )
        # Weird hack to avoid weird complaint of wandb.Image that it doesn't want to plot images with only 1 pixel row
        discount_equivalent_treasure = onp.concatenate([discount_equivalent_treasure, discount_equivalent_treasure],
                                                       axis=0)
        discount_equivalent_middle = onp.concatenate([discount_equivalent_middle, discount_equivalent_middle],
                                                       axis=0)
        discount_equivalent_begin = onp.concatenate([discount_equivalent_begin, discount_equivalent_begin],
                                                       axis=0)

        # table_treasure = [[x,y] for (x,y) in zip(range(trajectory.actions.shape[0]), contribution_discount_equivalent[:, -1 - offset])]
        # table_treasure = wandb.Table(data=table_treasure, columns=["steps", "credit"])
        # table_middle = [[x, y] for (x, y) in
        #                   zip(range(trajectory.actions.shape[0]), contribution_discount_equivalent[:, int(ctx.env.length / 2) - 1])]
        # table_middle = wandb.Table(data=table_middle, columns=["steps", "credit"])
        # table_begin = [[x, y] for (x, y) in
        #                   zip(range(trajectory.actions.shape[0]), contribution_discount_equivalent[:, 1])]
        # table_begin = wandb.Table(data=table_begin, columns=["steps", "credit"])
        #
        # table_treasure_log = [[x, jnp.log(max(1e-10,y))] for (x, y) in
        #                   zip(range(trajectory.actions.shape[0]), contribution_discount_equivalent[:, -1 - offset])]
        # table_treasure_log = wandb.Table(data=table_treasure_log, columns=["steps", "log_credit"])
        # table_middle_log = [[x, jnp.log(max(1e-10,y))] for (x, y) in
        #                 zip(range(trajectory.actions.shape[0]),
        #                     contribution_discount_equivalent[:, int(ctx.env.length / 2) - 1])]
        # table_middle_log = wandb.Table(data=table_middle_log, columns=["steps", "log_credit"])
        # table_begin_log = [[x, jnp.log(max(1e-10,y))] for (x, y) in
        #                zip(range(trajectory.actions.shape[0]), contribution_discount_equivalent[:, 1])]
        # table_begin_log = wandb.Table(data=table_begin_log, columns=["steps", "log_credit"])

        all_dict.update({
                k+"_treasure": wandb.Image(contribution_treasure),
                k+"_middle": wandb.Image(contribution_middle),
                k+"_begin": wandb.Image(contribution_begin),
                k + "_discount_equivalent_treasure": wandb.Image(discount_equivalent_treasure, mode='RGBA'),
                k + "_discount_equivalent_middle": wandb.Image(discount_equivalent_middle),
                k + "_discount_equivalent_begin": wandb.Image(discount_equivalent_begin),
            # k + "_discount_equivalent_treasure_plot": wandb.plot.line(table_treasure, "steps", "credit",
            #                                                           title="contribution summary value treasure"),
            # k + "_discount_equivalent_middle_plot": wandb.plot.line(table_middle, "steps", "credit",
            #                                                           title="contribution summary value middle"),
            # k + "_discount_equivalent_begin_plot": wandb.plot.line(table_begin, "steps", "credit",
            #                                                           title="contribution summary value begin"),
            # k + "_discount_equivalent_treasure_logplot": wandb.plot.line(table_treasure_log, "steps", "log_credit",
            #                                                           title="log contribution summary value treasure"),
            # k + "_discount_equivalent_middle_logplot": wandb.plot.line(table_middle_log, "steps", "log_credit",
            #                                                         title="log contribution summary value middle"),
            # k + "_discount_equivalent_begin_logplot": wandb.plot.line(table_begin_log, "steps", "log_credit",
            #                                                        title="log contribution summary value begin"),

        })
    all_dict.update({"observation": wandb.Image(
                         onp.array(trajectory.observations.reshape(-1, ctx.env.observation_shape[-1])).T
                     ),
                     "action": wandb.Image(onp.array(jax.nn.softmax(trajectory.logits, axis=1)).T)})
    return all_dict


def visualize_reward_prediction(rng, ctx, state: CallbackState):
    # Select the first trajectory.
    trajectory = jtu.tree_map(lambda x: x[0], state.trajectory)

    rewards = trajectory.rewards
    predicted_rewards = ctx.contribution.get_reward_prediction(
        state.contribution, trajectory
    ).squeeze()

    log_dict = {
        "reward_vs_predicted_reward": wandb.plot.line_series(
            xs=[i for i in range(rewards.shape[0])],
            ys=[onp.array(rewards), onp.array(predicted_rewards)],
            keys=["trajectory_rewards", "predicted_rewards"],
            title="reward_vs_predicted_reward",
            xname="trajectory_step",
        )
    }

    return log_dict


def performance_conveyor(rng, ctx, state: CallbackState):

    if isinstance(ctx.env, envs.treasure_conveyor.ConveyorTreasure) or \
        isinstance(ctx.env, envs.treasure_conveyor.MultiConveyorTreasure):
        picked_up_distractors, picked_up_treasure, total_distractors = jax.vmap(ctx.env.reward_info)(state.env)
    else:
        rewards = state.trajectory.rewards
        reward_treasure = ctx.env.reward_treasure
        # reward_distractor = ctx.env.reward_distractor

        picked_up_distractors = (rewards[:, :-1] != 0).sum(axis=1)
        grid = state.env.grid
        missed_distractors = (grid == envs.conveyor.THINGS.APPLE_LEFT.value).sum(axis=1)
        missed_distractors += (grid == envs.conveyor.THINGS.APPLE_RIGHT.value).sum(axis=1)
        total_distractors = picked_up_distractors + missed_distractors
        picked_up_treasure = rewards[:, -1] / reward_treasure

    log_dict = {
        "fraction_distractor_rewards": (picked_up_distractors / total_distractors).mean(),
        "percentage_treasure_rewards": picked_up_treasure.mean(),
        "total_distractors": total_distractors,
        # "rewards_distribution": state.trajectory.rewards,
    }
    return log_dict


def performance_multiple_conveyor(rng, ctx, state: CallbackState):
    rewards = state.trajectory.rewards
    reward_treasure = ctx.env.reward_treasure
    reward_distractor = ctx.env.reward_distractor

    picked_up_distractors = rewards[:, :-1].sum(axis=1) / reward_distractor
    total_distractors = (ctx.env.length - 2) * ctx.env.num_distractor
    picked_up_treasure = rewards[:, -1] / reward_treasure

    log_dict = {
        "fraction_distractor_rewards": (picked_up_distractors / total_distractors).mean(),
        "percentage_treasure_rewards": picked_up_treasure.mean(),
    }
    return log_dict

def get_track_reward_feature(mdp, feature_module):
    @jax.jit
    def get_features(rng, feature_state, action):
        get_features_action = jax.vmap(
            lambda rng, feature_state, obs, action: feature_module(rng, feature_state, obs, action),
            in_axes=(0, None, 0, None))
        features = (
            get_features_action(jax.random.split(rng, mdp.num_state), feature_state, mdp.mdp_observation, action))
        reward_ids = ((mdp.mdp_reward_probs[:, action] > 0) @ jnp.array((1, 2, 4, 8)))

        _, unique_inv = jnp.unique(
            reward_ids, return_inverse=True, size=3
        )
        return features, unique_inv

    @jax.jit
    def average_cossim(x, y):
        cos_sim_fn = jax.vmap(
            jax.vmap(optax.cosine_similarity, in_axes=(0, None)), in_axes=(None, 0)
        )
        cossim_matrix = cos_sim_fn(x, y)
        mask = jnp.ones_like(cossim_matrix)
        return jnp.sum(cossim_matrix * mask) / jnp.sum(mask), jnp.sum((cossim_matrix * mask == 1)) / jnp.sum(mask)

    @jax.jit
    def self_average_cossim(x):
        cos_sim_fn = jax.vmap(
            jax.vmap(optax.cosine_similarity, in_axes=(0, None)), in_axes=(None, 0)
        )
        cossim_matrix = cos_sim_fn(x, x)
        mask = jnp.tril(jnp.ones_like(cossim_matrix), k=-1)
        return jnp.sum(cossim_matrix * mask) / jnp.sum(mask), jnp.sum((cossim_matrix * mask > 0.999)) / jnp.sum(mask)

    def track_reward_feature(rng, ctx, state: CallbackState):
        feature_state = state.contribution.features
        log_dict = {}
        treasures = []
        apples_left = []
        apples_right = []
        apples_all = []

        features, unique_inv = get_features(rng, feature_state, 0)
        treasures.append(features[unique_inv == 1])

        features, unique_inv = get_features(rng, feature_state, 1)
        treasures.append(features[unique_inv == 2])
        apples_left.append(features[unique_inv == 1])
        apples_all.append(features[unique_inv == 1])

        features, unique_inv = get_features(rng, feature_state, 2)
        treasures.append(features[unique_inv == 2])
        apples_right.append(features[unique_inv == 1])
        apples_all.append(features[unique_inv == 1])

        features, unique_inv = get_features(rng, feature_state, 3)
        treasures.append(features[unique_inv == 1])

        treasures = jnp.concatenate(treasures, axis=0)
        apples_left = jnp.concatenate(apples_left, axis=0)
        apples_right = jnp.concatenate(apples_right, axis=0)
        apples_all = jnp.concatenate(apples_all, axis=0)

        log_dict.update({
            "treasure_self_cossim": self_average_cossim(treasures)[0],
            "treasure_self_cossim_identical": self_average_cossim(treasures)[1],
            "apples_left_self_cossim": self_average_cossim(apples_left)[0],
            "apples_left_self_cossim_identical": self_average_cossim(apples_left)[1],
            "apples_right_self_cossim": self_average_cossim(apples_right)[0],
            "apples_right_self_cossim_identical": self_average_cossim(apples_right)[1],
            "apples_left_vs_right_cossim": average_cossim(apples_left, apples_right)[0],
            "apples_left_vs_right_cossim_identical": average_cossim(apples_left, apples_right)[1],
            "treasure_vs_apples_cossim": average_cossim(treasures, apples_all)[0],
            "treasure_vs_apples_cossim_identical": average_cossim(treasures, apples_all)[1],
        })
        return log_dict
    return track_reward_feature