import time
import collections
import copy
import functools
import logging

import haiku as hk
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as onp
import optax
import os
import pickle
import ray
from ray import tune
import rlax
import tree

from algorithms import utils
from algorithms.actor import Actor, ActorOutput
from algorithms.haiku_nets import torso_network
from vec_env import VecFrameStack

AgentOutput = collections.namedtuple(
    'AgentOutput', (
        'state',
        'logits',
        'value',
        'aux_pred',
    )
)

A2CLog = collections.namedtuple(
    'A2CLog', (
        'entropy',
        'value',
        'ret',
        'pg_loss',
        'baseline_loss',
        'state_norm',
        'theta_norm',
        'grad_norm',
        'update_norm',
    )
)

AuxLog = collections.namedtuple(
    'AuxLog', (
        'pred',
        'aux_loss',
        'grad_norm',
        'update_norm',
    )
)


class ActorCriticAuxNet(hk.RNNCore):
    def __init__(self, num_actions, torso_type, torso_kwargs, use_rnn, head_layers,
                 deconv_layers, stop_ac_grad, scale, name=None):

        super(ActorCriticAuxNet, self).__init__(name=name)
        self._num_actions = num_actions
        self._torso_type = torso_type
        self._torso_kwargs = torso_kwargs
        self._use_rnn = use_rnn
        if use_rnn:
            core = hk.GRU(512, w_h_init=hk.initializers.Orthogonal())
        else:
            core = hk.IdentityCore()
        self._core = hk.ResetCore(core)
        self._head_layers = head_layers
        self._deconv_layers = deconv_layers
        self._stop_ac_grad = stop_ac_grad
        self._scale = scale

    def __call__(self, timesteps, state):
        torso_net = torso_network(self._torso_type, **self._torso_kwargs)
        torso_output = torso_net(timesteps.observation)

        if self._use_rnn:
            core_input = jnp.concatenate([
                hk.one_hot(timesteps.action_tm1, self._num_actions),
                timesteps.reward[:, None],
                torso_output
            ], axis=1)
            should_reset = timesteps.first
            core_output, next_state = hk.dynamic_unroll(self._core, (core_input, should_reset), state)
        else:
            core_output, next_state = torso_output, state

        channel, h, w, kernel, stride = self._deconv_layers
        pc_hidden = jax.nn.relu(hk.Linear(channel*h*w)(core_output))
        pc_hidden = jnp.reshape(pc_hidden, pc_hidden.shape[:-1] + (h, w, channel))
        scale = self._scale / jnp.sqrt(2.)  # Dueling architecture.
        pc_hidden = scale * pc_hidden + lax.stop_gradient((1. - scale) * pc_hidden)
        v = hk.Conv2DTranspose(1, kernel_shape=kernel, stride=stride, padding='VALID')(pc_hidden)
        a = hk.Conv2DTranspose(self._num_actions, kernel_shape=kernel, stride=stride, padding='VALID')(pc_hidden)
        aux_pred = v + a - a.mean(axis=-1, keepdims=True)

        main_head = []
        if self._stop_ac_grad:
            main_head.append(lax.stop_gradient)
        for dim in self._head_layers:
            main_head.append(hk.Linear(dim))
            main_head.append(jax.nn.relu)
        h = hk.Sequential(main_head)(core_output)
        logits = hk.Linear(self._num_actions)(h)
        value = hk.Linear(1)(h)

        agent_output = AgentOutput(
            state=core_output,
            logits=logits,
            value=value.squeeze(-1),
            aux_pred=aux_pred,
        )
        return agent_output, next_state

    def initial_state(self, batch_size):
        return self._core.initial_state(batch_size)


