"""Evaluates one XPIDs together with all hardcoded agents.
"""
import argparse
import pickle
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import Sequence
import distrax
from flax import struct

from typing import Sequence

from src.envs import make_env


from src.agents.actors import ActorCriticRNN, ScannedRNN


from src.agents.overcooked.agent_policy_wrappers import OvercookedIndependentPolicyWrapper, OvercookedOnionPolicyWrapper, OvercookedPlatePolicyWrapper, OvercookedRandomPolicyWrapper, OvercookedStaticPolicyWrapper
from src.jaxzsc.dpd.dpd_ippo_overcooked_rnn import TrainConfig as TrainConfigDPD
from src.jaxzsc.e3t.e3t_ippo_overcooked_rnn import TrainConfig as TrainConfigE3T
from src.jaxzsc.sp.sp_ippo_overcooked_rnn import TrainConfig as TrainConfigSP

from src.envs.overcooked.augmented_layouts import augmented_layouts as overcooked_layouts

import functools


class RolloutStats(struct.PyTreeNode):
    reward: jax.Array = jnp.asarray(0.0)
    length: jax.Array = jnp.asarray(0)


def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}


# class ScannedRNN(nn.Module):
#     @functools.partial(
#         nn.scan,
#         variable_broadcast="params",
#         in_axes=0,
#         out_axes=0,
#         split_rngs={"params": False},
#     )
#     @nn.compact
#     def __call__(self, carry, x):
#         """Applies the module."""
#         rnn_state = carry
#         ins, resets = x
#         rnn_state = jnp.where(
#             resets[:, np.newaxis],
#             self.initialize_carry(ins.shape[0], ins.shape[1]),
#             rnn_state,
#         )
#         new_rnn_state, y = nn.GRUCell(features=ins.shape[1])(rnn_state, ins)
#         return new_rnn_state, y

#     @staticmethod
#     def initialize_carry(batch_size, hidden_size):
#         # Use a dummy key since the default state init fn is just zeros.
#         cell = nn.GRUCell(features=hidden_size)
#         return cell.initialize_carry(
#             jax.random.PRNGKey(0), (batch_size, hidden_size)
#         )


# class ActorCriticRNN(nn.Module):
#     action_dim: Sequence[int]

#     gru_hidden_dim_size: int = 256
#     fc_dim_size: int = 256

#     embedding_layers: int = 1
#     actor_layers: int = 4
#     critic_layers: int = 4

#     other_agent_prediction: bool = False

#     use_layernorm: bool = False

#     @nn.compact
#     def __call__(self, hidden, x):
#         if self.other_agent_prediction:
#             obs, dones, past_5_sa_pairs = x
#         else:
#             obs, dones, past_5_sa_pairs = x

#         embedding = obs

#         embedding = nn.Dense(
#             self.fc_dim_size * 2, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
#         )(embedding)
#         if self.use_layernorm:
#             embedding = nn.LayerNorm()(embedding)
#         embedding = nn.relu(embedding)

#         for _ in range(self.embedding_layers):
#             embedding = nn.Dense(
#                 self.gru_hidden_dim_size, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
#             )(embedding)
#             if self.use_layernorm:
#                 embedding = nn.LayerNorm()(embedding)
#             embedding = nn.relu(embedding)

#         rnn_in = (embedding, dones)
#         hidden, embedding = ScannedRNN()(hidden, rnn_in)

#         #########
#         # Model of other agent
#         #########
#         if self.other_agent_prediction:
#             b_size, num_time = past_5_sa_pairs['action'].shape[0], past_5_sa_pairs['action'].shape[1]

#             other_graph_vals = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(np.sqrt(
#                 2)), bias_init=constant(0.0))(past_5_sa_pairs['obs'].reshape((b_size, num_time, -1)))
#             other_graph_vals = nn.relu(other_graph_vals)

