"""
Based on PureJaxRL Implementation of PPO.
"""
import os
from pathlib import Path
import pickle
import jax
import jax.numpy as jnp
import wandb

import numpy as np
import optax
import pyrallis
from typing import NamedTuple
from flax.training.train_state import TrainState
import distrax
from flax import struct

from typing import NamedTuple
from dataclasses import asdict, dataclass

from src.envs import make_env
from src.envs.log_wrapper import LogWrapper

from src.agents.actors import ScannedRNN, ActorCriticRNN, ActorWithConditionalCritic


@dataclass
class TrainConfig:
    # Wandb and other logging
    project: str = "JaxZSC"
    mode: str = "disabled"  # Literal["online", "offline", "disabled"]
    entity: str = ""
    checkpoint_path: str = "checkpoints"
    checkpoint_freq: int = 25  # Checkpoint every N updates

    # Overcooked
    env_name: str = "overcooked"
    # Literal["cramped_room", "asymm_advantages", "coord_ring", "forced_coord", "counter_circuit"]
    layout_name: str = "cramped_room"
    random_reset: bool = False
    rew_shaping_horizon: float = 3e7

    # Actor-Critic
    activation: str = "tanh"
    fc_dim_size: int = 256
    gru_hidden_dim: int = 256

    embedding_layers: int = 2
    actor_layers: int = 4
    critic_layers: int = 4

    use_layernorm: bool = True

    other_agent_prediction: bool = True
    moa_coef: float = 0.5

    # Training
    seed: int = 42
    lr: float = 2.5e-4  # 1e-3
    anneal_lr: bool = True

    num_envs: int = 512
    num_steps: int = 400

    total_timesteps: float = 5e7
    update_epochs: int = 6
    num_minibatches: int = 8
    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_eps: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.5
    max_grad_norm: float = 0.5

    # E3T
    e3t_epsilon: float = 0.55

    eval_against_pop: bool = True

    def __post_init__(self):
        self.num_actors = 2 * self.num_envs
        self.num_updates = int(self.total_timesteps //
                               self.num_steps // self.num_envs)
        self.num_reward_shaping_updates = int(
            self.rew_shaping_horizon // self.num_steps // self.num_envs)
        self.minibatch_size = self.num_actors * \
            self.num_steps // self.num_minibatches

        # Please see Figure 12 in the original paper:
        # https://papers.nips.cc/paper_files/paper/2023/file/07a363fd2263091c2063998e0034999c-Paper-Conference.pdf
        if self.layout_name == "cramped_room" or self.layout_name == "asymm_advantages":
            self.e3t_epsilon = 0.5
        elif self.layout_name == "coord_ring" or self.layout_name == "counter_circuit":
            self.e3t_epsilon = 0.5  # 0.3 The paper uses several values here.
        elif self.layout_name == "forced_coord":
            self.e3t_epsilon = 0.0

        print("Number of updates: ", self.num_updates)


def rollout(rng, env, network, params, hidden_size):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)
        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, 2, 5))
        ac_in = (
            obs_batch[np.newaxis, :],
            done[np.newaxis, :],
            batched_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action = pi.sample(seed=rng_action).squeeze()

        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        2, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()

def rollout_single_l(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
        ac_in = (
            last_obs["agent_0"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )
        hstate, pi, _, _ = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_1"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action0,
            "agent_1": action1
        }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def rollout_single_r(rng, env, network, params, other_network, other_params, hidden_size, popsize):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_o_action, rng_step = jax.random.split(rng, 4)

        in_past_sa_pairs = jax.tree.map(
            lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
        ac_in = (
            last_obs["agent_1"][np.newaxis, np.newaxis, :],
            done[np.newaxis, ...][:, 0:1],
            in_past_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action0 = pi.sample(seed=rng_action).squeeze()

        pi1, _ = other_network.apply(
            other_params,
            (
                last_obs["agent_0"][np.newaxis, :],
                jnp.zeros(popsize)[np.newaxis, :]
            )
        )
        action1 = pi1.sample(seed=rng_o_action).squeeze()

        env_act = {
            "agent_0": action1,
            "agent_1": action0,
        }

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, - 1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
    }

    past_5_sa_pairs['agent_1']['obs'] = obs['agent_1'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        1, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()




def rollout_both_ways(eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize):
    eval_returns_l, _ = jax.vmap(rollout_single_l, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)

    eval_returns_r, _ = jax.vmap(rollout_single_r, in_axes=(0, None, None, None, None, None, None, None))(
        eval_rng, env, network, params, partner_pop_actor, partner_pop_params, gru_hidden_dim, popsize)
    return jnp.array([eval_returns_l, eval_returns_r]).mean()


class Transition(NamedTuple):
    global_done: jnp.ndarray
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    entropy: jnp.ndarray
    hstate: jnp.ndarray
    other_action: jnp.ndarray
    past_sa_pairs: jnp.ndarray


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify_nested_dics(x: dict, agent_list, shape):
    data = [x[a] for a in agent_list]
    tree = jax.tree.map(lambda *v: jnp.stack(v), *data)
    return jax.tree.map(lambda x: x.reshape((*shape, -1)), tree)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def soft_blend_logits(logits, epsilon):
    probs = jax.nn.softmax(logits)
    num_actions = logits.shape[-1]
    uniform_probs = jnp.ones_like(probs) / num_actions
    mixed_probs = (1 - epsilon) * probs + epsilon * uniform_probs
    return jnp.log(mixed_probs + 1e-8)  # Convert back to logits


def make_update_fn(config, env, network, partner_pop_actor, partner_pop_params, pop_size):
    rew_shaping_anneal = optax.linear_schedule(
        init_value=1.,
        end_value=0.,
        transition_steps=config.rew_shaping_horizon
    )

    # TRAIN LOOP
    def _update_step(update_runner_state, unused):
        # COLLECT TRAJECTORIES
        runner_state, update_steps = update_runner_state

        def _env_step(runner_state, unused):
            train_state, env_state, last_obs, last_done, hstate, rng, update_step, epsilon_agent, past_5_sa_pairs = runner_state
            rng, _rng, _rng_other = jax.random.split(rng, 3)
            obs_batch = batchify(last_obs, env.agents, config.num_actors)

            def get_e3t_action(args):
                pi_ego, k = args
                pi_random = distrax.Categorical(logits=jnp.zeros_like(
                    pi_ego.logits))  # get uniform random policy
                # get mixture policy probs
                pi_e3t_probs = (
                    1-config.e3t_epsilon)*pi_ego.probs + config.e3t_epsilon * pi_random.probs
                pi_e3t = distrax.Categorical(
                    probs=pi_e3t_probs)  # convert probs to policy
                sampled_a = pi_e3t.sample(seed=k)
                log_prob_a = pi_e3t.log_prob(sampled_a)
                entropy_a = pi_e3t.entropy()
                return sampled_a, log_prob_a, entropy_a

            def get_base_action(args):
                pi_ego, k = args
                sampled_a = pi_ego.sample(seed=k)
                log_prob_a = pi_ego.log_prob(sampled_a)
                entropy_a = pi_ego.entropy()
                return sampled_a, log_prob_a, entropy_a

            batched_sa_pairs = batchify_nested_dics(
                past_5_sa_pairs, env.agents, (1, config.num_actors, 5))
            ac_in = (
                obs_batch[np.newaxis, :],
                last_done[np.newaxis, :],
                batched_sa_pairs,
            )
            hstate, pi, value, pred_pi = train_state.apply_fn(
                train_state.params, hstate, ac_in
            )

            e3t_action, e3t_log_prob, e3t_entropy = get_e3t_action((pi, _rng))
            base_action, base_log_prob, base_entropy = get_base_action(
                (pi, _rng))

            epsilon_agent_both = jnp.concatenate(
                [epsilon_agent, ~epsilon_agent], axis=0)
            action = jnp.where(epsilon_agent_both, e3t_action, base_action)
            action = action.squeeze()

            log_prob = jnp.where(epsilon_agent_both,
                                 e3t_log_prob, base_log_prob).squeeze()
            entropy = jnp.where(epsilon_agent_both,
                                e3t_entropy, base_entropy).squeeze()

            env_act = unbatchify(action, env.agents,
                                 config.num_envs, env.num_agents)

            env_act = {k: v.flatten() for k, v in env_act.items()}

            # Update state-action pairs
            past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
                'agent_0']['obs'].at[:, :-1, :].set(past_5_sa_pairs['agent_0']['obs'][:, 1:, :])
            past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
                'agent_0']['obs'].at[:, - 1, :].set(last_obs['agent_0'])
            past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
                'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
            past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
                'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

            past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
                'agent_1']['obs'].at[:, :-1, :].set(past_5_sa_pairs['agent_1']['obs'][:, 1:, :])
            past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
                'agent_1']['obs'].at[:, -1, :].set(last_obs['agent_1'])
            past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
                'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
            past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
                'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config.num_envs)
            obsv, env_state, reward, done, info = jax.vmap(
                env.step, in_axes=(0, 0, 0)
            )(rng_step, env_state, env_act)

            shaped_reward = info['shaped_reward']
            reward_shaping_frac = jnp.maximum(
                0.0, 1.0 - (update_step / config.num_reward_shaping_updates))
            shaped_reward = unbatchify(
                shaped_reward.transpose(1, 0), env.agents, config.num_envs, env.num_agents)
            shaped_reward = {k: v.squeeze() for k, v in shaped_reward.items()}
            reward = jax.tree.map(lambda x, y: x + y *
                                  reward_shaping_frac, reward, shaped_reward)

            # remove shaped rewards
            del info['shaped_reward']

            info = jax.tree.map(lambda x: x.reshape(
                (config.num_actors)), info)
            done_batch = batchify(
                done, env.agents, config.num_actors).squeeze()

            other_action = jnp.concatenate([
                env_act["agent_1"], env_act["agent_0"]
            ], axis=-1)
            transition = Transition(
                jnp.tile(done["__all__"], env.num_agents),
                last_done,
                action.squeeze(),
                value.squeeze(),
                batchify(reward, env.agents, config.num_actors).squeeze(),
                log_prob.squeeze(),
                obs_batch,
                info,
                entropy,
                hstate,
                other_action.squeeze(),  # ; We get this later.
                batchify_nested_dics(
                    past_5_sa_pairs, env.agents, (config.num_actors, 5))
            )
            runner_state = (train_state, env_state, obsv,
                            done_batch, hstate, rng, update_step, epsilon_agent, past_5_sa_pairs)
            return runner_state, transition

        initial_hstate = runner_state[-3]
        (train_state, env_state, obsv, done_batch,
         hstate, past_5_sa_pairs, rng) = runner_state
        # sample which agent we'll increase beta to
        epsilon_agent = jax.random.bernoulli(rng, shape=(config.num_envs,))
        rng, _rng = jax.random.split(rng)
        runner_state = (
            train_state, env_state, obsv, done_batch, hstate, rng, update_steps, epsilon_agent, past_5_sa_pairs)
        runner_state, traj_batch = jax.lax.scan(
            _env_step, runner_state, None, config.num_steps
        )

        # CALCULATE ADVANTAGE
        train_state, env_state, last_obs, last_done, hstate, rng, update_steps, epsilon_agent, past_5_sa_pairs = runner_state
        last_obs_batch = batchify(last_obs, env.agents, config.num_actors)

        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, config.num_actors, 5))

        _, _, last_val, _ = train_state.apply_fn(
            train_state.params, hstate, (last_obs_batch[np.newaxis, ...], last_done[np.newaxis, ...], batched_sa_pairs))
        last_val = last_val.squeeze()

        def _calculate_gae(traj_batch, last_val):
            def _get_advantages(gae_and_next_value, transition):
                gae, next_value = gae_and_next_value
                done, value, reward = (
                    transition.done,
                    transition.value,
                    transition.reward,
                )
                delta = reward + config.gamma * next_value * (1 - done) - value
                gae = (
                    delta
                    + config.gamma * config.gae_lambda * (1 - done) * gae
                )
                return (gae, value), gae

            _, advantages = jax.lax.scan(
                _get_advantages,
                (jnp.zeros_like(last_val), last_val),
                traj_batch,
                reverse=True,
                unroll=16,
            )
            return advantages, advantages + traj_batch.value

        advantages, targets = _calculate_gae(traj_batch, last_val)

        # UPDATE NETWORK
        def _update_epoch(update_state, unused):
            def _update_minbatch(train_state, batch_info):
                init_hstate, traj_batch, advantages, targets = batch_info

                def _loss_fn(params, traj_batch, gae, targets):
                    # RERUN NETWORK
                    _, pi, value, other_pi = network.apply(
                        params,
                        jax.tree.map(lambda h: h.squeeze(), init_hstate),
                        (traj_batch.obs, traj_batch.done, traj_batch.past_sa_pairs),
                    )
                    log_prob = pi.log_prob(traj_batch.action)

                    # CALCULATE VALUE LOSS
                    value_pred_clipped = traj_batch.value + (
                        value - traj_batch.value
                    ).clip(-config.clip_eps, config.clip_eps)
                    value_losses = jnp.square(value - targets)
                    value_losses_clipped = jnp.square(
                        value_pred_clipped - targets)
                    value_loss = (
                        0.5 * jnp.maximum(value_losses,
                                          value_losses_clipped).mean()
                    )

                    # CALCULATE ACTOR LOSS
                    logratio = log_prob - traj_batch.log_prob
                    ratio = jnp.exp(logratio)
                    gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                    loss_actor1 = ratio * gae
                    loss_actor2 = (
                        jnp.clip(
                            ratio,
                            1.0 - config.clip_eps,
                            1.0 + config.clip_eps,
                        )
                        * gae
                    )
                    loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                    loss_actor = loss_actor.mean()
                    entropy = pi.entropy().mean()

                    def calc_moa_loss(args):
                        other_pi, other_action = args
                        other_log_prob = other_pi.log_prob(other_action)
                        # NLL loss for other agent
                        other_loss = -other_log_prob
                        other_loss = other_loss.mean()
                        return other_loss

                    def dummy_moa_loss(x): return jnp.array(0.0)
                    moa_loss = jax.lax.cond(
                        config.other_agent_prediction, calc_moa_loss, dummy_moa_loss, (other_pi, traj_batch.other_action))

                    approx_kl = ((ratio - 1) - logratio).mean()
                    clip_frac = jnp.mean(
                        jnp.abs(ratio - 1) > config.clip_eps)

                    total_loss = (
                        loss_actor
                        + config.moa_coef * moa_loss
                        + config.vf_coef * value_loss
                        - config.ent_coef * entropy
                    )
                    return total_loss, (value_loss, loss_actor, entropy, ratio, approx_kl, clip_frac)

                grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                total_loss, grads = grad_fn(
                    train_state.params, traj_batch, advantages, targets
                )
                train_state = train_state.apply_gradients(grads=grads)
                return train_state, total_loss

            (
                train_state,
                init_hstate,
                traj_batch,
                advantages,
                targets,
                rng,
            ) = update_state
            rng, _rng = jax.random.split(rng)
            batch_size = config.minibatch_size * config.num_minibatches
            assert (
                batch_size == config.num_steps * config.num_actors
            ), "batch size must be equal to number of steps * number of actors per population member"
            permutation = jax.random.permutation(_rng, batch_size)
            init_hstate = jax.tree.map(lambda h: jnp.reshape(
                h, (1, config.num_actors, -1)), init_hstate)
            batch = (
                init_hstate,
                traj_batch,
                advantages.squeeze(),
                targets.squeeze(),
            )
            permutation = jax.random.permutation(_rng, config.num_actors)

            shuffled_batch = jax.tree_util.tree_map(
                lambda x: jnp.take(x, permutation, axis=1), batch
            )

            minibatches = jax.tree_util.tree_map(
                lambda x: jnp.swapaxes(
                    jnp.reshape(
                        x,
                        [x.shape[0], config.num_minibatches, -1]
                        + list(x.shape[2:]),
                    ),
                    1,
                    0,
                ),
                shuffled_batch,
            )
            train_state, total_loss = jax.lax.scan(
                _update_minbatch, train_state, minibatches
            )
            update_state = (
                train_state,
                jax.tree.map(lambda h: h.squeeze(), init_hstate),
                traj_batch,
                advantages,
                targets,
                rng,
            )
            return update_state, total_loss

        update_state = (
            train_state,
            initial_hstate,
            traj_batch,
            advantages,
            targets,
            rng,
        )
        update_state, loss_info = jax.lax.scan(
            _update_epoch, update_state, None, config.update_epochs
        )
        train_state = update_state[0]
        # Run evals
        rng, _rng = jax.random.split(rng, 2)
        eval_rng = jax.random.split(_rng, 100)
        eval_returns, _ = jax.vmap(rollout, in_axes=(0, None, None, None, None))(
            eval_rng, env, network, train_state.params, config.gru_hidden_dim)

        if config.eval_against_pop:
            rng, _rng = jax.random.split(rng, 2)
            eval_rng = jax.random.split(_rng, 100)
            # First vmap over partners and then run 100 games
            eval_pop_returns = jax.vmap(
                rollout_both_ways,
                in_axes=(None, None, None, None, None, 0, None, None)
            )(
                eval_rng, env, network, train_state.params, partner_pop_actor,
                partner_pop_params, config.gru_hidden_dim, pop_size)
        else:
            eval_pop_returns = jnp.asarray(0.0)

        metric = traj_batch.info
        metric = jax.tree.map(
            lambda x: x.reshape(
                (config.num_steps, config.num_envs, env.num_agents)
            ),
            traj_batch.info,
        )

        metric = jax.tree.map(lambda x: x[-1, ...].mean(), metric)
        ratio_0 = loss_info[1][3].at[0, 0].get().mean()
        loss_info = jax.tree.map(lambda x: x.mean(), loss_info)
        metric["loss"] = {
            "total_loss": loss_info[0],
            "value_loss": loss_info[1][0],
            "actor_loss": loss_info[1][1],
            "entropy": loss_info[1][2],
            "ratio": loss_info[1][3],
            "ratio_0": ratio_0,
            "approx_kl": loss_info[1][4],
            "clip_frac": loss_info[1][5],
        }
        metric["eval"] = {
            "eval_sp_return": eval_returns.mean(),
            "eval_pop_returns": eval_pop_returns.mean(),
        }
        metric["reward_shaping_factor"] = jnp.maximum(
            0.0, 1.0 - (update_steps / config.num_reward_shaping_updates))

        rng = update_state[-1]

        def callback(metric):
            wandb.log(metric)

        update_steps = update_steps + 1
        metric = jax.tree.map(lambda x: x.mean(), metric)
        metric["update_step"] = update_steps
        metric["env_step"] = update_steps * config.num_steps * config.num_envs
        jax.debug.callback(callback, metric)

        runner_state = (train_state, env_state, last_obs,
                        last_done, hstate, past_5_sa_pairs, rng)  # hstate resets automatically
        return (runner_state, update_steps), metric
    return _update_step


