from __future__ import annotations
import os
import copy
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from typing import NamedTuple, Dict, Union, Any, Tuple, Optional, List
from dataclasses import field
import chex
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from flax import traverse_util
from gymnax.wrappers.purerl import LogWrapper
import hydra
from omegaconf import OmegaConf
import gymnax
import flashbax as fbx
import wandb

from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from jaxmarl.environments.overcooked import overcooked_layouts
from jaxmarl.wrappers.baselines_v2 import (
    SMAXLogWrapper,
    MPELogWrapper,
    LogWrapper,
    CTRolloutManager,
)

from safetensors import safe_open

from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn

# Observation Preprocessor
class ObservationPreprocessor:
    def __init__(self, num_all_agents):
        self.num_all_agents = num_all_agents

    @partial(jax.jit, static_argnums=(0,))
    def preprocess_observation(self, observation_vector):
        """Preprocess the observation vector by padding and concatenating features."""
        # Details omitted
        pass

# Observation Encoder: Generates logits for the adjacency matrix
class ObservationEncoder(nn.Module):
    num_nodes: int  # Number of nodes in the graph
    num_layers: int = 2  # Number of layers in the network
    hidden_dim: int = 10  # Hidden units in each layer

    @nn.compact
    def __call__(self, observations):
        """Encode observation vectors."""
        pass

# Gumbel-Softmax for generating an adjacency matrix
class GumbelSoftmaxAdjMatrixModel(nn.Module):
    temperature: float = 0.5

    @nn.compact
    def __call__(self, logits):
        """Generate a soft adjacency matrix using Gumbel-Softmax."""
        pass

# GCN Layer
class GCNLayer(nn.Module):
    c_out: int  # Output feature size

    @nn.compact
    def __call__(self, node_feats, adj_matrix):
        """Graph Convolutional Network (GCN) layer."""
        pass

# Graph Network Readout Layer
class GraphReadout(nn.Module):
    @nn.compact
    def __call__(self, node_feats):
        """Perform readout operation, e.g., mean pooling."""
        pass

# End-to-end GCN model
class End2EndGCN(nn.Module):
    c_out: int
    num_all_agents: int
    temperature: float = 0.5

    def setup(self):
        """Setup components, including observation encoder and Gumbel-Softmax model."""
        pass

    @nn.compact
    def __call__(self, observations):
        """Process observations and produce graph embeddings."""
        pass

