import copy
from typing import Any

import flax
import jax
import jax.numpy as jnp
import ml_collections as mlc
import optax
from utils.encoders import GCEncoder, encoder_modules
from utils.flax_utils import ModuleDict, TrainState, nonpytree_field
from utils.networks import (
    GCActor,
    GCDiscreteActor,
    GCDiscreteCritic,
    GCValue,
    GCBilinearValue,
    ActorVectorField,
)


class COEAgent(flax.struct.PyTreeNode):

    rng: Any
    network: Any
    config: Any = nonpytree_field()

    @staticmethod
    def bce_loss(pred_logit, target):
        """Compute the BCE loss."""
        log_pred = jax.nn.log_sigmoid(pred_logit)
        log_not_pred = jax.nn.log_sigmoid(-pred_logit)
        loss = -(log_pred * target + log_not_pred * (1 - target))
        return loss

    def critic_loss(self, batch, grad_params):
        """Compute the critic loss."""
        next_goal_key = (
            "next_observations" if self.config["oracle_distill"] else "value_next_goals"
        )
        next_q_logits = self.network.select("critic")(
            batch["observations"],
            goals=batch[next_goal_key],
            actions=batch["actions"],
            params=grad_params,
        )
        next_q_loss = self.bce_loss(next_q_logits, self.config["discount"]).mean()

        goal_key = (
            "value_goal_observations"
            if self.config["oracle_distill"]
            else "value_goals"
        )
        q_logits = self.network.select("critic")(
            batch["observations"],
            goals=batch[goal_key],
            actions=batch["actions"],
            params=grad_params,
        )
        qs = jax.nn.sigmoid(q_logits)

        subg_dist = self.network.select("generator")(
            jnp.concatenate([batch["observations"], batch["actions"]], -1),
            batch[goal_key],
            params=grad_params,
        )
        subg = subg_dist.mode()
        subg_action = jnp.clip(
            self.network.select("actor")(
                subg, batch["value_goals"], params=grad_params
            ).mode(),
            -1,
            1,
        )

        first_q_logits = self.network.select("target_critic")(
            batch["observations"],
            goals=subg,
            actions=batch["actions"],
        )
        second_q_logits = self.network.select("target_critic")(
            subg,
            goals=batch[goal_key],
            actions=subg_action,
        )
        target = jax.nn.sigmoid(first_q_logits) * jax.nn.sigmoid(second_q_logits)

        d = batch["value_offsets"] <= 1
        generator_loss = -((1 - d) * target).mean()
        q_loss = self.bce_loss(q_logits, jax.lax.stop_gradient(target)).mean()

        rand_goal_key = (
            "value_subgoal_observations"
            if self.config["oracle_distill"]
            else "value_subgoal_goals"
        )
        log_prob = subg_dist.log_prob(batch[rand_goal_key])
        reg_loss = -(self.config["beta"] * log_prob).mean()

        total_loss = q_loss + next_q_loss + reg_loss + generator_loss

        if self.config["oracle_distill"]:
            distill_q_logits = self.network.select("oracle_critic")(
                batch["observations"],
                goals=batch["value_goals"],
                actions=batch["actions"],
                params=grad_params,
            )
            distill_loss = self.bce_loss(
                distill_q_logits, jax.lax.stop_gradient(qs)
            ).mean()

            total_loss = total_loss + distill_loss

        return total_loss, {
            "total_loss": total_loss,
            "q_loss": q_loss,
            "reg_loss": reg_loss,
            "q_mean": qs.mean(),
            "q_max": qs.max(),
            "q_min": qs.min(),
        }

    def actor_loss(self, batch, grad_params, rng=None):
        """Compute the actor loss."""

        # DDPG+BC loss.
        dist = self.network.select("actor")(
            batch["observations"], batch["actor_goals"], params=grad_params
        )
        if self.config["const_std"]:
            q_actions = jnp.clip(dist.mode(), -1, 1)
        else:
            q_actions = jnp.clip(dist.sample(seed=rng), -1, 1)
        critic_module = "oracle_critic" if self.config["oracle_distill"] else "critic"
        q1, q2 = self.network.select(critic_module)(
            batch["observations"], batch["actor_goals"], q_actions
        )
        q = jnp.minimum(q1, q2)

        # Normalize Q values by the absolute mean to make the loss scale invariant.
        q_loss = -q.mean() / jax.lax.stop_gradient(jnp.abs(q).mean() + 1e-6)
        log_prob = dist.log_prob(batch["actions"])

        bc_loss = -(self.config["alpha"] * log_prob).mean()

        actor_loss = q_loss + bc_loss

        return actor_loss, {
            "actor_loss": actor_loss,
            "q_loss": q_loss,
            "bc_loss": bc_loss,
            "q_mean": q.mean(),
            "q_abs_mean": jnp.abs(q).mean(),
            "bc_log_prob": log_prob.mean(),
            "mse": jnp.mean((dist.mode() - batch["actions"]) ** 2),
            "std": jnp.mean(dist.scale_diag),
        }

    @jax.jit
    def total_loss(self, batch, grad_params, rng=None):
        """Compute the total loss."""
        info = {}
        rng = rng if rng is not None else self.rng

        critic_loss, critic_info = self.critic_loss(batch, grad_params)
        for k, v in critic_info.items():
            info[f"critic/{k}"] = v

        rng, actor_rng = jax.random.split(rng)
        actor_loss, actor_info = self.actor_loss(batch, grad_params, actor_rng)
        for k, v in actor_info.items():
            info[f"actor/{k}"] = v

        loss = critic_loss + actor_loss
        return loss, info

    def target_update(self, network, module_name):
        """Update the target network."""
        new_target_params = jax.tree_util.tree_map(
            lambda p, tp: p * self.config["tau"] + tp * (1 - self.config["tau"]),
            self.network.params[f"modules_{module_name}"],
            self.network.params[f"modules_target_{module_name}"],
        )
        network.params[f"modules_target_{module_name}"] = new_target_params

    @jax.jit
    def update(self, batch):
        """Update the agent and return a new agent with information dictionary."""
        new_rng, rng = jax.random.split(self.rng)

        def loss_fn(grad_params):
            return self.total_loss(batch, grad_params, rng=rng)

        new_network, info = self.network.apply_loss_fn(loss_fn=loss_fn)
        self.target_update(new_network, "critic")

        return self.replace(network=new_network, rng=new_rng), info

    @jax.jit
    def sample_actions(
        self,
        observations,
        goals=None,
        seed=None,
        temperature=1.0,
    ):
        """Sample actions from the actor."""

        dist = self.network.select("actor")(
            observations, goals, temperature=temperature
        )
        actions = dist.sample(seed=seed)

        actions = jnp.clip(actions, -1, 1)

        return actions

    @classmethod
    def create(
        cls,
        seed,
        example_batch,
        config,
    ):
        """Create a new agent.

        Args:
            seed: Random seed.
            example_batch: Example batch.
            config: Configuration dictionary.
        """
        rng = jax.random.PRNGKey(seed)
        rng, init_rng = jax.random.split(rng, 2)

        ex_observations = example_batch["observations"]
        ex_actions = example_batch["actions"]
        ex_goals = example_batch["actor_goals"]
        ex_times = ex_actions[..., :1]
        action_dim = ex_actions.shape[-1]

        # Define encoders.
        encoders = dict()
        if config["encoder"] is not None:
            raise NotImplementedError
            encoder_module = encoder_modules[config["encoder"]]
            encoders["critic"] = GCEncoder(concat_encoder=encoder_module())
            encoders["actor"] = GCEncoder(concat_encoder=encoder_module())

        critic_def = GCValue(
            hidden_dims=config["value_hidden_dims"],
            layer_norm=config["layer_norm"],
            num_ensembles=2,
            gc_encoder=encoders.get("critic"),
        )
        oracle_critic_def = GCValue(
            hidden_dims=config["value_hidden_dims"],
            layer_norm=config["layer_norm"],
            num_ensembles=2,
            gc_encoder=encoders.get("critic"),
        )

        actor_def = GCActor(
            hidden_dims=config["actor_hidden_dims"],
            action_dim=action_dim,
            layer_norm=config["layer_norm"],
            state_dependent_std=False,
            const_std=config["const_std"],
            gc_encoder=encoders.get("actor"),
        )

        generator_def = GCActor(
            hidden_dims=config["generator_hidden_dims"],
            action_dim=ex_observations.shape[-1],
            layer_norm=config["layer_norm"],
            state_dependent_std=False,
            const_std=True,
            gc_encoder=None,
        )

        ex_critic_goals = ex_observations if config["oracle_distill"] else ex_goals

        network_info = dict(
            critic=(critic_def, (ex_observations, ex_critic_goals, ex_actions)),
            target_critic=(
                copy.deepcopy(critic_def),
                (ex_observations, ex_critic_goals, ex_actions),
            ),
            oracle_critic=(oracle_critic_def, (ex_observations, ex_goals, ex_actions)),
            actor=(
                actor_def,
                (
                    ex_observations,
                    ex_goals,
                ),
            ),
            generator=(
                generator_def,
                (jnp.concatenate([ex_observations, ex_actions], -1), ex_critic_goals),
            ),
        )
        networks = {k: v[0] for k, v in network_info.items()}
        network_args = {k: v[1] for k, v in network_info.items()}

        network_def = ModuleDict(networks)
        network_tx = optax.adam(learning_rate=config["lr"])
        network_params = network_def.init(init_rng, **network_args)["params"]

        network = TrainState.create(network_def, network_params, tx=network_tx)

        params = network_params
        params["modules_target_critic"] = params["modules_critic"]

        config["action_dim"] = action_dim
        return cls(rng, network=network, config=flax.core.FrozenDict(**config))