def get_run_string(config: TrainConfig):
    return f"FF_RNN_E3T_IPPO_Overcooked_{config.layout_name}"


@pyrallis.wrap()
def train(config: TrainConfig):
    ##### WANDB and other setup #####
    tags = [
        "FF",
        "RNN",
        "E3T",
        "IPPO",
        config.layout_name,
    ]

    group_string = get_run_string(config)
    run_string = f"{group_string}_SEED_{config.seed}"

    run = wandb.init(
        project=config.project,
        group=group_string,
        mode=config.mode,
        config=asdict(config),
        save_code=True,
        tags=tags,
    )

    if run.sweep_id is not None:
        run.name = run.sweep_id + "___" + run_string
    else:
        run.name = run.name + "___" + run_string

    print("XPID ID name:")
    print(run.name)
    print("-------------")

    #### Setup and check saving before training ####
    if config.checkpoint_path is not None:
        save_dir = os.path.join(config.checkpoint_path, run.name)
        # Make sure we can write the checkpoint later _before_ we wait 1 day for training!
        os.makedirs(save_dir, exist_ok=True)
        config_dict = asdict(config)
        with open(f"{save_dir}/config.pckl", 'wb') as f:
            pickle.dump(config_dict, f)

    env = make_env(
        "overcooked-v1", {
            "layout": config.layout_name,
            "random_reset": config.random_reset,
        }
    )
    env = LogWrapper(env, replace_info=False)

    def linear_schedule(count):
        frac = 1.0 - (count // (config.num_minibatches *
                                config.update_epochs)) / config.num_updates
        return config.lr * frac

    rng = jax.random.PRNGKey(config.seed)

    # INIT NETWORK
    network = ActorCriticRNN(
        env.action_space("agent_0").n,
        gru_hidden_dim_size=config.gru_hidden_dim,
        fc_dim_size=config.fc_dim_size,
        embedding_layers=config.embedding_layers,
        actor_layers=config.actor_layers,
        critic_layers=config.critic_layers,
        other_agent_prediction=config.other_agent_prediction,
        use_layernorm=config.use_layernorm,
    )

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, config.num_envs, 5, init_x.shape[0])),
            'action': jnp.zeros((1, config.num_envs, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, config.num_envs, 5, init_x.shape[0])),
            'action': jnp.zeros((1, config.num_envs, 5, 1))
        }
    }
    batched_sa_pairs = batchify_nested_dics(
        past_5_sa_pairs, env.agents, (1, config.num_actors, 5))

    init_x = (
        jnp.zeros(
            (1, config.num_actors, init_x.shape[0])
        ),
        jnp.zeros((1, config.num_actors)),
        batched_sa_pairs,
    )
    init_hstate = ScannedRNN.initialize_carry(
        config.num_actors, config.gru_hidden_dim)

    rng, _rng = jax.random.split(rng)
    network_params = network.init(_rng, init_hstate, init_x)

    if config.anneal_lr:
        tx = optax.chain(
            optax.clip_by_global_norm(config.max_grad_norm),
            optax.adam(learning_rate=linear_schedule, eps=1e-5),
        )
    else:
        tx = optax.chain(optax.clip_by_global_norm(
            config.max_grad_norm), optax.adam(config.lr, eps=1e-5))

    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )

    if config.eval_against_pop:
        path = Path("eval_populations/FF_BRDiv") / config.layout_name
        partner_pop_params = []
        for file in path.iterdir():
            if "param" in file.name:
                params = pickle.load(open(file, "rb"))
                partner_pop_params.append(params["actor_params"])

        path_config = Path("eval_populations/FF_BRDiv") / \
            config.layout_name / "config.pckl"

        with open(path_config, "rb") as f:
            other_config = pickle.load(f)

        partner_pop_params = jax.tree.map(
            lambda *x: jnp.stack(x), *partner_pop_params)
        partner_pop_actor = ActorWithConditionalCritic(
            env.action_space(env.agents[0]).n)

        pop_size = other_config["partner_pop_size"]
        # pop_size = 1
    else:
        partner_pop_params = None
        partner_pop_actor = None
        pop_size = 0

    # INIT UPDATE FUNCTION
    _update_step = make_update_fn(
        config, env, network, partner_pop_actor, partner_pop_params, pop_size)
    jitted_update_step = jax.jit(_update_step)

    # INIT EVAL ROLLOUT FUNCTION
    jitted_rollout = rollout  # config is static

    # INIT ENV
    rng, _rng = jax.random.split(rng)
    reset_rng = jax.random.split(_rng, config.num_envs)

    obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)

    # Fill past s-a
    past_5_sa_pairs['agent_0']['obs'] = obsv[
        'agent_0'][:, None, :].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (config.num_envs, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obsv[
        'agent_1'][:, None, :].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (config.num_envs, 5)) * 4

    init_past_5_sa_pairs = past_5_sa_pairs

    runner_state = (
        train_state,
        env_state,
        obsv,
        jnp.zeros((config.num_actors), dtype=bool),
        init_hstate,
        init_past_5_sa_pairs,
        rng)
    update_steps = 0

    for i in range(config.num_updates):
        # VMAPs over the parameters at the same time.
        (runner_state, update_steps), metric = jitted_update_step(
            (runner_state, update_steps), None)

        train_state = runner_state[0]
        params = train_state.params

        # Remarkably, saving is among the most expensive operations
        if config.checkpoint_path is not None:
            if (i % config.checkpoint_freq == 0 and i != 0) or i == config.num_updates - 1:
                env = make_env(
                    "overcooked-v1", {"layout": config.layout_name, "random_reset": False})

                network = ActorCriticRNN(
                    env.action_space("agent_0").n,
                    gru_hidden_dim_size=config.gru_hidden_dim,
                    fc_dim_size=config.fc_dim_size,
                    embedding_layers=config.embedding_layers,
                    actor_layers=config.actor_layers,
                    critic_layers=config.critic_layers,
                    other_agent_prediction=config.other_agent_prediction,
                    use_layernorm=config.use_layernorm,
                )
                total_r, total_l = jitted_rollout(
                    rng, env, network, params, config.gru_hidden_dim)

                path = f"{save_dir}/"
                os.makedirs(path, exist_ok=True)
                payload = {"actor_params": params}
                pickle.dump(payload, open(
                    path + f"params_{i}_{total_r}.pt", "wb"))
                pickle.dump(payload, open(
                    path + f"params.pt", "wb"))
                print(
                    f"Saved params for agent with total reward {path}params_{i}_{total_r}.pt", total_r)

    return {"runner_state": runner_state, "metrics": metric}


if __name__ == '__main__':
    train()