class Agent(object):
    def __init__(self, ob_space, action_space, torso_type, torso_kwargs, head_layers, deconv_layers, use_rnn,
                 stop_ac_grad, scale):
        self._ob_space = ob_space
        num_actions = action_space.n
        _, self._initial_state_apply_fn = hk.without_apply_rng(
            hk.transform(lambda batch_size: ActorCriticAuxNet(
                num_actions=num_actions,
                torso_type=torso_type,
                torso_kwargs=torso_kwargs,
                use_rnn=use_rnn,
                head_layers=head_layers,
                deconv_layers=deconv_layers,
                stop_ac_grad=stop_ac_grad,
                scale=scale,
            ).initial_state(batch_size))
        )
        self._init_fn, self._apply_fn = hk.without_apply_rng(
            hk.transform(lambda inputs, state: ActorCriticAuxNet(
                num_actions=num_actions,
                torso_type=torso_type,
                torso_kwargs=torso_kwargs,
                use_rnn=use_rnn,
                head_layers=head_layers,
                deconv_layers=deconv_layers,
                stop_ac_grad=stop_ac_grad,
                scale=scale,
            )(inputs, state))
        )

    @functools.partial(jax.jit, static_argnums=(0,))
    def init(self, rngkey):
        dummy_observation = tree.map_structure(lambda t: jnp.zeros(t.shape, t.dtype), self._ob_space)
        dummy_observation = tree.map_structure(lambda t: t[None], dummy_observation)
        dummy_reward = jnp.zeros((1,), dtype=jnp.float32)
        dummy_action = jnp.zeros((1,), dtype=jnp.int32)
        dummy_discount = jnp.zeros((1,), dtype=jnp.float32)
        dummy_first = jnp.zeros((1,), dtype=jnp.float32)
        dummy_state = self.initial_state(None)
        dummy_input = ActorOutput(
            rnn_state=dummy_state,
            action_tm1=dummy_action,
            reward=dummy_reward,
            discount=dummy_discount,
            first=dummy_first,
            observation=dummy_observation,
        )
        return self._init_fn(rngkey, dummy_input, dummy_state)

    @functools.partial(jax.jit, static_argnums=(0, 1))
    def initial_state(self, batch_size):
        return self._initial_state_apply_fn(None, batch_size)

    @functools.partial(jax.jit, static_argnums=(0,))
    def step(self, rngkey, params, timesteps, states):
        rngkey, subkey = jrandom.split(rngkey)
        timesteps = tree.map_structure(lambda t: t[:, None, ...], timesteps)  # [B, 1, ...]
        agent_output, next_states = jax.vmap(self._apply_fn, (None, 0, 0))(params, timesteps, states)
        agent_output = tree.map_structure(lambda t: t.squeeze(axis=1), agent_output)  # [B, ...]
        action = hk.multinomial(subkey, agent_output.logits, num_samples=1).squeeze(axis=-1)
        return rngkey, action, agent_output, next_states

    def unroll(self, params, timesteps, state):
        return self._apply_fn(params, timesteps, state)  # [T, ...]


