import os
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, NamedTuple, Any, Callable
from flax.training.train_state import TrainState
import distrax
from utils import load_config
from wrappers import (
    LogWrapper,
    BraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    NormalizeVecReward,
    ClipAction,
)
import pickle
import matplotlib.pyplot as plt

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: Callable

    @nn.compact
    def __call__(self, x):
        activation = self.activation
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))

        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config, activation):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = BraxGymnaxWrapper(config["ENV_NAME"], backend=config['backend']), None

    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)
    if config["NORMALIZE_ENV"]:
        env = NormalizeVecObservation(env)
        env = NormalizeVecReward(env, config["GAMMA"])

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).shape[0], activation=activation
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params
                )
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            metric['entropy'] = loss_info[-1][2].mean()
            metric['actor_loss'] = loss_info[-1][1].mean()
            metric['value_loss'] = loss_info[-1][0].mean()
            rng = update_state[-1]
            if config.get("DEBUG"):

                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )

                jax.debug.callback(callback, metric)

            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train

if __name__ == '__main__':

    def main(config):
        print('Training: ', config['ENV_NAME'])
        rng = jax.random.PRNGKey(config['SEED'])
        train = make_train(config, jax.nn.tanh)
        
        rngs = jax.random.split(rng, config['NUM_SEEDS'])
        tran_vjit = jax.jit(jax.vmap(train))
        results = tran_vjit(rngs)
        
        # save params
        # save_dir = f"./experts_mjx/{config['ENV_NAME']}/"
        save_dir = f"../main_results/{config['ENV_NAME']}/ppo/"
        os.makedirs(save_dir, exist_ok=True)

        # get best expert seed (for demos)
        best_seed = np.argmax([np.mean(results['metrics']['returned_episode_returns'][i]) for i in range(config['NUM_SEEDS'])])
        print(f"Best seed: {best_seed}")
        with open(f'{save_dir}/params.pkl', 'wb') as f:
            params = jax.tree_util.tree_map(lambda x: x[best_seed], results['runner_state'][0].params)
            pickle.dump(params, f)
        env_state = results['runner_state'][1]
        mean, var = env_state.env_state.mean[best_seed][0], env_state.env_state.var[best_seed][0]
        np.save(save_dir+'mean', np.asarray(mean))
        np.save(save_dir+'var', np.asarray(var))
        
        # plot returns
        avg_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2, 0))  # flatten across steps & envs
        std_returns_per_update = results['metrics']["returned_episode_returns"].mean(axis=(-1, -2)).std(axis=0)
        print(f"Final return: {avg_returns_per_update[-1]}")
        plt.plot(avg_returns_per_update, label='Average Return')
        plt.fill_between(range(len(avg_returns_per_update)), avg_returns_per_update-std_returns_per_update, avg_returns_per_update+std_returns_per_update, alpha=0.2)
        best_returns_per_update = results['metrics']["returned_episode_returns"][best_seed].mean(axis=(-1, -2))
        plt.plot(best_returns_per_update, label='Best Seed')
        plt.legend()
        plt.title("Episode Return")
        plt.xlabel("Update")
        plt.ylabel("Average Return")
        plt.savefig(save_dir+'episode_returns.png')
        plt.close()

        # save entropy
        entropy = results['metrics']['entropy']
        # print('Entropy shape:', entropy.shape)
        jnp.save(f"{save_dir}/entropy.npy", entropy)

        # plot entropy
        entropy_mean = results['metrics']['entropy'].mean(0)
        entropy_std = results['metrics']['entropy'].std(0)
        plt.plot(entropy_mean)
        plt.fill_between(range(len(entropy_mean)), entropy_mean-entropy_std, entropy_mean+entropy_std, alpha=0.3)
        plt.title("Entropy")
        plt.xlabel("Update")
        plt.ylabel("Entropy")
        plt.savefig(save_dir+'entropy.png')
        plt.close()

        # # save actor loss
        # actor_loss = results['metrics']['actor_loss']
        # print('Actor loss shape:', actor_loss.shape)
        # jnp.save(f"{save_dir}/actor_loss.npy", actor_loss)

        # # save value loss
        # value_loss = results['metrics']['value_loss']
        # print('Value loss shape:', value_loss.shape)
        # jnp.save(f"{save_dir}/value_loss.npy", value_loss)

    import time
    config = {
        'LR': 3.0e-4,
        'NUM_ENVS': 2048,
        'NUM_STEPS': 10,
        'TOTAL_TIMESTEPS': 50_000_000,
        'ENV_NAME': "halfcheetah",
        'UPDATE_EPOCHS': 4,
        'NUM_MINIBATCHES': 32,
        'GAMMA': 0.99,
        'GAE_LAMBDA': 0.95,
        'CLIP_EPS': 0.2,
        'ENT_COEF': 0.0,
        'VF_COEF': 0.5,
        'MAX_GRAD_NORM': 0.5,
        'ANNEAL_LR': False,
        'NORMALIZE_ENV': True,
        'DEBUG': False,
        'NUM_SEEDS': 16,
        'BACKEND': 'mjx',
        'SEED': 42,
    }
    
    envs = ['halfcheetah']
    config['SEED'] = 42
    for env in envs:
        start = time.time()
        config['ENV_NAME'] = env
        main(config)
        print(f"Time taken: {time.time() - start}")

    
    


