import copy
import time
from typing import Any, Callable, Dict, Tuple

import chex
import flax
import hydra
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax.distributions as tfd
from chex import PRNGKey
from colorama import Fore, Style
from flax.core.frozen_dict import FrozenDict
from flax.core.frozen_dict import FrozenDict as Params
from jax import tree
from jumanji.types import TimeStep
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint

from mava.evaluator import ActorState, COMPASSEvalActFn, _EvalEnvStepState
from mava.networks import RecurrentCOMPASSActor as Actor
from mava.networks import ScannedRNN
from mava.networks.distributions import IdentityTransformation
from mava.networks.heads import DiscreteLogitHead
from mava.systems.ppo.types import (
    HiddenStates,
)
from mava.types import Action, CompassRecActorApply, MarlEnv, Metrics
from mava.utils import make_env as environments
from mava.utils.checkpointing import Checkpointer
from mava.utils.cmaes_utils import (
    CMAPoolEmitter,
)
from mava.utils.logger import LogEvent, MavaLogger
from mava.utils.model_downloads import delete_local_checkpoints, unzip_local_checkpoints
from mava.utils.network_utils import _DISCRETE, get_action_head


class InferenceTimeLogger(MavaLogger):
    def log(self, metrics: Dict, t: int, t_eval: int, event: LogEvent) -> None:
        metrics = jax.tree.map(np.mean, metrics)
        self.logger.log_dict(metrics, t, t_eval, event)
        return

    def calc_winrate(self, _episode_metrics: Dict, _event: LogEvent) -> Dict:
        raise NotImplementedError


def make_ppo_execution_fn(
    env: MarlEnv, keys: chex.Array, config: DictConfig
) -> Tuple[Actor, Params]:
    """Initialise learner_fn, network, optimiser, environment and states."""
    # Get number of agents.
    num_agents = env.num_agents
    config.system.num_agents = num_agents

    # PRNG keys.
    key, actor_net_key = keys

    # Define network and optimisers.
    actor_pre_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso)
    actor_post_torso = hydra.utils.instantiate(config.network.actor_network.post_torso)
    _, action_space_type = get_action_head(env.action_spec())

    if action_space_type != _DISCRETE:
        raise NotImplementedError("COMPASS REINFORCE systems only support discrete action spaces")

    policy_head = DiscreteLogitHead(env.action_dim)

    actor_network = Actor(
        pre_torso=actor_pre_torso,
        post_torso=actor_post_torso,
        hidden_state_dim=config.network.hidden_state_dim,
        action_head=policy_head,
    )
    # Initialise observation with obs of all agents.
    init_obs = env.observation_spec().generate_value()
    init_obs = tree.map(
        lambda x: jnp.repeat(x[jnp.newaxis, ...], config.arch.num_envs, axis=0),
        init_obs,
    )
    # (time, batch, agents, ...)
    init_obs = tree.map(lambda x: x[jnp.newaxis, ...], init_obs)
    # (time, batch, agents)
    init_done = jnp.zeros((1, config.arch.num_envs, num_agents), dtype=bool)
    init_x = (init_obs, init_done)
    # Latent vectors do not need a time axis since we put the same latent at each timestep in a
    # trajectory.
    # (batch, agent, latent_dim)
    init_latent = jnp.ones(
        (config.arch.num_envs, config.system.num_agents, config.arch.compass_latent_dim)
    )

    # Initialise hidden states.
    # (batch, agents, ...)
    init_policy_hstate = ScannedRNN.initialize_carry(
        (config.arch.num_envs, num_agents), config.network.hidden_state_dim
    )
    # initialise params and optimiser state.
    actor_params = actor_network.init(actor_net_key, init_policy_hstate, init_x, init_latent)

    # Unzip local model checkpoint
    if config.logger.checkpointing.unzip_local_model:
        unzip_local_checkpoints(
            checkpoint_rel_dir="checkpoints",
            model_name=config.logger.system_name,
            run_id=config.logger.checkpointing.download_args.neptune_run_name,
        )

    if config.logger.checkpointing.load_model:
        loaded_checkpoint = Checkpointer(
            model_name=config.logger.system_name,
            **config.logger.checkpointing.load_args,  # Other checkpoint args
        )
        # Restore the learner state from the checkpoint
        restored_params, _ = loaded_checkpoint.restore_params(
            input_params=actor_params, restore_hstates=False, THiddenState=HiddenStates
        )

        actor_params = restored_params

    return actor_network, actor_params


