import gtimer as gt
import numpy as np
import torch

from collections import OrderedDict

from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
import lfrl.torch.pytorch_util as ptu
from lfrl.util.eval_util import create_stats_ordered_dict
import lfrl.util.pythonplusplus as ppp


class SACEmpTrainer(TorchTrainer):

    """
    Sharma et al. 2019. "Dynamics-Aware Discovery of Skills".
    """

    def __init__(
            self,
            policy,
            replay_buffer,
            policy_trainer,
            reward_func,
            emp_networks,
            num_policy_updates=1000,
            policy_batch_size=256,
    ):
        super().__init__()

        self.policy = policy
        self.replay_buffer = replay_buffer
        self.policy_trainer = policy_trainer
        self.reward_func = reward_func
        self.emp_networks = emp_networks
        self.num_policy_updates = num_policy_updates
        self.policy_batch_size = policy_batch_size

        self.obs_dim = replay_buffer.obs_dim()
        self.action_dim = replay_buffer.action_dim()

        replay_size = self.replay_buffer.max_replay_buffer_size()
        self._obs = np.zeros((replay_size, self.obs_dim))
        self._next_obs = np.zeros((replay_size, self.obs_dim))
        self._actions = np.zeros((replay_size, self.action_dim))
        self._rewards = np.zeros((replay_size, 1))
        self._ptr = 0
        self.replay_size = replay_size
        self._cur_replay_size = 0

        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self._epoch_size = None
        self.eval_statistics = OrderedDict()

    def add_sample(self, obs, next_obs, action, reward, **kwargs):
        self._obs[self._ptr] = obs
        self._next_obs[self._ptr] = next_obs
        self._actions[self._ptr] = action
        self._rewards[self._ptr] = reward

        self._ptr = (self._ptr + 1) % self.replay_size
        self._cur_replay_size = min(self._cur_replay_size+1, self.replay_size)

    def train_from_paths(self, paths, train_discrim=True, train_policy=True):

        """
        Reading new paths: append latent to state
        Note that is equivalent to on-policy when latent buffer size = sum of paths length
        """

        epoch_obs, epoch_next_obs, epoch_actions = [], [], []

        for path in paths:
            obs = path['observations']
            next_obs = path['next_observations']
            actions = path['actions']
            path_len = len(obs)

            for t in range(path_len):
                epoch_obs.append(obs[t:t+1])
                epoch_next_obs.append(next_obs[t:t+1])
                epoch_actions.append(actions[t:t+1])

        epoch_obs = np.concatenate(epoch_obs, axis=0)
        epoch_next_obs = np.concatenate(epoch_next_obs, axis=0)
        epoch_actions = np.concatenate(epoch_actions, axis=0)
        epoch_rewards = self.reward_func(ptu.from_numpy(epoch_obs))

        self._epoch_size = len(epoch_obs)

        for t in range(len(epoch_obs)):
            self.add_sample(
                epoch_obs[t],
                epoch_next_obs[t],
                epoch_actions[t],
                epoch_rewards[t],
            )

        gt.stamp('policy training', unique=False)

        self.train_from_buffer()

    def train_from_buffer(self, reward_kwargs=None):

        """
        Train policy
        """

        for _ in range(self.num_policy_updates):
            batch = ppp.sample_batch(
                self.policy_batch_size,
                observations=self._obs[:self._cur_replay_size],
                next_observations=self._next_obs[:self._cur_replay_size],
                actions=self._actions[:self._cur_replay_size],
                rewards=self._rewards[:self._cur_replay_size],
            )
            batch = ptu.np_to_pytorch_batch(batch)
            self.policy_trainer.train_from_torch(batch)

        gt.stamp('policy training', unique=False)

        """
        Diagnostics
        """

        if self._need_to_update_eval_statistics:
            self.eval_statistics.update(self.policy_trainer.eval_statistics)
            self.eval_statistics.update(create_stats_ordered_dict(
                'Intrinsic Rewards (Processed)',
                self._rewards[self._cur_replay_size-self._epoch_size:self._cur_replay_size]
            ))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks + self.emp_networks

    def get_snapshot(self):
        snapshot = dict()

        for k, v in self.policy_trainer.get_snapshot().items():
            snapshot['policy_trainer/' + k] = v

        return snapshot