#             embeddings = nn.Embed(num_embeddings=self.action_dim, features=self.fc_dim_size)(
#                 past_5_sa_pairs['action'].astype(jnp.int32))
#             # remove extraneous dimension
#             embeddings = embeddings.reshape((b_size, num_time, -1))
#             # concatenate along feature dimension
#             other_actor_mean = jnp.concatenate(
#                 [other_graph_vals, embeddings], axis=-1)

#             ####
#             prediction_other = nn.Dense(64, kernel_init=orthogonal(
#                 np.sqrt(2)), bias_init=constant(0.0))(other_actor_mean)
#             if self.use_layernorm:
#                 prediction_other = nn.LayerNorm()(prediction_other)
#             prediction_other = nn.leaky_relu(prediction_other)

#             #####
#             prediction_other = nn.Dense(64, kernel_init=orthogonal(
#                 np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
#             if self.use_layernorm:
#                 prediction_other = nn.LayerNorm()(prediction_other)
#             prediction_other = nn.leaky_relu(prediction_other)

#             #####
#             prediction_other = nn.Dense(64, kernel_init=orthogonal(
#                 np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
#             if self.use_layernorm:
#                 prediction_other = nn.LayerNorm()(prediction_other)
#             prediction_other = nn.leaky_relu(prediction_other)

#             #####
#             prediction_other = nn.Dense(64, kernel_init=orthogonal(
#                 np.sqrt(2)), bias_init=constant(0.0))(prediction_other)
#             if self.use_layernorm:
#                 prediction_other = nn.LayerNorm()(prediction_other)
#             prediction_other = nn.tanh(prediction_other)

#             ####
#             prediction_other = nn.Dense(self.action_dim, kernel_init=orthogonal(
#                 np.sqrt(2)), bias_init=constant(0.0))(prediction_other)

#             prediction_other = prediction_other / \
#                 jnp.sqrt(jnp.sum(prediction_other**2, axis=-1,
#                                  keepdims=True) + 1e-10)  # L2 normalization

#             other_pi = distrax.Categorical(logits=prediction_other)
#             actor_embedding = jnp.concatenate(
#                 [embedding, jax.lax.stop_gradient(prediction_other)], axis=-1)
#         else:
#             other_pi = distrax.Categorical(
#                 logits=jnp.zeros((self.action_dim,)))
#             actor_embedding = embedding

#         #########
#         # Actor
#         #########
#         actor_mean = actor_embedding
#         for _ in range(self.actor_layers):
#             actor_mean = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(2), bias_init=constant(0.0))(
#                 actor_mean
#             )
#             if self.use_layernorm:
#                 actor_mean = nn.LayerNorm()(actor_mean)
#             actor_mean = nn.relu(actor_mean)

#         actor_mean = nn.Dense(
#             self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
#         )(actor_mean)

#         pi = distrax.Categorical(logits=actor_mean)

#         #########
#         # Critic
#         #########
#         critic = embedding
#         for _ in range(self.critic_layers):
#             critic = nn.Dense(self.fc_dim_size, kernel_init=orthogonal(2), bias_init=constant(0.0))(
#                 critic
#             )
#             if self.use_layernorm:
#                 critic = nn.LayerNorm()(critic)
#             critic = nn.relu(critic)

#         critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
#             critic
#         )

#         return hidden, pi, jnp.squeeze(critic, axis=-1), other_pi