def get_eval_fn(
    env: MarlEnv,
    act_fn: Any,
    n_envs: int = 1,
    log_win_rate: bool = False,
) -> Callable:
    """Creates a function that can be used to evaluate agents on a given environment.

    Args:
    ----
        env: an environment that conforms to the mava environment spec.
        act_fn: a function that takes in params, timestep, key and optionally a state
                and returns actions and optionally a state (see `EvalActFn`).
        n_envs: the number of environments to run in parallel.
        log_win_rate: whether to log the win rate of the agents.
    """

    def eval_fn(
        params: FrozenDict,
        env_key: PRNGKey,
        acting_key: PRNGKey,
        init_act_state: ActorState,
        latents: chex.Array,
    ) -> Metrics:
        """Evaluates the given params on an environment and returns relevant
        metrics.

        Returns: Dict[str, Array] - dictionary of metric name to metric values
        for each episode.
        """

        def _env_step(eval_state: _EvalEnvStepState, _: Any) -> Tuple[_EvalEnvStepState, TimeStep]:
            """Performs a single environment step"""
            env_state, ts, key, actor_state = eval_state

            print("ts.observation.shape: ", ts.observation)
            print("latents.shape: ", latents.shape)

            key, act_key = jax.random.split(key)
            action, actor_state = act_fn(params, ts, act_key, actor_state, latents)

            env_state, ts = jax.vmap(env.step)(env_state, action)

            return (env_state, ts, key, actor_state), ts

        def _episode(env_key: PRNGKey) -> Metrics:
            """Simulates `n_envs` episodes."""
            # split key
            reset_keys = jax.random.split(env_key, n_envs)

            # get initial env states
            env_state, ts = jax.vmap(env.reset)(reset_keys)

            # Note: the key of the step state must be handled outside
            step_state = env_state, ts, acting_key, init_act_state
            _, timesteps = jax.lax.scan(_env_step, step_state, jnp.arange(env.time_limit + 1))

            metrics = timesteps.extras["episode_metrics"]
            if log_win_rate:
                metrics["won_episode"] = timesteps.extras["won_episode"]

            # find the first instance of done to get the metrics at that timestep
            # we don't care about subsequent steps
            done_idx = jnp.argmax(timesteps.last(), axis=0)
            metrics = jax.tree.map(
                lambda m: m[done_idx, jnp.arange(n_envs)],
                metrics,
            )
            del metrics["is_terminal_step"]  # not needed for logging

            return metrics

        metrics = _episode(env_key)

        return metrics

    return eval_fn


