import functools
from typing import Any, Dict, Mapping, NamedTuple, Optional, Text, Tuple, Union

import jsonlines
from absl import logging

import numpy as np

import jax
import rlax
from jax import numpy as jnp
from jaxline import experiment
from jaxline import utils

import haiku as hk

from acme.jax import utils as acme_utils

import optax

from rl import networks
from rl import offline
from rl import losses
from rl.utils import checkpointing


def _get_first(xs):
    return jax.tree_map(lambda x: x[0], xs)


class _PolicyEvalExperimentState(NamedTuple):
    online_params: hk.Params
    target_params: hk.Params
    opt_state: optax.OptState


class PolicyEvalExperiment(experiment.AbstractExperiment):

    def __init__(self,
                 dataset_file: str,
                 max_steps: int,
                 batch_size: int,
                 batch_sample_replace: bool,
                 ridge_coeff: float,
                 random_seed: np.random.SeedSequence,
                 network: str,
                 network_config: Mapping[Text, Any],
                 optimizer: str,
                 optimizer_config: Mapping[Text, Any],
                 lr_schedule: str,
                 lr_schedule_config: Mapping[Text, Any],
                 target_update_period: int,
                 checkpointing_config: Mapping[Text, Any],
                 evaluation_config: Mapping[Text, Any],
                 snapshot_config: Mapping[Text, Any],
                 loss: str,
                 linear_solver,
                 train_test_split: float,
                 hess_average_loss: bool,
                 use_iterative_refinement,
                 tol
                 ):

        init_rng_seed, np_rng_seed = random_seed.spawn(2)

        self._np_rng_seed = np_rng_seed
        self._init_rng = jnp.array(init_rng_seed.generate_state(2))

        self._train_input = None
        self._train_state = None
        self.dataset = None
        self.eval_dataset = None
        self._train_test_split = train_test_split

        self._max_steps = max_steps

        self._dataset_file = dataset_file
        self._batch_size = batch_size
        self._eval_batch_size = evaluation_config["batch_size"]
        self._batch_sample_replace = batch_sample_replace

        self._optimizer = optimizer
        self._optimizer_config = optimizer_config

        self._lr_schedule = getattr(optax, lr_schedule)(**lr_schedule_config)
        self._target_update_period = target_update_period

        # build the transformed ops
        forward_fn = functools.partial(getattr(networks, network), **network_config)
        self.forward = hk.without_apply_rng(hk.transform(forward_fn))
        # training can handle multiple devices, thus the pmap
        self.update_pmap = jax.pmap(self._update_step, axis_name='i')
        # eval can handle multiple devices, thus the pmap
        self.eval_pmap = jax.pmap(self._eval_step, axis_name='i')

        self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)

        self._evaluate_snapshots = snapshot_config["evaluate_snapshots"]
        self._evaluate_log_path = evaluation_config["log_path"]
        self._run_id = evaluation_config["run_id"]

        if loss == "frm":
            self._loss_fn = functools.partial(
                losses.frm_td_loss,
                ridge_coeff=ridge_coeff,
                hess_average_loss=hess_average_loss,
                use_iterative_refinement=use_iterative_refinement,
                tol=tol,
                linear_solver=linear_solver,
            )
        elif loss == "linear_frm":
            self.dataset = offline.Dataset.load(self._dataset_file)
            self._loss_fn = functools.partial(
                losses.dense_linear_frm_td_loss,
                ridge_coeff=ridge_coeff,
                hessian_inputs=self.dataset.obs_t,
                hess_average_loss=hess_average_loss,
            )
        elif loss == "erm":
            self._loss_fn = losses.mean_td_learning

    @property
    def state(self):
        return self._train_state

    def get_optimizer(self, learning_rate):
        return getattr(optax, self._optimizer)(learning_rate, **self._optimizer_config)

    def _build_train_input(self):
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        if self.dataset is None:
            self.dataset = offline.Dataset.load(self._dataset_file)

        if len(self.dataset.shape) != 1:
            raise NotImplementedError("Datasets has more than 1 batch dimension.")

        num_train_samples = self.dataset.shape[0]
        if self._train_test_split:
            # use a certain ratio of samples as training data
            num_train_samples = int(num_train_samples * self._train_test_split)
            train_data = jax.tree_util.tree_map(lambda x: x[:num_train_samples, ...], self.dataset)
        else:
            train_data = self.dataset

        # split the seed sequence and create a random number generator
        batch_seed, self._np_rng_seed = self._np_rng_seed.spawn(2)
        rng = np.random.default_rng(batch_seed)

        batch_sample_replace = self._batch_sample_replace
        batch_size = (num_devices, per_device_batch_size)
        while True:
            idx = rng.choice(num_train_samples, batch_size, replace=batch_sample_replace)
            yield jax.tree_map(lambda x: np.take(x, idx, axis=0), train_data)

    def _build_eval_input(self):
        num_devices = jax.device_count()
        global_batch_size = self._batch_size
        per_device_batch_size, ragged = divmod(global_batch_size, num_devices)

        if ragged:
            raise ValueError(
                f'Global batch size {global_batch_size} must be divisible by '
                f'num devices {num_devices}')

        if self.dataset is None:
            self.dataset = offline.Dataset.load(self._dataset_file)

        if len(self.dataset.shape) != 1:
            raise NotImplementedError("Datasets has more than 1 batch dimension.")

        if self.eval_dataset is None:
            if self._train_test_split:
                # use a certain ratio of samples as test data
                num_train_samples = self.dataset.shape[0]
                num_train_samples = int(num_train_samples * self._train_test_split)
                eval_dataset = jax.tree_util.tree_map(lambda x: x[num_train_samples:, ...], self.dataset)
            else:
                eval_dataset = self.dataset

            # compute empirical discounted returns
            g_t = rlax.discounted_returns(eval_dataset.r_t, eval_dataset.discount_t, 0.)
            g_t.block_until_ready()
            self.eval_dataset = (eval_dataset, g_t)

        num_eval_samples = self.eval_dataset[1].shape[0]
        batch_idx = 0
        while batch_idx < num_eval_samples:
            next_batch_idx = min((batch_idx + global_batch_size, num_eval_samples))
            batch_size = next_batch_idx - batch_idx

            # if remainder of the eval set is ragged, run with as many as
            # possible and discard the rest
            per_device_batch_size, ragged = divmod(batch_size, num_devices)
            next_batch_idx = next_batch_idx - ragged
            batch_shape = (num_devices, per_device_batch_size)

            batch = jax.tree_map(lambda x: x[batch_idx:next_batch_idx], self.eval_dataset)
            yield jax.tree_map(lambda x: x.reshape(batch_shape + x.shape[1:]), batch)

            batch_idx = next_batch_idx
            if ragged:
                logging.warning(
                    f'Skipping the last {ragged} samples during evaluation to allow the batch "'
                    f'to fit the current number of devices.'
                )
                break

    def _make_initial_state(self, rng, dummy_inputs):
        online_rng, target_rng = jax.random.split(rng)
        online_params = self.forward.init(online_rng, dummy_inputs.obs_t)
        target_params = self.forward.init(target_rng, dummy_inputs.obs_t)
        opt_state = self.get_optimizer(0).init(online_params)

        return _PolicyEvalExperimentState(
            online_params=online_params,
            target_params=target_params,
            opt_state=opt_state
        )

    def _initialize_train(self):
        self._train_input = self._build_train_input()
        # self._train_input = acme_utils.prefetch(self._build_train_input())
        if self._train_state is None:
            init_state = jax.pmap(self._make_initial_state, axis_name='i')
            init_rng = utils.bcast_local_devices(self._init_rng)

            inputs = next(self._train_input)
            with jax.disable_jit():
                self._train_state = init_state(rng=init_rng, dummy_inputs=inputs)

    def _update_step(self, state: _PolicyEvalExperimentState, global_step, rng, inputs):

        r_t, discount_t, obs_t, a_t, obs_tp1 = inputs

        def loss_fn(params):
            return self._loss_fn(
                online_params=params,
                target_params=state.target_params,
                apply_fn=self.forward.apply,
                o_tm1=obs_t,
                r_t=r_t,
                discount_t=discount_t,
                o_t=obs_tp1,
            )
        # update online network
        grad_fn = jax.grad(loss_fn, argnums=0, has_aux=True)
        grads, logs = grad_fn(state.online_params)

        # cross-device grad and logs reductions
        grads = jax.tree_map(lambda v: jax.lax.pmean(v, axis_name='i'), grads)
        logs = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='i'), logs)

        learning_rate = self._lr_schedule(global_step)
        updates, opt_state = self.get_optimizer(learning_rate).update(
            grads, state.opt_state, state.online_params)
        online_params = optax.apply_updates(state.online_params, updates)

        target_params = rlax.periodic_update(
            new_tensors=online_params,
            old_tensors=state.target_params,
            steps=global_step,
            update_period=self._target_update_period,
        )

        return _PolicyEvalExperimentState(
            online_params=online_params,
            target_params=target_params,
            opt_state=opt_state), logs

    def step(self,
             global_step: jnp.ndarray,
             rng: jnp.ndarray,
             **unused_args) -> Dict[str, np.ndarray]:
        if self._train_input is None:
            self._initialize_train()

        inputs = next(self._train_input)

        self._train_state, scalars = self.update_pmap(
            self._train_state,
            global_step=global_step,
            rng=rng,
            inputs=inputs,
        )

        return _get_first(scalars)

    def _eval_step(self, state: _PolicyEvalExperimentState, global_step, inputs, summed_scalars):
        del global_step
        (r_t, discount_t, obs_t, a_t, obs_tp1), g_t = inputs
        params = state.online_params

        g_error = rlax.l2_loss(self.forward.apply(params, obs_t)[..., 0], g_t)
        
        td_error = losses.vec_td_learning(
            online_params=params,
            target_params=params,
            apply_fn=self.forward.apply,
            o_tm1=obs_t,
            r_t=r_t,
            discount_t=discount_t,
            o_t=obs_tp1,
        )

        scalars = {"td_error": td_error**2, "g_error": g_error}
        scalars = jax.tree_util.tree_map(lambda x: jnp.sum(x, axis=0), scalars)

        # Accumulate the sum of scalars for each step.
        if summed_scalars is None:
            summed_scalars = scalars
        else:
            summed_scalars = jax.tree_map(jnp.add, summed_scalars, scalars)

        return summed_scalars
    
    def evaluate(self,
                 global_step: jnp.ndarray,
                 **unused_args) -> Optional[Dict[str, np.ndarray]]:
        if self._train_state is None:
            self._initialize_train()
        
        summed_scalars = None
        eval_inputs = self._build_eval_input()

        # eval_inputs = acme_utils.prefetch(self._build_eval_input())

        num_samples = 0
        for inputs in eval_inputs:
            # add number of samples per device
            num_samples += inputs[0].shape[1]
            summed_scalars = self.eval_pmap(
                    self._train_state, global_step, inputs, summed_scalars)

        mean_scalars = jax.tree_map(lambda x: jnp.mean(x / num_samples), summed_scalars)

        logs = jax.tree_map(float, mean_scalars)
        logs.update(
            step=int(_get_first(global_step)),
        )

        if self._run_id is not None:
            logs["run_id"] = self._run_id

        if self._evaluate_log_path is not None:
            with jsonlines.open(self._evaluate_log_path, 'a') as writer:
                writer.write(logs)

        return mean_scalars

    def save_checkpoint(self, step: int, rng: jnp.ndarray):
        self._checkpointer.maybe_save_checkpoint(
            self._train_state, step=step, rng=rng, is_final=step >= self._max_steps)

    def save_snapshot(self, global_step: jnp.ndarray, rng: jnp.ndarray):
        step = _get_first(global_step)

        if self._evaluate_snapshots:
            scalars = self.evaluate(global_step=global_step)
            logging.info('Step %d: Eval scalars: %s', step, scalars)
        
        self._checkpointer.save_snapshot(self._train_state, step=step, rng=_get_first(rng))

    def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]:
        checkpoint_data = self._checkpointer.maybe_load_checkpoint()
        if checkpoint_data is None:
            return None
        self._train_state, step, rng = checkpoint_data
        return step, rng
