import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Callable
from flax.training.train_state import TrainState
import distrax
import gymnax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
import os
from utils import load_config

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: Callable

    @nn.compact
    def __call__(self, x):
        activation = self.activation
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray

def make_train(config, activation):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = 1.0 - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"])) / config["NUM_UPDATES"]
        return config["LR"] * frac

    def train(rng):

        # INIT NETWORK
        network = ActorCritic(env.action_space(env_params).n, activation)
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(optax.clip_by_global_norm(config["MAX_GRAD_NORM"]), optax.adam(config["LR"], eps=1e-5))
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0,0,0,None))(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    # Convert boolean mask to float for calculations
                    done_float = jnp.where(done, 1.0, 0.0)
                    delta = reward + config["GAMMA"] * next_value * (1 - done_float) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done_float) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_losses = jnp.square(value - targets)
                        value_loss = (
                            0.5 * value_losses.mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor = -log_prob * gae
                        
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    (total_loss, (value_loss, loss_actor, entropy)), grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, (total_loss, value_loss, loss_actor, entropy)

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, (total_loss, value_loss, loss_actor, entropy) = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, (total_loss, value_loss, loss_actor, entropy)

            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, ppo_losses_per_epoch = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            
            # shape is (UPDATE_EPOCHS, NUM_MINIBATCHES)
            metric['value_loss'] = ppo_losses_per_epoch[1][-1].mean()
            metric['loss_actor'] = ppo_losses_per_epoch[2][-1].mean()
            metric['entropy'] = ppo_losses_per_epoch[3][-1].mean()
            rng = update_state[-1]

            runner_state = (train_state, env_state, last_obs, rng)

            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        
        return {"runner_state": runner_state, "metrics": metric}

    return train

if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import pickle
    def main(config):
        print('Training: ', config['ENV_NAME'])
        rng = jax.random.PRNGKey(42)
        train = make_train(config, jax.nn.relu)
        
        rngs = jax.random.split(rng, config['NUM_SEEDS'])
        tran_vjit = jax.jit(jax.vmap(train))
        results = tran_vjit(rngs)
        
        # save params
        save_dir = f"../main_results/a2c_results/{config['ENV_NAME']}/a2c/"
        os.makedirs(save_dir, exist_ok=True)

        # get best expert seed (for demos)
        best_seed = np.argmax([np.mean(results['metrics']['returned_episode_returns'][i]) for i in range(config['NUM_SEEDS'])])
        print(f"Best seed: {best_seed}")
        # with open(f'{save_dir}/params.pkl', 'wb') as f:
        #     params = jax.tree_util.tree_map(lambda x: x[best_seed], results['runner_state'][0].params)
        #     pickle.dump(params, f)
        
        # plot returns
        avg_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))  # flatten across steps & envs
        std_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2)).std(axis=0)
        print(f"Final return: {avg_returns_per_update[-1]}")
        plt.plot(avg_returns_per_update, label='Average Return')
        plt.fill_between(range(len(avg_returns_per_update)), avg_returns_per_update-std_returns_per_update, avg_returns_per_update+std_returns_per_update, alpha=0.2)
        best_returns_per_update = results['metrics']["returned_episode_returns"][best_seed].mean(axis=(-1, -2))
        plt.plot(best_returns_per_update, label='Best Seed')
        plt.legend()
        plt.title("Episode Return")
        plt.xlabel("Update")
        plt.ylabel("Average Return")
        plt.savefig(save_dir+'episode_returns.png')
        plt.close()

        entropy_mean = results['metrics']["entropy"].mean(0)
        entropy_std = results['metrics']["entropy"].std(0)
        plt.plot(entropy_mean)
        plt.fill_between(
            range(len(entropy_mean)),
            entropy_mean - entropy_std,
            entropy_mean + entropy_std,
            alpha=0.3,
        )
        plt.title("Entropy")
        plt.xlabel("Update")
        plt.ylabel("Entropy")
        plt.savefig(save_dir+'entropy.png')
        plt.close()

        # plot and save entropy
        entropy = results['metrics']['entropy']
        jnp.save(f"{save_dir}/entropy.npy", entropy)

    config = {
        "LR": 3e-4,
        "NUM_ENVS": 64,
        "NUM_STEPS": 16,
        "TOTAL_TIMESTEPS": 1e7,
        "UPDATE_EPOCHS": 1,
        "NUM_MINIBATCHES": 8,
        "GAMMA": 0.99,
        "GAE_LAMBDA": 0.95,
        "ENT_COEF": 0.01,
        "VF_COEF": 5.0,
        "MAX_GRAD_NORM": 10.0,
        "ACTIVATION": "relu",
        "ENV_NAME": "SpaceInvaders-MinAtar",
        "ANNEAL_LR": True,
        "NUM_SEEDS": 1,
        "HP_STUDY": False,
        "DEBUG": True,
    }
                
    main(config)