def gen_a2c_update_fn(agent, opt_update, gamma, vf_coef, entropy_reg, use_mask):
    def a2c_loss(theta, trajs):
        rnn_states = tree.map_structure(lambda t: t[:, 0], trajs.rnn_state)
        learner_output, _ = jax.vmap(agent.unroll, (None, 0, 0))(theta, trajs, rnn_states)  # [B, T + 1, ...]
        rewards = trajs.reward[:, 1:]
        discounts = trajs.discount[:, 1:] * gamma
        bootstrap_value = learner_output.value[:, -1]
        returns = jax.vmap(rlax.discounted_returns)(rewards, discounts, bootstrap_value)
        advantages = returns - learner_output.value[:, :-1]

        if use_mask:
            masks = trajs.discount[:, :-1]
        else:
            masks = jnp.ones_like(trajs.discount[:, :-1])
        pg_loss = jax.vmap(rlax.policy_gradient_loss)(
            learner_output.logits[:, :-1], trajs.action_tm1[:, 1:], advantages, masks)
        ent_loss = jax.vmap(rlax.entropy_loss)(learner_output.logits[:, :-1], masks)
        baseline_loss = 0.5 * jnp.mean(
            jnp.square(learner_output.value[:, :-1] - lax.stop_gradient(returns)) * masks, axis=1)
        loss = jnp.mean(pg_loss + vf_coef * baseline_loss + entropy_reg * ent_loss)

        state_norm = jnp.sqrt(jnp.sum(jnp.square(learner_output.state), axis=-1))
        a2c_log = A2CLog(
            entropy=-ent_loss,
            value=learner_output.value,
            ret=returns,
            pg_loss=pg_loss,
            baseline_loss=baseline_loss,
            state_norm=state_norm,
            theta_norm=optax.global_norm(theta),
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return loss, a2c_log

    def a2c_update(theta, opt_state, trajs):
        grads, logs = jax.grad(a2c_loss, has_aux=True)(theta, trajs)
        updates, new_opt_state = opt_update(grads, opt_state)
        grad_norm = optax.global_norm(grads)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        return new_theta, new_opt_state, logs

    return a2c_update


def gen_pc_update_fn(agent, opt_update, preprocess, gamma, use_mask):
    def pc_loss(theta, trajs):
        rnn_states = tree.map_structure(lambda t: t[:, 0], trajs.rnn_state)
        agent_output, _ = jax.vmap(agent.unroll, (None, 0, 0))(theta, trajs, rnn_states)  # [B, T+1, ...]

        observations = preprocess(trajs.observation)
        actions = trajs.action_tm1[:, 1:]
        action_value = agent_output.aux_pred
        discounts = trajs.discount[:, 1:] * gamma
        aux_loss = jax.vmap(rlax.pixel_control_loss, (0, 0, 0, 0, None))(
            observations, actions, action_value, discounts, 4)

        if use_mask:
            masks = trajs.discount[:, :-1]
        else:
            masks = jnp.ones_like(trajs.discount[:, :-1])
        aux_loss = jnp.mean(jnp.sum(aux_loss, axis=(2, 3)) * masks)

        aux_log = AuxLog(
            pred=action_value,
            aux_loss=aux_loss,
            grad_norm=0.,  # placeholder
            update_norm=0.,  # placeholder
        )
        return aux_loss, aux_log

    def pc_update(theta, opt_state, trajs):
        grads, logs = jax.grad(pc_loss, has_aux=True)(theta, trajs)
        grad_norm = optax.global_norm(grads)
        updates, new_opt_state = opt_update(grads, opt_state)
        update_norm = optax.global_norm(updates)
        logs = logs._replace(
            grad_norm=grad_norm,
            update_norm=update_norm,
        )
        new_theta = optax.apply_updates(theta, updates)
        return new_theta, new_opt_state, logs

    return pc_update


class Experiment(tune.Trainable):
    def setup(self, config):
        self._config = config
        platform = jax.lib.xla_bridge.get_backend().platform
        logging.warning("Running on %s", platform)
        if config['env_id'] == 'maze':
            import environments.maze.vec_env_utils as maze_vec_env
            self._envs = maze_vec_env.make_vec_env(
                config['nenvs'],
                config['seed'],
                env_kwargs=config['env_kwargs'],
            )
            self._frame_skip = 1
            use_mask = True
            pc_preprocess = lambda x: x / 255.
        elif config['env_id'].startswith('procgen/'):
            import environments.procgen.vec_env_utils as procgen_vec_env
            env_id = config['env_id'][8:]
            self._envs = procgen_vec_env.make_vec_env(
                env_id,
                config['nenvs'],
                env_kwargs=config['env_kwargs'],
            )
            self._frame_skip = 1
            use_mask = False
            pc_preprocess = lambda x: x / 255.
        elif config['env_id'].startswith('dmlab/'):
            import environments.dmlab.vec_env_utils as dmlab_vec_env
            env_id = config['env_id'][6:]
            gpu_id = ray.get_gpu_ids()[0]
            env_kwargs = copy.deepcopy(config['env_kwargs'])
            env_kwargs['gpuDeviceIndex'] = gpu_id
            self._envs = dmlab_vec_env.make_vec_env(
                env_id, config['cache'], config['noop_max'], config['nenvs'], config['seed'], env_kwargs)
            self._frame_skip = 4
            use_mask = True
            pc_preprocess = lambda x: x / 255.
            scale = 1. / onp.sqrt((72 / 4) * (96 / 4))
        elif config['env_id'][-14:] == 'NoFrameskip-v4':
            import environments.atari.vec_env_utils as atari_vec_env
            envs = atari_vec_env.make_vec_env(
                config['env_id'],
                config['nenvs'],
                config['seed'],
            )
            if config['use_rnn']:
                self._envs = envs
            else:
                self._envs = VecFrameStack(envs, 4)
            self._frame_skip = 4
            use_mask = True
            pc_preprocess = lambda x: x[..., 2:-2, 2:-2, -1:] / 255.
            scale = 1. / 20.
        else:
            raise KeyError
        self._nsteps = config['nsteps']

        if not config['scale_gradient']:
            scale = 1.

        jax_seed = onp.random.randint(2 ** 31 - 1)
        self._rngkey = jrandom.PRNGKey(jax_seed)

        agent = Agent(
            ob_space=self._envs.observation_space,
            action_space=self._envs.action_space,
            torso_type=config['torso_type'],
            torso_kwargs=config['torso_kwargs'],
            use_rnn=config['use_rnn'],
            head_layers=config['head_layers'],
            deconv_layers=config['deconv_layers'],
            stop_ac_grad=config['stop_ac_grad'],
            scale=scale,
        )
        self._actor = Actor(self._envs, agent, self._nsteps)

        if config['a2c_opt_type'] == 'adam':
            a2c_opt = optax.adam(**config['a2c_opt_kwargs'])
        elif config['a2c_opt_type'] == 'rmsprop':
            a2c_opt = optax.rmsprop(**config['a2c_opt_kwargs'])
        else:
            raise KeyError
        if config['max_a2c_grad_norm'] > 0:
            a2c_opt = optax.chain(
                optax.clip_by_global_norm(config['max_a2c_grad_norm']),
                a2c_opt,
            )
        a2c_opt_init, a2c_opt_update = a2c_opt
        if config['aux_opt_type'] == 'adam':
            aux_opt_kwargs = config['aux_opt_kwargs'].copy()
            learning_rate = aux_opt_kwargs.pop('learning_rate')
            aux_opt = optax.chain(
                optax.scale_by_adam(**aux_opt_kwargs),
                optax.scale(-learning_rate),
            )
        elif config['aux_opt_type'] == 'rmsprop':
            aux_opt = optax.rmsprop(**config['aux_opt_kwargs'])
        else:
            raise KeyError
        if config['max_aux_grad_norm'] > 0:
            aux_opt = optax.chain(
                optax.clip_by_global_norm(config['max_aux_grad_norm']),
                aux_opt,
            )
        aux_opt_init, aux_opt_update = aux_opt

        a2c_update_fn = gen_a2c_update_fn(
            agent=agent,
            opt_update=a2c_opt_update,
            gamma=config['gamma'],
            vf_coef=config['vf_coef'],
            entropy_reg=config['entropy_reg'],
            use_mask=use_mask,
        )

        aux_update_fn = gen_pc_update_fn(
            agent=agent,
            opt_update=aux_opt_update,
            preprocess=pc_preprocess,
            gamma=config['pc_gamma'],
            use_mask=use_mask,
        )
        self._a2c_update_fn = jax.jit(a2c_update_fn)
        self._aux_update_fn = jax.jit(aux_update_fn)

        self._rngkey, subkey = jrandom.split(self._rngkey)
        self._theta = agent.init(subkey)
        self._a2c_opt_state = a2c_opt_init(self._theta)
        self._aux_opt_state = aux_opt_init(self._theta)

        self._epinfo_buf = collections.deque(maxlen=100)
        self._num_iter = 0
        self._num_frames = 0
        self._tstart = time.time()

    def step(self):
        t0 = time.time()
        rngkey = self._rngkey
        theta = self._theta
        num_frames_this_iter = 0
        for _ in range(self._config['log_interval']):
            rngkey, trajs, epinfos = self._actor.rollout(rngkey, theta)
            self._epinfo_buf.extend(epinfos)

            trajs = jax.device_put(trajs)
            theta, self._a2c_opt_state, a2c_log = self._a2c_update_fn(
                theta, self._a2c_opt_state, trajs)
            theta, self._aux_opt_state, aux_log = self._aux_update_fn(
                theta, self._aux_opt_state, trajs)

            self._num_iter += 1
            num_frames_this_iter += self._config['nenvs'] * self._nsteps * self._frame_skip
        self._rngkey = rngkey
        self._theta = theta
        self._num_frames += num_frames_this_iter

        a2c_log = jax.device_get(a2c_log)
        aux_log = jax.device_get(aux_log)
        ev = utils.explained_variance(a2c_log.value[:, :-1].flatten(), a2c_log.ret.flatten())
        log = {
            'label': self._config['label'],
            'episode_return': onp.mean([epinfo['r'] for epinfo in self._epinfo_buf]),
            'episode_length': onp.mean([epinfo['l'] for epinfo in self._epinfo_buf]),
            'entropy': a2c_log.entropy.mean(),
            'explained_variance': ev,
            'pg_loss': a2c_log.pg_loss.mean(),
            'baseline_loss': a2c_log.baseline_loss.mean(),
            'value_mean': a2c_log.value.mean(),
            'value_std': a2c_log.value.std(),
            'return_mean': a2c_log.ret.mean(),
            'return_std': a2c_log.ret.std(),
            'state_norm': onp.mean(a2c_log.state_norm),
            'a2c_grad_norm': a2c_log.grad_norm,
            'a2c_update_norm': a2c_log.update_norm,
            'param_norm': a2c_log.theta_norm,
            'aux_loss': aux_log.aux_loss.mean(),
            'aux_grad_norm': aux_log.grad_norm,
            'aux_update_norm': aux_log.update_norm,
            'num_iterations': self._num_iter,
            'num_frames': self._num_frames,
            'fps': num_frames_this_iter / (time.time() - t0),
        }
        return log

    def _save(self, tmp_checkpoint_dir):
        theta = jax.device_get(self._theta)
        a2c_opt_state = jax.device_get(self._a2c_opt_state)
        aux_opt_state = jax.device_get(self._aux_opt_state)
        checkpoint_path = os.path.join(tmp_checkpoint_dir, 'model.chk')
        with open(checkpoint_path, 'wb') as checkpoint_file:
            pickle.dump((theta, a2c_opt_state, aux_opt_state), checkpoint_file)
        return checkpoint_path

    def _restore(self, checkpoint):
        with open(checkpoint, 'rb') as checkpoint_file:
            theta, opt_state = pickle.load(checkpoint_file)
        self._theta = theta
        self._opt_state = opt_state


if __name__ == '__main__':
    config = {
        'label': 'a2c-pc',
        'env_id': 'BreakoutNoFrameskip-v4',
        'env_kwargs': {},

        'torso_type': 'atari_shallow',
        'torso_kwargs': {
            'dense_layers': (),
        },
        'use_rnn': False,
        'head_layers': (512,),
        # 'deconv_layers': (32, 8, 11, 4, 2),  # For DM Lab
        'deconv_layers': (32, 9, 9, 4, 2),  # For Atari
        'stop_ac_grad': True,

        'nenvs': 16,
        'nsteps': 20,
        'gamma': 0.99,
        'vf_coef': 0.5,
        'entropy_reg': 0.01,

        'a2c_opt_type': 'rmsprop',
        'a2c_opt_kwargs': {
            'learning_rate': 7E-4,
            'decay': 0.99,
            'eps': 1E-5,
        },
        'max_a2c_grad_norm': 0.5,

        'aux_opt_type': 'adam',
        'aux_opt_kwargs': {
            'learning_rate': 7E-4,
            'b1': 0.,
            'b2': 0.99,
            'eps_root': 1E-5,
        },
        'max_aux_grad_norm': 0.,

        'log_interval': 100,
        'seed': 42,

        'pc_gamma': 0.9,
    }
    analysis = tune.run(
        Experiment,
        name='debug',
        config=config,
        stop={
            'num_frames': 200 * 10 ** 6,
        },
        resources_per_trial={
            'gpu': 1,
        },
    )
