import jax
import jax.numpy as jnp
import numpy as np
from gym import spaces

from . import agent, jaxutils
from . import ninjax as nj

tree_map = jax.tree_util.tree_map
sg = lambda x: tree_map(jax.lax.stop_gradient, x)


class Hansome(nj.Module):
    # Assumed action space is 15 + n_commands

    def __init__(self, wm, act_space, config):
        self.manager_reward_func = lambda s: wm.heads["reward"](s).mean()[1:]
        if config.critic_type == "vfunction":
            critics = {"manager": agent.VFunction(self.manager_reward_func, config, name="critic")}
        else:
            raise NotImplementedError(config.critic_type)
        self.ac = agent.ImagActorCritic(critics, {"manager": 1.0}, act_space, config, name="ac")
        self.imagine = wm.imagine_carry
        self.config = config

    def initial(self, batch_size):
        return {
            "step": np.zeros((batch_size,), np.int8),
            "command": jnp.zeros((batch_size, self.config.n_commands), jnp.float32),
        }

    def set_m_parameters(self, **kwargs):
        for key, value in kwargs.items():
            self.put(f"m_{key}", jnp.array([value]))

    def policy(self, latent, carry):
        m_unimix = self.get("m_unimix", jnp.array, [1.0])
        m_horizon = self.get("m_horizon", jnp.array, [128])
        m_ent_scale = self.get("m_ent_scale", jnp.array, [1.0])
        update = (carry["step"] % jax.lax.stop_gradient(m_horizon[0])) == 0
        act = self.ac.actor(latent)
        switch = lambda x, y: (jnp.einsum("i,i...->i...", 1 - update.astype(x.dtype), x) + jnp.einsum("i,i...->i...", update.astype(x.dtype), y))
        command = switch(carry["command"], act.get_aco_prior_sample(jax.lax.stop_gradient(m_unimix[0]), jax.lax.stop_gradient(m_ent_scale[0])))
        act.set_prior_sample(command)
        outs = {"action": act}
        carry = {"step": carry["step"] + 1, "command": command}
        return outs, carry

    def train(self, imagine, start, data):
        def loss(start):
            def policy_carry(latent, carry):
                outs, new_carry = self.policy(sg(latent), carry)
                return outs["action"].sample(seed=nj.rng()), new_carry

            carry = self.initial(len(start["is_first"]))
            traj = self.imagine(policy_carry, start, self.config.imag_horizon, carry)
            loss, metrics = self.ac.loss(traj)
            return loss, (traj, metrics)

        mets, (traj, metrics) = self.ac.opt(self.ac.actor, loss, start, has_aux=True)
        metrics.update(mets)
        for key, critic in self.ac.critics.items():
            mets = critic.train(traj, self.ac.actor)
            metrics.update({f"{key}_critic_{k}": v for k, v in mets.items()})
        return traj, metrics

    def report(self, data):
        return {}