def rollout(
    rng,
    env,
    network,
    params,
    init_hstate,
    hardcoded_partner,
    agent_switch: bool = False
):
    def _cond_fn(carry):
        rng, env_state, stats, obsv, hstate1, hstate2, past_5_sa_pairs, done = carry
        return (done != True).any()  # Continue if not done.

    def _body_fn(carry):
        rng, env_state, stats, last_obs, hstate, other_agent_state, past_5_sa_pairs, done = carry

        rng, rng_action, rng_step = jax.random.split(rng, 3)

        if agent_switch:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_0"])
            ac_in = (
                last_obs["agent_0"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs,
                # agent_positions["agent_0"][np.newaxis, np.newaxis, :],
            )
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_1"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)
        else:
            in_past_sa_pairs = jax.tree.map(
                lambda x: x[np.newaxis], past_5_sa_pairs["agent_1"])
            ac_in = (
                last_obs["agent_1"].reshape(-1)[np.newaxis, np.newaxis, :],
                done[np.newaxis, np.newaxis, 0],
                in_past_sa_pairs,
                # agent_positions["agent_1"][np.newaxis, np.newaxis, :],
            )
            action_other, other_agent_state = hardcoded_partner.get_action(
                params=None, obs=last_obs["agent_0"], done=done[0],
                avail_actions=None, hstate=other_agent_state, rng=None, env_state=env_state.env_state)

        hstate, pi, value, other_pi = network.apply(
            params, hstate, ac_in)
        action = pi.sample(seed=rng_action).squeeze()

        if agent_switch:
            env_act = {
                "agent_0": action,
                "agent_1": action_other
            }
        else:
            env_act = {
                "agent_0": action_other,
                "agent_1": action
            }

        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_0']['obs'][:, 1:])
        past_5_sa_pairs['agent_0']['obs'] = past_5_sa_pairs[
            'agent_0']['obs'].at[:, - 1].set(last_obs['agent_0'])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, :-1].set(past_5_sa_pairs['agent_0']['action'][:, 1:])
        past_5_sa_pairs['agent_0']['action'] = past_5_sa_pairs[
            'agent_0']['action'].at[:, -1].set(env_act['agent_0'])

        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, :-1].set(past_5_sa_pairs['agent_1']['obs'][:, 1:])
        past_5_sa_pairs['agent_1']['obs'] = past_5_sa_pairs[
            'agent_1']['obs'].at[:, -1].set(last_obs['agent_1'])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, :-1].set(past_5_sa_pairs['agent_1']['action'][:, 1:])
        past_5_sa_pairs['agent_1']['action'] = past_5_sa_pairs[
            'agent_1']['action'].at[:, -1].set(env_act['agent_1'])

        obsv, env_state, reward, done, info = env.step(
            rng_step, env_state, env_act
        )

        stats = stats.replace(
            reward=stats.reward + reward["agent_0"],
            length=stats.length + 1
        )
        done = batchify(done, env.agents, 2)

        carry = (rng, env_state, stats, obsv, hstate,
                 other_agent_state, past_5_sa_pairs, done.squeeze())
        return carry

    key, key_r = jax.random.split(rng)
    obs, state = env.reset(key_r)

    init_x = jnp.zeros(env.observation_space("agent_0").shape)
    init_x = init_x.flatten()

    past_5_sa_pairs = {
        'agent_0': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        },
        'agent_1': {
            'obs': jnp.zeros((1, 5, init_x.shape[0])),
            'action': jnp.zeros((1, 5, 1))
        }
    }

    past_5_sa_pairs['agent_0']['obs'] = obs['agent_0'][:,
                                                       None].repeat(5, axis=1)
    past_5_sa_pairs['agent_0']['action'] = jnp.ones(
        (1, 5)) * 4
    past_5_sa_pairs['agent_1']['obs'] = obs[
        'agent_1'][:, None].repeat(5, axis=1)
    past_5_sa_pairs['agent_1']['action'] = jnp.ones(
        (1, 5)) * 4

    other_agent_id = 1 if agent_switch else 0
    other_agent_state = hardcoded_partner.init_hstate(
        None, aux_info={"agent_id": other_agent_id})

    init_carry = (
        rng, state, RolloutStats(), obs,
        init_hstate[np.newaxis, ...], other_agent_state, past_5_sa_pairs, jnp.array([False, False]))

    final_carry = jax.lax.while_loop(_cond_fn, _body_fn, init_val=init_carry)
    return final_carry[2].reward.squeeze(), final_carry[2].length.squeeze()