def get_config():
    config = mlc.ConfigDict(
        dict(
            # Agent hyperparameters.
            agent_name="coe",  # Agent name.
            lr=3e-4,  # Learning rate.
            batch_size=1024,  # Batch size.
            actor_hidden_dims=(
                1024,
                1024,
                1024,
                1024,
            ),  # Actor network hidden dimensions.
            generator_hidden_dims=(
                1024,
                1024,
                1024,
                1024,
            ),  # Generator network hidden dimensions.
            value_hidden_dims=(
                1024,
                1024,
                1024,
                1024,
            ),  # Value network hidden dimensions.
            layer_norm=True,  # Whether to use layer normalization.
            discount=0.999,  # Discount factor.
            tau=0.005,  # Target network update rate.
            action_dim=mlc.config_dict.placeholder(
                int
            ),  # Action dimension (set automatically).
            alpha=0.0,
            const_std=True,
            beta=0.0,
            oracle_distill=False,  # Whether to use oracle distillation (only used when using `oraclerep` environments).
            encoder=mlc.config_dict.placeholder(
                str
            ),  # Visual encoder name (None, 'impala_small', etc.).
            # Dataset hyperparameters.
            dataset_class="GCDataset",  # Dataset class name.
            value_p_curgoal=0.0,  # Probability of using the current state as the value goal.
            value_p_trajgoal=1.0,  # Probability of using a future state in the same trajectory as the value goal.
            value_p_randomgoal=0.0,  # Probability of using a random state as the value goal.
            value_geom_sample=True,  # Whether to use geometric sampling for future value goals.
            actor_p_curgoal=0.0,  # Probability of using the current state as the actor goal.
            actor_p_trajgoal=0.5,  # Probability of using a future state in the same trajectory as the actor goal.
            actor_p_randomgoal=0.5,  # Probability of using a random state as the actor goal.
            actor_geom_sample=True,  # Whether to use geometric sampling for future actor goals.
            gc_negative=False,  # Whether to use '0 if s == g else -1' (True) or '1 if s == g else 0' (False) as reward.
            p_aug=0.0,  # Probability of applying image augmentation.
            frame_stack=mlc.config_dict.placeholder(int),  # Number of frames to stack.
        )
    )
    return config
