import jax
from brax_networks import Actor
from flax.training.train_state import TrainState
from common import Transition
from wrappers import (
    LogWrapper,
    ILBraxGymnaxWrapper,
    VecEnv,
    ClipAction,
)
import flax.linen as nn
import matplotlib.pyplot as plt
import optax
import jax.numpy as jnp
import pickle
import os

def make_train(config, sample_expert_transitions):
    env, env_params = ILBraxGymnaxWrapper(config["ENV_NAME"], backend=config['BACKEND']), None
    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)

    def train(rng):
        network = Actor(env.action_space(env_params).shape[0], activation=nn.relu)
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["LR"], eps=1e-5),
        )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        def _update_step(update_state, unused):
            train_state, rng = update_state

            rng, rng_sample = jax.random.split(rng)
            expert_batch, _ = sample_expert_transitions(config["MINIBATCH_SIZE"], rng_sample)

            def _bc_loss(params, expert_batch):
                pi = network.apply(params, expert_batch.obs)
                log_prob = pi.log_prob(expert_batch.action)

                return -log_prob.mean()

            grad_fn = jax.value_and_grad(_bc_loss, has_aux=False)
            bc_loss, grads = grad_fn(train_state.params, expert_batch)
            train_state = train_state.apply_gradients(grads=grads)
            
            return (train_state, rng), bc_loss

        update_state = (train_state, rng)
        update_state, bc_loss = jax.lax.scan(_update_step, update_state, None, config["UPDATE_EPOCHS"])

        train_state, rng = update_state

        def _env_step(runner_state, unused):
            train_state, env_state, last_obs, rng = runner_state

            # SELECT ACTION
            rng, _rng = jax.random.split(rng)
            pi = network.apply(train_state.params, last_obs)
            action = pi.sample(seed=_rng)
            log_prob = pi.log_prob(action)

            # STEP ENV
            rng, _rng = jax.random.split(rng)
            rng_step = jax.random.split(_rng, config["NUM_ENVS"])
            obsv, env_state, reward, done, info = env.step(
                rng_step, env_state, action, env_params
            )                

            transition = Transition(
                done, action, None, reward, log_prob, last_obs, obsv, obsv, info   
            )
            runner_state = (train_state, env_state, obsv, rng)
            return runner_state, transition

        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)
        runner_state = (train_state, env_state, obsv, rng)
        _, transitions = jax.lax.scan(_env_step, runner_state, None, length=1000)

        return {"train_state": train_state, "info": transitions.info, "bc_loss": bc_loss}
        
    return train

"""
Expert Replay Buffer
"""
def make_expert_transitions(config):
    # load expert transitions
    if config['BACKEND'] == 'positional':
        expert_transitions = pickle.load(open(f"../experts_new/{config['ENV_NAME']}/transitions_sorted.pkl", 'rb'))
    else:
        expert_transitions = pickle.load(open(f"../experts_mjx/{config['ENV_NAME']}/transitions_sorted.pkl", 'rb'))
    expert_transitions = jax.tree_util.tree_map(
        lambda x: x[:,:config["N_EXPERT_TRAJS"]], expert_transitions
    )
    if config["SUB_SAMPLE_RATE"] > 1:
        expert_transitions_subset = jax.tree_util.tree_map(
            lambda x: x[::config["SUB_SAMPLE_RATE"]], expert_transitions
        )
        expert_transitions = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y[-1:]]),
            expert_transitions_subset, expert_transitions
        )  
    
    print('Returns:', expert_transitions.info['returned_episode_returns'][-1, :])
    print(f'Mean: {jnp.mean(expert_transitions.info["returned_episode_returns"][-1, :])} Std: {jnp.std(expert_transitions.info["returned_episode_returns"][-1, :])}')  
    print('Obs shape:', expert_transitions.unnorm_obs.shape)
    print('Next Obs shape:', expert_transitions.unnorm_next_obs.shape)
    print('Action shape:', expert_transitions.action.shape)

    # flatten transitions
    expert_transitions = jax.tree_util.tree_map(lambda x: x.swapaxes(1,0).reshape((-1,) + x.shape[2:]), expert_transitions)
    expert_size = expert_transitions.done.shape[0]

    def sample_expert_transitions(batch_size, rng):
        rng, _rng = jax.random.split(rng)
        
        # Sample indices with replacement if needed
        indices = jax.random.randint(_rng, (batch_size,), 0, expert_size)
        
        # Apply sampled indices to all elements in expert_transitions
        batch = jax.tree_util.tree_map(lambda x: x[indices], expert_transitions)

        return batch, rng

    return sample_expert_transitions

if __name__ == "__main__":
    import time
    config = {}
    config['ENV_NAME'] = 'ant'
    config['BACKEND'] = 'mjx'
    config['LR'] = 5*1e-3
    config['NUM_ENVS'] = 16
    config['UPDATE_EPOCHS'] = int(1e4)
    config['NUM_SEEDS'] = 16
    config['MAX_GRAD_NORM'] = 1.0
    config['MINIBATCH_SIZE'] = 512
    config['SUB_SAMPLE_RATE'] = 20
    config['N_EXPERT_TRAJS'] = 10
    config['DISC_INP'] = 'sa'

    envs = ['ant', 'halfcheetah', 'walker2d', 'hopper']

    def main(config):
        rng = jax.random.PRNGKey(0)
        rngs = jax.random.split(rng, config['NUM_SEEDS'])
        sample_expert_transitions = make_expert_transitions(config)
        sample_expert_transitions = jax.jit(sample_expert_transitions, static_argnums=(0,))
        train = jax.jit(jax.vmap(make_train(config, sample_expert_transitions), in_axes=(0)))
        start = time.time()
        results = train(rngs)
        print(f"Time taken: {time.time() - start}")

        # save_dir = f'../runs/IL/{config["ENV_NAME"]}/bc/'
        save_dir = f"../../main_results/{config['ENV_NAME']}/{config['DISC_INP']}/{config['N_EXPERT_TRAJS']}/bc/"
        os.makedirs(save_dir, exist_ok=True)
        
        # plot bc_loss
        plt.plot(results['bc_loss'].mean(0))
        plt.fill_between(range(len(results['bc_loss'].mean(0))), results['bc_loss'].mean(0) - results['bc_loss'].std(0), results['bc_loss'].mean(0) + results['bc_loss'].std(0), alpha=0.2)
        plt.savefig(save_dir + 'bc_loss.png')
        plt.close()

        eval_returns = results['info']['returned_episode_returns'][results['info']['returned_episode']].reshape((config['NUM_SEEDS'], -1))
        jnp.save(save_dir + 'returns.npy', eval_returns)
        mean_rews = eval_returns.mean(axis=(-1,0))
        std_rews = eval_returns.mean(axis=(-1)).std(axis=0)
        print(f'Final return (eval mean, std): {mean_rews} +- {std_rews}')

    for env in envs:
        config['ENV_NAME'] = env
        main(config)
    


             

        





            
            
                
        