"""Evaluates two XPIDs together that share the same network architectures
"""
import argparse
import pickle
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import Sequence
import distrax
from flax import struct

from typing import Sequence

from src.envs import make_env
from src.envs.log_wrapper import LogWrapper

from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig as TrainConfigDPD
from src.jaxzsc.e3t.e3t_ippo_overcooked_rnn import TrainConfig as TrainConfigE3T

from src.agents.actors import ActorCriticRNN, ScannedRNN


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


def rollout(rng, env, network, params, params2, hidden_size):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, past_5_sa_pairs, done = carry

        rng, rng_action, rng_step = jax.random.split(rng, 3)

        obs_batch = batchify(last_obs, env.agents, 2)
        batched_sa_pairs = batchify_nested_dics(
            past_5_sa_pairs, env.agents, (1, 2, 5))
        ac_in = (
            obs_batch[np.newaxis, :],
            done[np.newaxis, :],
            batched_sa_pairs,
        )

        hstate, pi, value, other_pi = network.apply(params, hstate, ac_in)
        action = pi.sample(seed=rng_action).squeeze()

        env_act = unbatchify(action, env.agents, 1, env.num_agents)
        env_act = {k: v.flatten().squeeze() for k, v in env_act.items()}

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)
        carry = (rng, env_state, stats, obsv, hstate,
                 past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    init_hstate = ScannedRNN.initialize_carry(
        2, hidden_size)  # Hardcoded
    init_carry = (rng, state, RolloutStats(), obs,
                  init_hstate, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--xpid1',
        type=str,
        default=None,
        help='First XPID')

    parser.add_argument(
        '--xpid2',
        type=str,
        default=None,
        help='First XPID')
    args = parser.parse_args()

    save_dir1 = f"checkpoints/{args.xpid1}"
    save_dir2 = f"checkpoints/{args.xpid2}"

    with open(f"{save_dir1}/config.pckl", 'rb') as f:
        loaded_dict = pickle.load(f)

    if "E3T" in args.xpid1:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in args.xpid1:
        config = TrainConfigDPD(**loaded_dict)

    with open(f"{save_dir1}/params.pt", 'rb') as f:
        params1 = pickle.load(f)["actor_params"]

    with open(f"{save_dir2}/params.pt", 'rb') as f:
        params2 = pickle.load(f)["actor_params"]

    env = make_env(
        "overcooked-v1", {"layout": config.layout_name})

    network = ActorCriticRNN(env.action_space("agent_0").n, config=config)
    init_hstate1 = ScannedRNN.initialize_carry(
        1024, config.gru_hidden_dim)

    init_hstate2 = ScannedRNN.initialize_carry(
        1024, config.gru_hidden_dim)

    rng = jax.random.PRNGKey(0)
    rng = jax.random.split(rng, 1024)
    reward, len = jax.vmap(rollout, in_axes=(0, None, None, None, None, 0, 0))(
        rng,
        env,
        network,
        params1,
        params2,
        init_hstate1,
        init_hstate2,
    )
    reward2, len = jax.vmap(rollout, in_axes=(0, None, None, None, None, 0, 0))(
        rng,
        env,
        network,
        params2,
        params1,
        init_hstate2,
        init_hstate1,
    )

    print(((reward + reward2) / 2).mean())


if __name__ == '__main__':
    main()
