import numpy as np
import torch
import os
import wandb

from utils import compute_running_mean, scatter_batch, make_dir

TANH_CONST = 4.0
LOG_FREQ = 100
EPS_CONST = 1e-9


class UpdatingRewardModule(object):
    """
    Class for determining the reward of the incremental learning
    objective. This module simply keeps track of all the trajectory
    latents from all skills, and uses them to compute the reward
    for the current skill.
    """

    def __init__(self, transformed_obs_shape, max_episode_timesteps,
                 saved_latent_per_skill, total_skills,
                 max_running_obses=10, slow_update_coeff=10,
                 device='cuda', alpha=1.0, beta=1.0, use_t_obs=True,
                 use_t_vel=False, use_timesteps=True,
                 use_cosine_sim_vel=False, use_entropy=True,
                 topk=3, entropy_use_ln=False):
        if use_timesteps:
            self.capacity = saved_latent_per_skill
            self.max_timesteps = max_episode_timesteps
            self.max_running_obses = max_running_obses
        else:
            self.capacity = saved_latent_per_skill * max_episode_timesteps
            self.max_running_obses = max_running_obses * max_episode_timesteps
            self.max_timesteps = 1

        self._use_timesteps = use_timesteps
        self.device = device
        self._slow_update_coeff = slow_update_coeff
        self._use_t_vel = use_t_vel
        self._use_t_obs = use_t_obs
        self._use_cosine_sim_vel = use_cosine_sim_vel
        self._use_entropy = use_entropy
        self._topk = topk
        self._entropy_use_ln = entropy_use_ln

        assert use_t_obs or use_t_vel, \
            "At least one of obs or vel must be stored"

        assert not (use_cosine_sim_vel and use_entropy), \
            "Using both cosine similarity and entropy may become unprincipled"

        # the proprioceptive obs is stored as float32, pixels obs as uint8
        obs_dtype = np.float32 if len(transformed_obs_shape) == 1 else np.uint8

        # For collecting latents from a learned skill
        self.collected_obses = np.empty((self.capacity + 1,
                                         self.max_timesteps,
                                         *transformed_obs_shape),
                                        dtype=obs_dtype)
        self.collected_vel = np.empty((self.capacity + 1,
                                       self.max_timesteps,
                                       *transformed_obs_shape),
                                      dtype=obs_dtype)
        self.collected_obs_exists = np.zeros((self.capacity + 1,
                                              self.max_timesteps, 1))

        # For keeping track of latents from the current skill being learned
        self.running_obses = np.empty((self.max_running_obses + 1,
                                       self.max_timesteps,
                                       *transformed_obs_shape),
                                      dtype=obs_dtype)
        self.running_vel = np.empty((self.max_running_obses + 1,
                                     self.max_timesteps,
                                     *transformed_obs_shape),
                                    dtype=obs_dtype)
        self.running_obses_exists = np.zeros((self.max_running_obses + 1,
                                              self.max_timesteps, 1))

        # Big bank of saved latents.
        total_saved_latents = self.capacity * total_skills
        self.saved_latents = np.empty((total_saved_latents,
                                       self.max_timesteps,
                                       *transformed_obs_shape),
                                      dtype=np.float32)
        self.saved_vel_latents = np.empty((total_saved_latents,
                                           self.max_timesteps,
                                           *transformed_obs_shape),
                                          dtype=np.float32)
        self.saved_latents_exists = np.zeros((total_saved_latents,
                                              self.max_timesteps, 1))

        self.idx = 0
        self.collect_idx = 0
        self.saved_until = 0
        self.full = False
        self.skill_idx = 0
        self.saved_latent_per_skill = saved_latent_per_skill

        self.alpha = alpha
        self.beta = beta

        self._reward_calls = 0.
        self._saved_reward_call = 0.
        self._diversity_normalizer = 1.
        self._compactness_normalizer = 1.
        self.average_rewards = [0]
        self.average_compactness_penalty = [0]
        self.average_diversity_reward = [0]

    def register_logger(self, logger):
        self.logger = logger

    def current_len(self):
        # Only sample from the first slow_update_length samples
        slow_update_length = self.max_running_obses // self._slow_update_coeff
        if self.idx > slow_update_length or self.full:
            return slow_update_length
        return self.idx

    def add_current(self, transformed_obs, obs,
                    next_transformed_obs, next_obs,
                    timestep, done):
        t = timestep - 1 if self._use_timesteps else 0
        if self._use_t_vel:
            t_vel = next_transformed_obs - transformed_obs
            np.copyto(self.running_vel[self.idx, t], t_vel)
        if self._use_t_obs:
            np.copyto(self.running_obses[self.idx, t], next_transformed_obs)
        np.copyto(self.running_obses_exists[self.idx, t], 1.)

        if done or not self._use_timesteps:
            self.idx += 1
            self.idx %= self.max_running_obses
            if self.idx == 0:
                self.full = True

    def add_collected_trajectory(self, transformed_obs, obs,
                                 next_transformed_obs, next_obs,
                                 timestep, done):
        t = timestep - 1 if self._use_timesteps else 0
        if self._use_t_vel:
            t_vel = next_transformed_obs - transformed_obs
            np.copyto(self.collected_vel[self.collect_idx, t],
                      t_vel)
        if self._use_t_obs:
            np.copyto(self.collected_obses[self.collect_idx, t],
                      next_transformed_obs)
        np.copyto(self.collected_obs_exists[self.collect_idx, t], 1.)

        if done or (not self._use_timesteps):
            self.collect_idx += 1

    @property
    def temperature(self, use_temp=True, tanh=True):
        if self._saved_reward_call == 0. or (not use_temp):
            return 1.

        t = self._reward_calls / self._saved_reward_call
        if tanh:
            return np.tanh((2 * t - 1) * TANH_CONST)
        return t

    def add_new_skill(self, num_steps_next_skill=None):
        self.skill_idx += 1
        to_save = min(self.capacity, self.collect_idx)
        saved_until = self.saved_until
        if self._use_t_obs:
            self.saved_latents[saved_until:saved_until + to_save, ...] = (
                self.collected_obses[:self.collect_idx, ...]
            )
        if self._use_t_vel:
            self.saved_vel_latents[saved_until:saved_until + to_save, ...] = (
                self.collected_vel[:self.collect_idx, ...]
            )

        self.saved_latents_exists[saved_until:saved_until + to_save, ...] = (
            self.collected_obs_exists[:self.collect_idx, ...]
        )

        self.saved_until += to_save
        self.collect_idx = 0
        self.idx = 0
        self.full = False

        if num_steps_next_skill:
            self._saved_reward_call = num_steps_next_skill
        else:
            self._saved_reward_call = self._reward_calls
        self._reward_calls = 0.
        if self.average_diversity_reward[-1] == 0.:
            self._diversity_normalizer = 1.
        else:
            self._diversity_normalizer = 1./(
                self.beta * self.average_diversity_reward[-1])
        if self.average_compactness_penalty[-1] == 0:
            self._compactness_normalizer = 1.
        else:
            self._compactness_normalizer = 1./(
                self.alpha * self.average_compactness_penalty[-1])
        self.average_rewards.append(0)
        self.average_compactness_penalty.append(0)
        self.average_diversity_reward.append(0)

    def process_replay_obses(self, next_t_obs, t_obs, timesteps):
        trans_obs_batch, trans_vel_batch, truth_table = None, None, None
        if self._use_t_obs:
            trans_obs_batch, truth_table = scatter_batch(
                next_t_obs, timesteps, self.max_timesteps
            )
        if self._use_t_vel:
            t_vel = next_t_obs - t_obs
            if self._use_cosine_sim_vel:
                # Normalize the velocities before calculating distance
                vel_norm = (torch.norm(t_vel, p=2, dim=-1, keepdim=True) +
                            EPS_CONST)
                t_vel /= vel_norm
            trans_vel_batch, truth_table = scatter_batch(
                t_vel, timesteps, self.max_timesteps
            )
        return trans_obs_batch, trans_vel_batch, truth_table

    def get_rewards(self, next_t_obs, t_obs,
                    timesteps, batch_size=512,
                    step=0, eval=False):
        # Declare placeholders first
        trans_obs_batch, trans_vel_batch, truth_table =\
            self.process_replay_obses(next_t_obs, t_obs, timesteps)

        if not self._use_timesteps:
            batch_shape = (trans_obs_batch.shape if self._use_t_obs
                           else trans_vel_batch.shape)
            timestep_dim = batch_shape[1]
            assert timestep_dim == 1, \
                   ("Timestep provided while the reward module ignores it"
                    f"{batch_shape}")

        if self.current_len() > 0:
            current_idxs = np.random.randint(0,
                                             self.current_len() // 2 + 1,
                                             size=batch_size)
            # shape n_reward_batch x max_timesteps
            current_truth_table = torch.as_tensor(
                self.running_obses_exists[current_idxs],
                device=self.device
            ).squeeze(dim=-1)

            # shape n_replay_batch x n_reward_batch x max_timesteps
            consistency_penalty = 0.
            if self._use_t_obs:
                current_trajs = torch.as_tensor(
                    self.running_obses[current_idxs],
                    device=self.device
                )
                consistency_penalty += torch.norm(
                    trans_obs_batch[:, None, :, :] -
                    current_trajs[None, :, :, :],
                    dim=3, p=2
                )

            if self._use_t_vel:
                current_vels = torch.as_tensor(
                    self.running_vel[current_idxs],
                    device=self.device
                )
                consistency_penalty += torch.norm(
                    trans_vel_batch[:, None, :, :] -
                    current_vels[None, :, :, :],
                    dim=3, p=2
                )

            # zero out the unnecessary values in the table
            consistency_penalty *= truth_table[:, None, :]
            consistency_penalty *= current_truth_table[None, :, :]

            # sum over the timestep axis n_replay_batch x n_reward_batch
            consistency_penalty = consistency_penalty.sum(dim=-1)
        else:
            if trans_obs_batch is not None:
                length = len(trans_obs_batch)
            else:
                length = len(trans_vel_batch)
            consistency_penalty = torch.zeros((length, 1),
                                              device=self.device)

        if self.saved_until > 0:
            past_traj_idxs = np.random.randint(0,
                                               self.saved_until,
                                               size=batch_size)

            # shape n_reward_batch x max_timesteps x 1
            past_traj_exists = torch.as_tensor(
                self.saved_latents_exists[past_traj_idxs],
                device=self.device
            ).squeeze(dim=-1)
            diversity_reward = 0.

            if self._use_t_obs:
                # shape n_reward_batch x max_timesteps x d
                past_trajs = torch.as_tensor(
                    self.saved_latents[past_traj_idxs], device=self.device)

                # shape n_replay_batch x n_reward_batch x max_timesteps
                diversity_reward += torch.norm(
                    trans_obs_batch[:, None, :, :] - past_trajs[None, :, :, :],
                    dim=3, p=2
                )
            if self._use_t_vel:
                # shape n_reward_batch x max_timesteps x d
                past_vels = torch.as_tensor(
                    self.saved_vel_latents[past_traj_idxs], device=self.device)

                # shape n_replay_batch x n_reward_batch x max_timesteps
                diversity_reward += torch.norm(
                    trans_vel_batch[:, None, :, :] - past_vels[None, :, :, :],
                    dim=3, p=2
                )

            # zero out the unnecessary values in the table
            diversity_reward *= truth_table[:, None, :]
            diversity_reward *= past_traj_exists[None, :, :]

            # sum over the timestep axis to get n_replay_batch x n_reward_batch
            diversity_reward = diversity_reward.sum(dim=-1)

        else:
            diversity_reward = torch.ones_like(consistency_penalty)

        if not self._use_entropy:
            penalty = consistency_penalty.mean(axis=1, keepdim=True)
            bonus = diversity_reward.mean(axis=1, keepdim=True)
        else:
            _, total_consistency_samples = consistency_penalty.shape
            topk_penalty, _ = torch.topk(
                consistency_penalty,
                k=min(self._topk, total_consistency_samples),
                dim=1, largest=False)

            penalty = self._entropy_norm(topk_penalty[:, -1:])
            topk_bonus, _ = torch.topk(
                diversity_reward, k=self._topk,
                dim=1, largest=False)
            bonus = self._entropy_norm(topk_bonus[:, -1:])
        normalized_penalty = (-self.alpha *
                              self.temperature *
                              self._compactness_normalizer *
                              penalty)
        normalized_bonus = self.beta * self._diversity_normalizer * bonus
        reward = normalized_penalty + normalized_bonus

        if not eval:
            self._log(penalty, bonus, reward, step=step)
        return reward

    def _entropy_norm(self, x):
        if self._entropy_use_ln:
            return torch.log(x + EPS_CONST)
        return x

    def _log(self, penalty, bonus, reward, step=0):
        self._reward_calls += 1
        mean_penalty = penalty.mean().cpu().item()
        mean_bonus = bonus.mean().cpu().item()
        mean_reward = reward.mean().cpu().item()
        self.average_compactness_penalty[-1] = compute_running_mean(
            self.average_compactness_penalty[-1], mean_penalty,
            self._reward_calls
        )
        self.average_diversity_reward[-1] = compute_running_mean(
            self.average_diversity_reward[-1], mean_bonus,
            self._reward_calls
        )
        self.average_rewards[-1] = compute_running_mean(
            self.average_rewards[-1], mean_reward,
            self._reward_calls
        )
        if step % LOG_FREQ == 0:
            wandb.log({
                'reward/mean_diversity_bonus': mean_bonus,
                'reward/mean_consistency_penalty': mean_penalty,
                'reward/mean_average_reward': mean_reward,
                'update_step': step,
            })

    def save_buffers(self, path=None):
        if not path:
            path = os.path.join(os.getcwd(), 'reward_buffer')
        make_dir(path)
        to_save_dict = {
            'saved_latents': self.saved_latents,
            'saved_vel_latents': self.saved_vel_latents,
            'saved_latents_exists': self.saved_latents_exists,
            'saved_until': self.saved_until,
            'last_skill_length': self._saved_reward_call,
            'avg_rewards': self.average_rewards,
            'avg_compactness_penalty': self.average_compactness_penalty,
            'avg_diversity_reward': self.average_diversity_reward,
            'skill_idx': self.skill_idx,
            'diversity_normalizer': self._diversity_normalizer,
            'compactness_normalizer': self._compactness_normalizer,
        }
        np.savez_compressed(file=os.path.join(path, 'reward_buffer.npz'),
                            **to_save_dict)

    def load_buffers(self, path=None):
        if not path:
            path = os.getcwd()
        rb_path = os.path.join(path, 'reward_buffer')
        to_load_dict = np.load(
            os.path.join(rb_path, 'reward_buffer.npz')
        )
        self.saved_latents = to_load_dict['saved_latents'].clip(-4.2, 4.2)
        self.saved_vel_latents = to_load_dict['saved_vel_latents']
        self.saved_latents_exists = to_load_dict['saved_latents_exists']
        self.saved_until = to_load_dict['saved_until']
        self._saved_reward_call = int(to_load_dict['last_skill_length'])
        self.average_rewards = to_load_dict['avg_rewards']
        self.average_compactness_penalty = \
            to_load_dict['avg_compactness_penalty']
        self.average_diversity_reward = to_load_dict['avg_diversity_reward']
        self.skill_idx = to_load_dict['skill_idx']
        self._diversity_normalizer = to_load_dict['diversity_normalizer']
        self._compactness_normalizer = to_load_dict['compactness_normalizer']