# Scanned RNN Module
class ScannedRNN(nn.Module):
    @partial(
        nn.scan,
        variable_broadcast="params",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Apply the RNN in a scanned manner."""
        pass

    @staticmethod
    def initialize_carry(hidden_size, *batch_size):
        """Initialize hidden state."""
        pass

# Agent RNN Model
class AgentRNN(nn.Module):
    action_dim: int
    hidden_dim: int
    init_scale: float

    @nn.compact
    def __call__(self, hidden, obs, dones):
        """Agent RNN processes observations, hidden states, and actions."""
        pass

# Single Agent Model
class AgentModel(nn.Module):
    action_dim: int
    hidden_dim: int
    num_agents: int
    num_all_agents: int
    c_out: int = 16
    temperature: float = 1.0

    def setup(self):
        """Setup the agent model, including GNN and RNN components."""
        pass

    def __call__(self, hidden, obs, dones, agent_id, train_pre):
        """Call the agent model."""
        pass

# Multi-agent combined model
class CombinedModel(nn.Module):
    action_dim: int
    hidden_dim: int
    num_agents: int   # Default number of agents
    num_all_agents: int
    c_out: int = 16
    num_pre_policy_agents: int = 3
    temperature: float = 1.0

    def setup(self):
        """Setup multiple agent models, including pre-policy and standard agent models."""
        pass

    def __call__(self, hidden, obs, dones, train_pre=False):
        """Forward pass for multiple agents."""
        pass

@chex.dataclass(frozen=True)
class Timestep:
    obs: dict
    actions: dict
    rewards: dict
    dones: dict
    avail_actions: dict


class CustomTrainState(TrainState):
    target_network_params: Any
    timesteps: int = 0
    n_updates: int = 0
    grad_steps: int = 0


def make_train(config, env):
    env_agent, env_pre = env
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )

    eps_scheduler = optax.linear_schedule(
        init_value=config["EPS_START"],
        end_value=config["EPS_FINISH"],
        transition_steps=config["EPS_DECAY"] * config["NUM_UPDATES"],
    )

    def get_greedy_actions(q_vals, valid_actions):
        unavail_actions = 1 - valid_actions
        q_vals = q_vals - (unavail_actions * 1e10)
        return jnp.argmax(q_vals, axis=-1)

    # epsilon-greedy exploration
    def eps_greedy_exploration(rng, q_vals, eps, valid_actions):

        rng_a, rng_e = jax.random.split(
            rng
        )  # a key for sampling random actions and one for picking

        greedy_actions = get_greedy_actions(q_vals, valid_actions)

        # pick random actions from the valid actions
        def get_random_actions(rng, val_action):
            return jax.random.choice(
                rng,
                jnp.arange(val_action.shape[-1]),
                p=val_action * 1.0 / jnp.sum(val_action, axis=-1),
            )

        _rngs = jax.random.split(rng_a, valid_actions.shape[0])
        random_actions = jax.vmap(get_random_actions)(_rngs, valid_actions)

        chosed_actions = jnp.where(
            jax.random.uniform(rng_e, greedy_actions.shape)
            < eps,  # pick the actions that should be random
            random_actions,
            greedy_actions,
        )
        return chosed_actions

    def batchify(x: dict):
        return jnp.stack([x[agent] for agent in env_agent.agents], axis=0)

    def unbatchify(x: jnp.ndarray):
        return {agent: x[i] for i, agent in enumerate(env_agent.agents)}

    def train(rng):

        # INIT ENV
        original_seed = rng[0]
        rng, _rng = jax.random.split(rng)

        env_dummpy = env[0]

        wrapped_agent_env = CTRolloutManager(env_agent, batch_size=config["NUM_ENVS"])
        wrapped_pre_env = CTRolloutManager(env_pre, batch_size=config["NUM_ENVS"])


        wrapped_env = CTRolloutManager(env_agent, batch_size=config["NUM_ENVS"])
        test_env = CTRolloutManager(
            env_agent, batch_size=config["TEST_NUM_ENVS"]
        )  # batched env for testing (has different batch size)
        #
        # init_x = (
        #     jnp.zeros(
        #         (1, 1, wrapped_env.obs_size)
        #     ),  # (time_step, batch_size, obs_size)
        #     jnp.zeros((1, 1)),  # (time_step, batch size)
        # )
        # init_hs = ScannedRNN.initialize_carry(
        #     config["HIDDEN_SIZE"], 1
        # )  # (batch_size, hidden_dim)

        network = CombinedModel(action_dim=wrapped_env.max_action_space, hidden_dim=config["HIDDEN_SIZE"],
                                c_out=config["C_OUT"], num_agents=len(wrapped_env.agents),
                                num_all_agents=len(wrapped_env.all_agents), temperature=config['TEMPERATURE'])

        train_pre = False

        def create_agent(rng):

            init_batch_x = (
                jnp.zeros(
                    (len(wrapped_env.agents), 1, 1, wrapped_env.obs_size)
                ),  # (time_step, batch_size, obs_size)
                jnp.zeros((len(wrapped_env.agents), 1, 1)),  # (time_step, batch size)
            )
            init_batch_hs = ScannedRNN.initialize_carry(
                 config["HIDDEN_SIZE"], len(wrapped_env.agents), 1
            )  # (batch_size, hidden_dim)


            network_params = network.init(rng, init_batch_hs, *init_batch_x, False)
            pre_policy_lr = config["PRE_POLICY_LR"]

            agent_lr = config["AGENT_LR"]

            agent_lr_scheduler = optax.linear_schedule(
                init_value=agent_lr,
                end_value=1e-10,
                transition_steps=(config["NUM_EPOCHS"]) * config["NUM_UPDATES"],
            )

            pre_policy_lr_scheduler = optax.linear_schedule(
                init_value=pre_policy_lr,
                end_value=1e-10,
                transition_steps=(config["NUM_EPOCHS"]) * config["NUM_UPDATES"],
            )

            agent_lr, pre_policy_lr = agent_lr_scheduler, pre_policy_lr_scheduler\
                if config.get("LR_LINEAR_DECAY", False) else pre_policy_lr


            def get_optimizer(optimizer_name, lr, config):
                if optimizer_name == 'adam':
                    return optax.chain(
                        optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                        optax.adam(learning_rate=lr)
                    )
                elif optimizer_name == 'sgd':
                    return optax.chain(
                        optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                        optax.sgd(learning_rate=lr, momentum=config["MOMENTUM"])
                    )
                elif optimizer_name == 'radam':
                    return optax.chain(
                        optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                        optax.radam(learning_rate=lr),
                    )
                else:
                    raise ValueError(f"Unsupported optimizer: {optimizer_name}")

            # Get the optimizer instances
            optimizers = {
                'pre_policy_opt': get_optimizer(config['PRE_POLICY_OPT'], pre_policy_lr, config),
                'agent_opt': get_optimizer(config['AGENT_OPT'], agent_lr, config)
            }

            # Define the parameter partitioning function
            param_partitions = traverse_util.path_aware_map(
                lambda path, _: 'pre_policy_opt' if 'pre_policy_agent' in path else 'agent_opt', network_params
            )

            # print(param_partitions)

            # Create the multi-transform optimizer
            tx = optax.multi_transform(optimizers, param_partitions)

            train_state = CustomTrainState.create(
                apply_fn=network.apply,
                params=network_params,
                target_network_params=network_params,
                tx=tx,
            )
            return train_state

        rng, _rng = jax.random.split(rng)
        train_state = create_agent(rng)

        # INIT BUFFER
        # to initalize the buffer is necessary to sample a trajectory to know its strucutre
        def _env_sample_step(env_state, unused):
            rng, key_a, key_s = jax.random.split(
                jax.random.PRNGKey(0), 3
            )  # use a dummy rng here
            key_a = jax.random.split(key_a, env_agent.num_agents)
            actions = {
                agent: wrapped_agent_env.batch_sample(key_a[i], agent)
                for i, agent in enumerate(env_agent.agents)
            }
            avail_actions = wrapped_agent_env.get_valid_actions(env_state.env_state)
            obs, env_state, rewards, dones, infos = wrapped_agent_env.batch_step(
                key_s, env_state, actions
            )
            timestep = Timestep(
                obs=obs,
                actions=actions,
                rewards=rewards,
                dones=dones,
                avail_actions=avail_actions,
            )
            return env_state, timestep

        _, _env_state = wrapped_agent_env.batch_reset(rng)
        _, sample_traj = jax.lax.scan(
            _env_sample_step, _env_state, None, config["NUM_STEPS"]
        )
        sample_traj_unbatched = jax.tree_map(
            lambda x: x[:, 0], sample_traj
        )  # remove the NUM_ENV dim
        buffer_agent = fbx.make_trajectory_buffer(
            max_length_time_axis=config["BUFFER_SIZE"] // config["NUM_ENVS"],
            min_length_time_axis=config["BUFFER_BATCH_SIZE"],
            sample_batch_size=config["BUFFER_BATCH_SIZE"],
            add_batch_size=config["NUM_ENVS"],
            sample_sequence_length=1,
            period=1,
        )
        buffer_agent_state = buffer_agent.init(sample_traj_unbatched)

        ### Pre-policy environment buffer
        def _env_sample_step(env_state, unused):
            rng, key_a, key_s = jax.random.split(
                jax.random.PRNGKey(0), 3
            )  # use a dummy rng here
            key_a = jax.random.split(key_a, env_pre.num_agents)
            actions = {
                agent: wrapped_pre_env.batch_sample(key_a[i], agent)
                for i, agent in enumerate(wrapped_pre_env.agents)
            }
            avail_actions = wrapped_pre_env.get_valid_actions(env_state.env_state)
            obs, env_state, rewards, dones, infos = wrapped_pre_env.batch_step(
                key_s, env_state, actions
            )
            timestep = Timestep(
                obs=obs,
                actions=actions,
                rewards=rewards,
                dones=dones,
                avail_actions=avail_actions,
            )
            return env_state, timestep

        _, _env_state = wrapped_pre_env.batch_reset(rng)
        _, sample_traj = jax.lax.scan(
            _env_sample_step, _env_state, None, config["NUM_STEPS"]
        )
        sample_traj_unbatched = jax.tree_map(
            lambda x: x[:, 0], sample_traj
        )  # remove the NUM_ENV dim
        buffer_pre = fbx.make_trajectory_buffer(
            max_length_time_axis=config["BUFFER_SIZE"] // config["NUM_ENVS"],
            min_length_time_axis=config["BUFFER_BATCH_SIZE"],
            sample_batch_size=config["BUFFER_BATCH_SIZE"],
            add_batch_size=config["NUM_ENVS"],
            sample_sequence_length=1,
            period=1,
        )
        buffer_pre_state = buffer_pre.init(sample_traj_unbatched)


        # TRAINING LOOP
        def _update_step(runner_state, unused):

            train_state, buffer_agent_state, buffer_pre_state, test_state, rng, train_pre = runner_state

            # SAMPLE PHASE
            def _step_env(carry, _):
                hs, last_obs, last_dones, env_state, rng = carry
                rng, rng_a, rng_s = jax.random.split(rng, 3)

                # jax.debug.print("obs {x}", x=last_obs.shape)

                # (num_agents, 1 (dummy time), num_envs, obs_size)
                _obs = batchify(last_obs)[:, np.newaxis]
                _dones = batchify(last_dones)[:, np.newaxis]


                new_hs, q_vals = network.apply(train_state.params, hs, _obs, _dones)

                q_vals = q_vals.squeeze(
                    axis=1
                )  # (num_agents, num_envs, num_actions) remove the time dim

                # explore
                avail_actions = wrapped_env.get_valid_actions(env_state.env_state)

                eps = eps_scheduler(train_state.n_updates)
                _rngs = jax.random.split(rng_a, wrapped_env.num_agents)
                actions = jax.vmap(eps_greedy_exploration, in_axes=(0, 0, None, 0))(
                    _rngs, q_vals, eps, batchify(avail_actions)
                )
                actions = unbatchify(actions)

                # new_obs, new_env_state, rewards, dones, infos = wrapped_env.batch_step(
                #     rng_s, env_state, actions
                # )

                new_obs, new_env_state, rewards, dones, infos = jax.lax.cond(
                    train_pre,
                    lambda _: wrapped_pre_env.batch_step(rng_s, env_state, actions),
                    lambda _: wrapped_agent_env.batch_step(rng_s, env_state, actions),
                    operand=None
                )

                timestep = Timestep(
                    obs=last_obs,
                    actions=actions,
                    rewards=jax.tree_map(lambda x:config.get("REW_SCALE", 1)*x, rewards),
                    dones=dones,
                    avail_actions=avail_actions,
                )
                return (new_hs, new_obs, dones, new_env_state, rng), (timestep, infos)

            # step the env (should be a complete rollout)
            rng, _rng = jax.random.split(rng)

            init_obs, env_state = jax.lax.cond(
                train_pre,
                lambda _: wrapped_pre_env.batch_reset(_rng),
                lambda _: wrapped_agent_env.batch_reset(_rng),
                operand=None
            )

            init_dones = {
                agent: jnp.zeros((config["NUM_ENVS"]), dtype=bool)
                for agent in wrapped_env.agents + ["__all__"]
            }
            init_hs = ScannedRNN.initialize_carry(
                config["HIDDEN_SIZE"], len(wrapped_env.agents), config["NUM_ENVS"]
            )
            expl_state = (init_hs, init_obs, init_dones, env_state)
            rng, _rng = jax.random.split(rng)
            _, (timesteps, infos) = jax.lax.scan(
                _step_env,
                (*expl_state, _rng),
                None,
                config["NUM_STEPS"],
            )

            train_state = train_state.replace(
                timesteps=train_state.timesteps
                + config["NUM_STEPS"] * config["NUM_ENVS"]
            )  # update timesteps count

            # BUFFER UPDATE
            buffer_traj_batch = jax.tree_util.tree_map(
                lambda x: jnp.swapaxes(x, 0, 1)[
                    :, np.newaxis
                ],  # put the batch dim first and add a dummy sequence dim
                timesteps,
            )  # (num_envs, 1, time_steps, ...)

            buffer_agent_state, buffer_pre_state = jax.lax.cond(
                train_pre,
                lambda _: (buffer_agent_state, buffer_pre.add(buffer_pre_state, buffer_traj_batch)),
                lambda _: (buffer_agent.add(buffer_agent_state, buffer_traj_batch), buffer_pre_state),
                operand=None
            )


            # NETWORKS UPDATE
            def _learn_phase(carry, _):

                train_state, rng, train_pre = carry
                rng, _rng = jax.random.split(rng)

                minibatch = jax.lax.cond(
                    train_pre,
                    lambda _: buffer_pre.sample(buffer_pre_state, _rng).experience,
                    lambda _: buffer_agent.sample(buffer_agent_state, _rng).experience,
                    operand=None
                )

                minibatch = jax.tree_map(
                    lambda x: jnp.swapaxes(
                        x[:, 0], 0, 1
                    ),  # remove the dummy sequence dim (1) and swap batch and temporal dims
                    minibatch,
                )  # (max_time_steps, batch_size, ...)

                # preprocess network input
                init_hs = ScannedRNN.initialize_carry(
                    config["HIDDEN_SIZE"],
                    len(wrapped_agent_env.agents),
                    config["BUFFER_BATCH_SIZE"],
                )
                # num_agents, timesteps, batch_size, ...
                _obs = batchify(minibatch.obs)

                _dones = batchify(minibatch.dones)
                _actions = batchify(minibatch.actions)
                #_rewards = batchify(minibatch.rewards)
                _avail_actions = batchify(minibatch.avail_actions)

                # _, q_next_target = jax.vmap(network.apply, in_axes=(None, 0, 0, 0))(
                #     train_state.target_network_params,
                #     init_hs,
                #     _obs,
                #     _dones,
                # )  # (num_agents, timesteps, batch_size, num_actions)

                _, q_next_target = network.apply(train_state.target_network_params, init_hs,
                                                 _obs, _dones)

                def _loss_fn(params):
                    # _, q_vals = jax.vmap(network.apply, in_axes=(None, 0, 0, 0))(
                    #     params,
                    #     init_hs,
                    #     _obs,
                    #     _dones,
                    # )  # (num_agents, timesteps, batch_size, num_actions)

                    _, q_vals = network.apply(params, init_hs, _obs, _dones)

                    # get logits of the chosen actions
                    chosen_action_q_vals = jnp.take_along_axis(
                        q_vals,
                        _actions[..., np.newaxis],
                        axis=-1,
                    ).squeeze(-1)  # (num_agents, timesteps, batch_size,)

                    unavailable_actions = 1 - _avail_actions
                    valid_q_vals = q_vals - (unavailable_actions * 1e10)

                    # get the q values of the next state
                    q_next = jnp.take_along_axis(
                        q_next_target,
                        jnp.argmax(valid_q_vals, axis=-1)[..., np.newaxis],
                        axis=-1,
                    ).squeeze(-1)  # (num_agents, timesteps, batch_size,)

                    vdn_target = (
                        minibatch.rewards["__all__"][:-1]
                        + (
                            1 - minibatch.dones["__all__"][:-1]
                        )  # use next done because last done was saved for rnn re-init
                        * config["GAMMA"]
                        * jnp.sum(q_next, axis=0)[1:]  # sum over agents
                    )

                    chosen_action_q_vals = jnp.sum(chosen_action_q_vals, axis=0)[:-1]
                    loss = jnp.mean(
                        (chosen_action_q_vals - jax.lax.stop_gradient(vdn_target)) ** 2
                    )

                    return loss, chosen_action_q_vals.mean()

                (loss, qvals), grads = jax.value_and_grad(_loss_fn, has_aux=True)(
                    train_state.params
                )


                train_state = train_state.apply_gradients(grads=grads)
                train_state = train_state.replace(
                    grad_steps=train_state.grad_steps + 1,
                )
                return (train_state, rng, train_pre), (loss, qvals)

            rng, _rng = jax.random.split(rng)

            is_learn_time = (
                buffer_agent.can_sample(buffer_agent_state)
            ) & (  # enough experience in buffer
                train_state.timesteps > config["LEARNING_STARTS"]
            ) & (buffer_pre.can_sample(buffer_pre_state))


            (train_state, rng, train_pre), (loss, qvals) = jax.lax.cond(
                is_learn_time,
                lambda train_state, rng, train_pre: jax.lax.scan(
                    _learn_phase, (train_state, rng, train_pre), None, config["NUM_EPOCHS"]
                ),
                lambda train_state, rng, train_pre: (
                    (train_state, rng, train_pre),
                    (
                        jnp.zeros(config["NUM_EPOCHS"]),
                        jnp.zeros(config["NUM_EPOCHS"]),
                    ),
                ),  # do nothing
                train_state,
                rng,
                train_pre
            )

            # update target network
            train_state = jax.lax.cond(
                train_state.n_updates % config["TARGET_UPDATE_INTERVAL"] == 0,
                lambda train_state: train_state.replace(
                    target_network_params=optax.incremental_update(
                        train_state.params,
                        train_state.target_network_params,
                        config["TAU"],
                    )
                ),
                lambda train_state: train_state,
                operand=train_state,
            )

            # UPDATE METRICS

            train_state = train_state.replace(n_updates=train_state.n_updates + 1)

            train_pre = jax.lax.cond(
                train_state.n_updates % config['SWITCH_INTERVAL'] == 0,
                lambda _: jnp.logical_not(train_pre),
                lambda _: train_pre,
                operand=train_pre
            )


            metrics = {
                "env_step": train_state.timesteps,
                "update_steps": train_state.n_updates,
                "grad_steps": train_state.grad_steps,
                "loss": loss.mean(),
                "qvals": qvals.mean(),

            }

            metrics.update(jax.tree_map(lambda x: x.mean(), infos))

            if config.get("TEST_DURING_TRAINING", True):
                rng, _rng = jax.random.split(rng)
                test_state = jax.lax.cond(
                    train_state.n_updates
                    % int(config["NUM_UPDATES"] * config["TEST_INTERVAL"])
                    == 0,
                    lambda _: get_greedy_metrics(_rng, train_state),
                    lambda _: test_state,
                    operand=None,
                )
                metrics.update({"test_" + k: v for k, v in test_state.items()})

            # report on wandb if required
            if config["WANDB_MODE"] != "disabled":

                def callback(metrics, original_seed):
                    if config.get('WANDB_LOG_ALL_SEEDS', False):
                        metrics.update(
                            {f"rng{int(original_seed)}/{k}": v for k, v in metrics.items()}
                        )
                    wandb.log(metrics)

                jax.debug.callback(callback, metrics, original_seed)

            runner_state = (train_state, buffer_agent_state, buffer_pre_state, test_state, rng, train_pre)

            return runner_state, None

        def get_greedy_metrics(rng, train_state):
            """Help function to test greedy policy during training"""
            if not config.get("TEST_DURING_TRAINING", True):
                return None

            params = train_state.params
            def _greedy_env_step(step_state, unused):
                params, env_state, last_obs, last_dones, hstate, rng = step_state
                rng, key_s = jax.random.split(rng)
                _obs = batchify(last_obs)[:, np.newaxis]
                _dones = batchify(last_dones)[:, np.newaxis]
                # hstate, q_vals = jax.vmap(network.apply, in_axes=(None, 0, 0, 0))(
                #     params,
                #     hstate,
                #     _obs,
                #     _dones,
                # )

                hstate, q_vals = network.apply(params, hstate, _obs, _dones, train_pre)

                q_vals = q_vals.squeeze(axis=1)
                valid_actions = test_env.get_valid_actions(env_state.env_state)
                actions = get_greedy_actions(q_vals, batchify(valid_actions))
                actions = unbatchify(actions)
                obs, env_state, rewards, dones, infos = test_env.batch_step(
                    key_s, env_state, actions
                )
                step_state = (params, env_state, obs, dones, hstate, rng)
                return step_state, (rewards, dones, infos)

            rng, _rng = jax.random.split(rng)
            init_obs, env_state = test_env.batch_reset(_rng)
            init_dones = {
                agent: jnp.zeros((config["TEST_NUM_ENVS"]), dtype=bool)
                for agent in wrapped_env.agents + ["__all__"]
            }
            rng, _rng = jax.random.split(rng)
            hstate = ScannedRNN.initialize_carry(
                config["HIDDEN_SIZE"], len(wrapped_agent_env.agents), config["TEST_NUM_ENVS"]
            )  # (n_agents*n_envs, hs_size)
            step_state = (
                params,
                env_state,
                init_obs,
                init_dones,
                hstate,
                _rng,
            )
            step_state, (rewards, dones, infos) = jax.lax.scan(
                _greedy_env_step, step_state, None, config["TEST_NUM_STEPS"]
            )
            metrics = jax.tree_map(
                lambda x: jnp.nanmean(
                    jnp.where(
                        infos["returned_episode"],
                        x,
                        jnp.nan,
                    )
                ),
                infos,
            )
            return metrics

        rng, _rng = jax.random.split(rng)
        test_state = get_greedy_metrics(_rng, train_state)

        # train
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, buffer_agent_state, buffer_pre_state, test_state, _rng, train_pre)

        runner_state, metrics = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )

        return {"runner_state": runner_state, "metrics": metrics}

    return train


def env_from_config(config):
    env_name = config["ENV_NAME"]
    # smax init neeeds a scenario
    if "smax" in env_name.lower():
        config["ENV_KWARGS"]["scenario"] = map_name_to_scenario(config["MAP_NAME"])

        env_name = f"{config['ENV_NAME']}_{config['MAP_NAME']}"
        env1 = make(config["ENV_NAME"], **config["ENV_KWARGS"])
        env1 = SMAXLogWrapper(env1)

        config2 = copy.deepcopy(config)
        config2["ENV_KWARGS"]["TRAIN_PRE"] = True
        env2 = make(config2["ENV_NAME"], **config["ENV_KWARGS"])
        env2 = SMAXLogWrapper(env2)
        env = (env1, env2)



    # overcooked needs a layout
    elif "overcooked" in env_name.lower():
        env_name = f"{config['ENV_NAME']}_{config['ENV_KWARGS']['layout']}"
        config["ENV_KWARGS"]["layout"] = overcooked_layouts[
            config["ENV_KWARGS"]["layout"]
        ]
        env = make(config["ENV_NAME"], **config["ENV_KWARGS"])
        env = LogWrapper(env)
    elif "mpe" in env_name.lower():
        env = make(config["ENV_NAME"], **config["ENV_KWARGS"])
        env = MPELogWrapper(env)
    else:
        env = make(config["ENV_NAME"], **config["ENV_KWARGS"])
        env = LogWrapper(env)
    return env, env_name


def single_run(config):

    config = {**config, **config["alg"]}  # merge the alg config with the main config
    print("Config:\n", OmegaConf.to_yaml(config))

    alg_name = config.get("ALG_NAME", "vdn_rnn")
    env, env_name = env_from_config(copy.deepcopy(config))

    wandb.init(
        entity=config["ENTITY"],
        project=config["PROJECT"],
        tags=[
            alg_name.upper(),
            env_name.upper(),
            f"jax_{jax.__version__}",
        ],
        name=f"{alg_name}_{env_name}",
        config=config,
        mode=config["WANDB_MODE"],
    )

    rng = jax.random.PRNGKey(config["SEED"])

    rngs = jax.random.split(rng, config["NUM_SEEDS"])
    train_vjit = jax.jit(jax.vmap(make_train(config, env)))
    outs = jax.block_until_ready(train_vjit(rngs))

    # save params
    if config.get("SAVE_PATH", None) is not None:
        from jaxmarl.wrappers.baselines import save_params

        model_state = outs["runner_state"][0]
        save_dir = os.path.join(config["SAVE_PATH"], env_name)
        os.makedirs(save_dir, exist_ok=True)
        OmegaConf.save(
            config,
            os.path.join(
                save_dir, f'{alg_name}_{env_name}_seed{config["SEED"]}_config.yaml'
            ),
        )

        for i, rng in enumerate(rngs):
            params = jax.tree_map(lambda x: x[i], model_state.params)
            save_path = os.path.join(
                save_dir,
                f'{alg_name}_{env_name}_seed{config["SEED"]}_vmap{i}.safetensors',
            )
            save_params(params, save_path)


def tune(default_config):
    """Hyperparameter sweep with wandb."""

    default_config = {**default_config, **default_config["alg"]}  # merge the alg config with the main config
    env_name = default_config["ENV_NAME"]
    alg_name = default_config.get("ALG_NAME", "vdn_rnn")
    id = default_config["TUNED_CONFIG_ID"]
    env, env_name = env_from_config(default_config)

    def wrapped_make_train():
        wandb.init(project=default_config["PROJECT"])

        # update the default params
        config = copy.deepcopy(default_config)
        for k, v in dict(wandb.config).items():
            config[k] = v

        print("running experiment with params:", config)

        rng = jax.random.PRNGKey(config["SEED"])
        rngs = jax.random.split(rng, config["NUM_SEEDS"])
        train_vjit = jax.jit(jax.vmap(make_train(config, env)))
        outs = jax.block_until_ready(train_vjit(rngs))



    sweep_config = {
        "name": f"tuned_{alg_name}_{env_name}_config_{id}",
        "method": "grid",
        "metric": {
            "name": "test_returned_won_episode",
            "goal": "maximize",
        },

        "parameters": {
            'SEED': {'values': [0,1,2,3,4,5,6,7,8,9]}
        }
    }

    wandb.login()
    sweep_id = wandb.sweep(
        sweep_config, entity=default_config["ENTITY"], project=default_config["PROJECT"]
    )
    # if default_config['SWEEP_ID'] is not None:
    #     sweep_id = default_config['SWEEP_ID']

    wandb.agent(sweep_id, wrapped_make_train, entity=default_config["ENTITY"],
                project=default_config["PROJECT"], count=300)


@hydra.main(version_base=None, config_path="./config", config_name="config")
def main(config):
    config = OmegaConf.to_container(config)
    print("Config:\n", OmegaConf.to_yaml(config))
    if config["HYP_TUNE"]:
        tune(config)
    else:
        single_run(config)


if __name__ == "__main__":
    main()