def load_config_and_params(xpid):
    save_dir = f"checkpoints/{xpid}"

    with open(f"{save_dir}/config.pckl", "rb") as f:
        loaded_dict = pickle.load(f)

    if "E3T" in xpid:
        config = TrainConfigE3T(**loaded_dict)
    elif "DPD" in xpid:
        config = TrainConfigDPD(**loaded_dict)
    elif "SP" in xpid:
        config = TrainConfigSP(**loaded_dict)
    else:
        raise ValueError(f"Unknown config type for XPID: {xpid}")

    with open(f"{save_dir}/params.pt", "rb") as f:
        params = pickle.load(f)["actor_params"]

    return config, params


def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--base_xpid", type=str, required=True)
    parser.add_argument("--max_seed", type=int, required=True)
    args = parser.parse_args()

    prefix = args.base_xpid.rsplit("_SEED_", 1)[0]
    task_name = args.base_xpid.split("Overcooked_")[1].rsplit("_SEED_", 1)[0]

    num_seeds = args.max_seed + 1

    config, _ = load_config_and_params(f"{prefix}_SEED_0")

    env = make_env(
        "overcooked-v1", {"layout": config.layout_name})

    network = ActorCriticRNN(
        env.action_space("agent_0").n,
        gru_hidden_dim_size=config.gru_hidden_dim,
        fc_dim_size=config.fc_dim_size,
        embedding_layers=config.embedding_layers,
        actor_layers=config.actor_layers,
        critic_layers=config.critic_layers,
        other_agent_prediction=config.other_agent_prediction,
        use_layernorm=config.use_layernorm,
    )
    rng = jax.random.PRNGKey(0)
    rng = jax.random.split(rng, 1024)

    # Preload all params and hstates
    params_list = []
    hstate_list = []
    for seed in range(num_seeds):
        _, params = load_config_and_params(f"{prefix}_SEED_{seed}")
        params_list.append(params)
        hstate = ScannedRNN.initialize_carry(1024, config.gru_hidden_dim)
        hstate_list.append(hstate)

    indp = OvercookedIndependentPolicyWrapper(
        layout=overcooked_layouts[config.layout_name])
    onin = OvercookedOnionPolicyWrapper(
        layout=overcooked_layouts[config.layout_name])
    plta = OvercookedPlatePolicyWrapper(
        layout=overcooked_layouts[config.layout_name])
    rnda = OvercookedRandomPolicyWrapper(
        layout=overcooked_layouts[config.layout_name])
    stct = OvercookedStaticPolicyWrapper(
        layout=overcooked_layouts[config.layout_name])

    other_agents = [indp, onin, plta, rnda, stct]  #

    results = np.zeros((len(other_agents), num_seeds))

    for j, a in enumerate(other_agents):
        print(a)
        for i in range(num_seeds):
            reward, l = jax.vmap(rollout, in_axes=(0, None, None, None, 0, None, None))(
                rng,
                env,
                network,
                params_list[i],
                hstate_list[i],
                a,
                False
            )
            reward2, l = jax.vmap(rollout, in_axes=(0, None, None, None, 0, None, None))(
                rng,
                env,
                network,
                params_list[i],
                hstate_list[i],
                a,
                True,
            )
            avg_reward = ((reward + reward2)/2).mean()
            results[j, i] = float(avg_reward)

            print(avg_reward)

    # Print result matrix
    print(f"\nResult Matrix for task '{task_name}' (comma-separated):")
    header = f"{task_name}," + ",".join([f"S{j}" for j in range(num_seeds)])
    print(header)
    for i, n in enumerate(other_agents):
        row = f"{n.name}," + \
            ",".join([f"{results[i, j]:.4f}" for j in range(num_seeds)])
        print(row)

    if type(config) == TrainConfigDPD:
        print(f"{args.base_xpid},{config.learnability_function},{config.layout_name}")


if __name__ == '__main__':
    main()
