""" 
Based on PureJaxRL Implementation of PPO
"""

import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax import serialization
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper

import pickle
import os
import jaxmarl
from jaxmarl.wrappers.baselines import LogWrapper
from jaxmarl.environments.overcooked import overcooked_layouts
from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer
import hydra
from omegaconf import OmegaConf
import wandb
import matplotlib.pyplot as plt

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)
    

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def get_rollout(train_state, config):
    env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])
    # env_params = env.default_params
    # env = LogWrapper(env)

    network = ActorCritic(env.action_space().n, activation=config["ACTIVATION"])
    key = jax.random.PRNGKey(0)
    key, key_r, key_a = jax.random.split(key, 3)

    init_x = jnp.zeros(env.observation_space().shape)
    init_x = init_x.flatten()

    network.init(key_a, init_x)
    network_params = train_state.params

    done = False

    obs, state = env.reset(key_r)
    state_seq = [state]
    while not done:
        key, key_a0, key_a1, key_s = jax.random.split(key, 4)

        # obs_batch = batchify(obs, env.agents, config["NUM_ACTORS"])
        # breakpoint()
        obs = {k: v.flatten() for k, v in obs.items()}

        pi_0, _ = network.apply(network_params, obs["agent_0"])
        pi_1, _ = network.apply(network_params, obs["agent_1"])

        actions = {"agent_0": pi_0.sample(seed=key_a0), "agent_1": pi_1.sample(seed=key_a1)}
        # env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
        # env_act = {k: v.flatten() for k, v in env_act.items()}

        # STEP ENV
        obs, state, reward, done, info = env.step(key_s, state, actions)
        done = done["__all__"]

        state_seq.append(state)

    return state_seq

def batchify(x: dict, agent_list, num_actors):
    x = jnp.stack([x[a] for a in agent_list])
    return x.reshape((num_actors, -1))


def unbatchify(x: jnp.ndarray, agent_list, num_envs, num_actors):
    x = x.reshape((num_actors, num_envs, -1))
    return {a: x[i] for i, a in enumerate(agent_list)}

def make_train(config):
    env = jaxmarl.make(config["ENV_NAME"], **config["ENV_KWARGS"])

    config["NUM_ACTORS"] = env.num_agents * config["NUM_ENVS"]
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"] 
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ACTORS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    
    env = LogWrapper(env)
    
    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):

        # INIT NETWORK
        network = ActorCritic(env.action_space().n, activation=config["ACTIVATION"])
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space().shape)
        
        init_x = init_x.flatten()
        
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
            
        def _load_model(params):
            reward_model_path = config["EXPERT_PARAM_PATH"]
            with open(reward_model_path, "rb") as f:
                pretrain_params = serialization.from_bytes(params, f.read())
            return pretrain_params
        network_params = jax.experimental.io_callback(_load_model, network_params, network_params)
        
            
            
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )
        
        
        
        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0,))(reset_rng)
        
        
        # Collect LOOP
        def _batch_collect_step(runner_state, epoch):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)

                obs_batch = batchify(last_obs, env.agents, config["NUM_ACTORS"])
                
                # pi, value = network.apply(train_state.params, obs_batch)
                # action = pi.sample(seed=_rng)
                # log_prob = pi.log_prob(action)
                # use random action for now
                action = jax.random.randint(_rng, (config["NUM_ACTORS"],), 0, env.action_space().n)
                value = jnp.zeros((config["NUM_ACTORS"],))
                log_prob = jnp.ones((config["NUM_ACTORS"],)) * jnp.log(1/env.action_space().n)
                env_act = unbatchify(action, env.agents, config["NUM_ENVS"], env.num_agents)
                # print(env_act)
                
                env_act = {k:v.flatten() for k,v in env_act.items()}
                # print(env_act)
                # raise ValueError
                
                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0))(
                    rng_step, env_state, env_act
                )
                info = jax.tree_map(lambda x: x.reshape((config["NUM_ACTORS"])), info)
                transition = Transition(
                    batchify(done, env.agents, config["NUM_ACTORS"]).squeeze(),
                    action,
                    value,
                    batchify(reward, env.agents, config["NUM_ACTORS"]).squeeze(),
                    log_prob,
                    obs_batch,
                    info
                    
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )
            
            def save_traj_batch(traj_batch, epoch):
                # if not os.path.exists("data/overcooked"):
                #     os.makedirs("data/overcooked")
                # traj_batch_filepath = f"data/overcooked/step_{epoch}.pkl"
                if not os.path.exists(config["DATA_SAVE_PATH"]):
                    os.makedirs(config["DATA_SAVE_PATH"])
                traj_batch_filepath = f"{config['DATA_SAVE_PATH']}/step_{epoch}.pkl"
                with open(traj_batch_filepath, 'wb') as f:
                    pickle.dump(traj_batch, f)
                print(f"Saved traj_batch to {traj_batch_filepath}, num_envs: {config['NUM_ENVS']}, num_actors: {config['NUM_ACTORS']}, action_dim: {env.action_space().n}")
                print("obs_dim: ",  traj_batch.obs.shape, "episode_len: ", config["NUM_STEPS"],"episode_return: ", traj_batch.info['returned_episode_returns'][-1, :].mean())  
                
                return None

            jax.experimental.io_callback(save_traj_batch, None, traj_batch, epoch)

            # env_state, last_obs, last_done, rng = runner_state
            metric = traj_batch.info

            
            # runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _batch_collect_step, runner_state, jnp.arange(config["COLLECT_EPOCHS"])
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train



@hydra.main(version_base=None, config_path="config", config_name="ippo_ff_overcooked")
def main(config):
    config = OmegaConf.to_container(config) 
    config["ENV_KWARGS"]["layout"] = overcooked_layouts[config["ENV_KWARGS"]["layout"]]
    # print(config["ENV_KWARGS"]["layout"])
    # print(config["EXPERT_PARAM_PATH"])
    rng = jax.random.PRNGKey(30)
    num_seeds = 20
    wandb.init(
        entity=config["ENTITY"],
        project=config["PROJECT"],
        tags=["RNN", config["ENV_NAME"]],
        config=config,
        # mode=config["WANDB_MODE"],
        mode='offline',
        name='IPPO_FF_Overcooked'
    )
    with jax.disable_jit(False):
        train_jit = jax.jit(make_train(config))
        # train_jit = jax.jit(jax.vmap(make_train(config)))
        # rngs = jax.random.split(rng, num_seeds)
        out = train_jit(rng)
    
    # print('** Saving Results **')
    # filename = f'{config["ENV_NAME"]}_cramped_room_new'
    # rewards = out["metrics"]["returned_episode_returns"].mean(-1).reshape((num_seeds, -1))
    # reward_mean = rewards.mean(0)  # mean 
    # reward_std = rewards.std(0) / np.sqrt(num_seeds)  # standard error
    
    # plt.plot(reward_mean)
    # plt.fill_between(range(len(reward_mean)), reward_mean - reward_std, reward_mean + reward_std, alpha=0.2)
    # # compute standard error
    # plt.xlabel("Update Step")
    # plt.ylabel("Return")
    # plt.savefig(f'{filename}.png')

    # # animate first seed
    # train_state = jax.tree_map(lambda x: x[0], out["runner_state"][0])
    # state_seq = get_rollout(train_state, config)
    # viz = OvercookedVisualizer()
    # # agent_view_size is hardcoded as it determines the padding around the layout.
    # viz.animate(state_seq, agent_view_size=5, filename=f"{filename}.gif")

if __name__ == "__main__":
    main()