def run_experiment(_config: DictConfig) -> float:
    """Runs experiment."""

    exp_start_time = time.time()

    _config.logger.system_name = "rec_ireinforce_compass_cmaes"
    config = copy.deepcopy(_config)

    n_devices = len(jax.devices())

    if config.system.recurrent_chunk_size is not None:
        raise NotImplementedError("Compass systems do not support recurrent rollout chunking.")

    # Set recurrent chunk size.
    if config.system.recurrent_chunk_size is None:
        config.system.recurrent_chunk_size = config.system.rollout_length
    else:
        assert (
            config.system.rollout_length % config.system.recurrent_chunk_size == 0
        ), "Rollout length must be divisible by recurrent chunk size."

        assert (
            config.arch.num_envs % config.system.num_minibatches == 0
        ), "Number of envs must be divisibile by number of minibatches."

    # Create the enviroments for train and eval.
    _, eval_env = environments.make(config, fixed_reset=True)

    # PRNG keys.
    key, eval_key, actor_net_key, emitter_key = jax.random.split(
        jax.random.PRNGKey(config.system.seed), num=4
    )

    # Setup learner.
    actor_network, params = make_ppo_execution_fn(eval_env, (key, actor_net_key), config)

    # duplicate params across devices
    params = flax.jax_utils.replicate(params, devices=jax.devices())

    def make_rec_eval_act_fn(
        actor_apply_fn: CompassRecActorApply,
        config: DictConfig,
    ) -> COMPASSEvalActFn:
        """Makes an act function that conforms to the evaluator API given a standard
        recurrent mava actor network."""

        _hidden_state = "hidden_state"

        def eval_act_fn(
            params: FrozenDict,
            timestep: TimeStep,
            key: chex.PRNGKey,
            actor_state: ActorState,
            latent: chex.Array,
        ) -> Tuple[Action, Dict]:
            hidden_state = actor_state[_hidden_state]

            n_agents = timestep.observation.agents_view.shape[1]
            last_done = timestep.last()[..., jnp.newaxis].repeat(n_agents, axis=-1)
            ac_in = (timestep.observation, last_done)
            ac_in = tree.map(lambda x: x[jnp.newaxis], ac_in)  # add batch dim to obs

            hidden_state, actor_logits = actor_apply_fn(params, hidden_state, ac_in, latent)
            pi = IdentityTransformation(distribution=tfd.Categorical(logits=actor_logits))
            action = pi.mode() if config.arch.evaluation_greedy else pi.sample(seed=key)
            return action.squeeze(0), {_hidden_state: hidden_state}

        return eval_act_fn

    eval_act_fn = make_rec_eval_act_fn(actor_network.apply, config)

    # get the main "inference time" parameters here
    n_devices = jax.device_count()
    n_envs = config.inference.n_envs  # typically 32

    budget = config.inference.budget  # typically 6400
    n_attempts = config.inference.n_attempts

    # prepare the budget
    sequential_budget = budget // n_attempts

    if n_envs % n_devices != 0:
        raise ValueError(f"Num eval episodes must be divisible by num devices ({n_devices})")

    n_envs_per_device = n_envs // n_devices

    # define the evaluator
    evaluator = get_eval_fn(
        eval_env,
        eval_act_fn,
        n_envs_per_device,
        config.env.log_win_rate,
    )

    # inputs are params, env_key, acting_key, init_act_state, latents
    # with dims going as deep as (n_devices, n_attempts, ...)
    # deepest being acting_key
    # evaluator is already vmapped over parallel envs
    # hence we only vmap over the attempts
    # latent vectors: (n_devices, n_attempts, n_envs, latent_dim)
    evaluator = jax.vmap(evaluator, in_axes=(None, None, 0, None, 0))
    evaluator = jax.pmap(evaluator)

    # Logger setup
    logger = InferenceTimeLogger(config)
    cfg: Dict = OmegaConf.to_container(config, resolve=True)
    cfg["arch"]["devices"] = jax.devices()
    pprint(cfg)

    # Create an initial hidden state used for resetting memory for evaluation
    eval_hs = ScannedRNN.initialize_carry(
        (n_envs_per_device, config.system.num_agents),
        config.network.hidden_state_dim,
    )
    eval_hs = flax.jax_utils.replicate(eval_hs, devices=jax.devices())

    # initialise accumulated metrics
    max_episode_return = -jnp.inf * jnp.ones(n_envs)
    max_episode_win = jnp.zeros(n_envs)

    # create the env keys (fixed envs for the entire evaluation)
    env_key = jax.random.PRNGKey(config.inference.env_seed)
    env_keys = jax.random.split(env_key, n_devices).reshape(n_devices, -1)

    emitter = None
    if config.inference.search_strategy.name == "cmaes":
        e_kwargs = config.inference.search_strategy.kwargs
        e_kwargs.num_best = n_attempts // 2
        emitter = CMAPoolEmitter(
            num_states=e_kwargs.num_states,
            population_size=n_attempts,
            num_best=e_kwargs.num_best,
            search_dim=config.arch.compass_latent_dim,
            init_sigma=e_kwargs.init_sigma,
            delay_eigen_decomposition=e_kwargs.delay_eigen_decomposition,
            init_minval=-config.arch.latent_amplifier * jnp.ones((config.arch.compass_latent_dim,)),
            init_maxval=config.arch.latent_amplifier * jnp.ones((config.arch.compass_latent_dim,)),
            random_key=emitter_key,
        )

        cmaes_state = jax.tree.map(
            lambda x: jnp.repeat(jnp.expand_dims(x, axis=0), repeats=n_envs, axis=0),
            emitter.init(),
        )

    for eval_step in range(sequential_budget):
        total_exp_time = time.time() - exp_start_time
        break_condition = config.inference.time_constraint is not None and (
            total_exp_time > config.inference.time_constraint
        )
        if break_condition:
            print(
                f"""Breaking due to time constraint:
                {total_exp_time} > {config.inference.time_constraint}"""
            )
            break

        # prepare acting_keys: (n_devices, n_attempts, ...)
        eval_key, acting_key, latent_key = jax.random.split(eval_key, num=3)
        acting_keys = jax.random.split(acting_key, n_devices * n_attempts).reshape(
            n_devices, n_attempts, -1
        )

        if config.inference.search_strategy.name == "cmaes":
            subkeys = jax.random.split(latent_key, num=n_envs)
            latents, _ = jax.vmap(emitter.sample)(cmaes_state, subkeys)  # type: ignore

            # switch attempts & envs dimensions
            latents = latents.transpose((1, 0, 2))

        elif config.inference.search_strategy.name == "uniform_sampling":
            # prepare the latent vectors - for now, uniform sampling
            latents = jax.random.uniform(
                latent_key,
                (
                    n_attempts,
                    n_envs,
                    config.arch.compass_latent_dim,
                ),
                minval=config.inference.search_strategy.kwargs.minval,
                maxval=config.inference.search_strategy.kwargs.maxval,
            )
        else:
            raise ValueError("Unknown search strategy")

        # split over the devices
        latents = latents.reshape(
            n_devices, n_attempts, n_envs_per_device, config.arch.compass_latent_dim
        )

        # repeat the latents for agents (TODO: hide it in the network)
        latents = jnp.repeat(jnp.expand_dims(latents, axis=-2), config.system.num_agents, axis=-2)

        start_time = time.time()

        # evaluate & get metrics
        eval_metrics = evaluator(params, env_keys, acting_keys, {"hidden_state": eval_hs}, latents)

        eval_metrics = jax.block_until_ready(eval_metrics)

        end_time = time.time()

        # remove the device dimension for all metrics
        # i.e. merge dim_devices and dim_envs_per_device
        eval_metrics = jax.tree.map(lambda x: x.reshape(n_attempts, -1, *x.shape[3:]), eval_metrics)

        # update of the emitter
        if config.inference.search_strategy.name == "cmaes":
            # remove the agents dimension
            latents = latents[:, :, :, 0, :]

            # remove the device dimension again for the latents
            latents = latents.reshape(n_attempts, -1, *latents.shape[3:])

            # re-switch attempts and envs dimensions - (n_envs, n_attempts, #dim)
            latents = latents.transpose((1, 0, 2))

            # Calculate fitness for CMA-ES
            fitness = eval_metrics["episode_return"]

            # also need to reshape the fitness (n_envs, n_attempts)
            fitness = fitness.transpose((1, 0))

            # Sort latents based on fitness
            sorted_indices = jnp.argsort(-fitness, axis=1)
            sorted_latents = jnp.take_along_axis(latents, sorted_indices[:, :, jnp.newaxis], axis=1)

            # Update the CMA-ES emitter state
            cmaes_state = jax.vmap(
                emitter.update_state  # type: ignore
            )(  # type: ignore
                cmaes_state,
                sorted_candidates=sorted_latents[:, : e_kwargs.num_best, :],
            )

        # get max per episode over the last batch of attempts
        max_latest_batch = np.max(eval_metrics["episode_return"], axis=0)

        # keep track of the best ever per environment
        max_episode_return = np.maximum(max_episode_return, max_latest_batch)

        # metrics post-processing and logging
        total_timesteps = jnp.sum(eval_metrics["episode_length"])
        eval_metrics["steps_per_second"] = total_timesteps / (end_time - start_time)

        # now get min, mean, max, std over the attempts
        eval_metrics["return(mean)"] = np.mean(eval_metrics["episode_return"], axis=0)
        eval_metrics["return(min)"] = np.min(eval_metrics["episode_return"], axis=0)
        eval_metrics["return(max)"] = np.max(eval_metrics["episode_return"], axis=0)
        eval_metrics["return(std)"] = np.std(eval_metrics["episode_return"], axis=0)

        eval_metrics["return(accumulated_max)"] = max_episode_return

        # was the episode won over the last batch
        if "won_episode" in eval_metrics:
            latest_batch_win = np.max(eval_metrics["won_episode"], axis=0)
            max_episode_win = np.maximum(max_episode_win, latest_batch_win)
            eval_metrics["win_rate(accumulated)"] = (jnp.sum(max_episode_win) / n_envs) * 100

            # get the win rate
            n_won_episodes: int = np.sum(eval_metrics["won_episode"])
            n_episodes: int = np.size(eval_metrics["won_episode"])
            win_rate = (n_won_episodes / n_episodes) * 100

            eval_metrics["win_rate"] = win_rate
            eval_metrics.pop("won_episode")

        # log
        logger.log(eval_metrics, (eval_step + 1) * n_attempts, eval_step, LogEvent.EVAL)

    # record the performance for the final evaluation run
    eval_performance = float(np.mean(max_episode_return))

    # stop logging
    logger.stop()

    # Remove the local model checkpoint when training is complete
    if config.logger.checkpointing.delete_local_checkpoints:
        delete_local_checkpoints(checkpoint_folder_dir="checkpoints")

    return eval_performance


@hydra.main(
    config_path="../../../configs/default",
    config_name="rec_reinforce_compass_eval_cmaes.yaml",
    version_base="1.2",
)
def hydra_entry_point(cfg: DictConfig) -> float:
    """Experiment entry point."""
    # Allow dynamic attributes.
    OmegaConf.set_struct(cfg, False)

    # Run experiment.
    eval_performance = run_experiment(cfg)
    print(
        f"{Fore.CYAN}{Style.BRIGHT}Recurrent IREINFORCE COMPASS CMAES Search"
        f"experiment completed{Style.RESET_ALL}"
    )
    return eval_performance


if __name__ == "__main__":
    hydra_entry_point()
