import numpy as np
import torch

import gtimer as gt

import lfrl.torch.pytorch_util as ptu
from lfrl.samplers.utils import model_rollout_functions as rollout_functions
from lfrl.trainers.dads.dads_model import DADSModelTrainer
import lfrl.util.pythonplusplus as ppp
from lfrl.trainers.dads.empowerment_functions import calculate_contrastive_empowerment
from lfrl.util.eval_util import create_stats_ordered_dict


class DADSModelPriorTrainer(DADSModelTrainer):

    def __init__(
            self,
            latent_prior,
            prior_trainer,
            prior_train_steps=32,
            prior_batch_size=256,
            prior_type='start',
            num_unif_train_calls=0,
            epsilon_greedy=0,
            use_importance_weighting=False,
            importance_weighting_clip=10,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.latent_prior = latent_prior
        self.prior_trainer = prior_trainer
        self.prior_train_steps = prior_train_steps
        self.prior_batch_size = prior_batch_size
        self.prior_type = prior_type
        self.num_unif_train_calls = num_unif_train_calls
        self.epsilon_greedy = epsilon_greedy
        self.use_importance_weighting = use_importance_weighting
        self.importance_weighting_clip = importance_weighting_clip

    def generate_latents(self, obs):
        if self._train_calls < self.num_unif_train_calls:
            return super().generate_latents(obs)
        latents, *_ = self.latent_prior(ptu.from_numpy(obs))
        latents = ptu.get_numpy(latents)
        if self.epsilon_greedy > 0:
            unif_r = np.random.uniform(0, 1, size=latents.shape[0])
            eps_replace = unif_r < self.epsilon_greedy
            unif_latents = super().generate_latents(obs[eps_replace])
            latents[eps_replace] = unif_latents
        return latents

    def generate_paths(self, **kwargs):
        if self.prior_type == 'start':
            return super().generate_paths(**kwargs)
        elif self.prior_type == 'every':
            return rollout_functions.policy_latent_prior_with_disagreement(
                dynamics_model=kwargs['dynamics_model'],
                policy=kwargs['control_policy'],
                latent_prior=self.latent_prior,
                start_states=kwargs['start_states'],
                max_path_length=kwargs['rollout_len'],
                terminal_cutoff=None,
            )
        else:
            raise NotImplementedError

    def calculate_intrinsic_rewards(self, states, next_states, latents, *args, **kwargs):
        if self.restrict_input_size > 0:
            states = states[:, :self.restrict_input_size]
            next_states = next_states[:, :self.restrict_input_size]
        reward_func = calculate_contrastive_empowerment
        rewards, (logp, logp_altz, denom), reward_diagnostics = reward_func(
            self.discriminator,
            states,
            next_states,
            latents,
            num_prior_samples=self.num_prior_samples,
            # distribution_type='prior',
            obs_mean=self.obs_mean,
            obs_std=self.obs_std,
            return_diagnostics=True,
            prior=self.latent_prior,
        )
        rewards[rewards != rewards] = -10  # check for NaN
        rewards = np.clip(rewards, -10, 10)  # numerical stability
        return rewards, (logp, logp_altz, denom), reward_diagnostics

    def train_from_torch(self, batch):
        super().train_from_torch(batch)

        if self._train_calls % self.train_every > 0:
            return

        for _ in range(self.prior_train_steps):
            batch = ppp.sample_batch(
                self.prior_batch_size,
                observations=self._obs[:self._cur_replay_size],
                next_observations=self._next_obs[:self._cur_replay_size],
                actions=self._latents[:self._cur_replay_size],
                rewards=self._rewards[:self._cur_replay_size],
            )
            batch = ptu.np_to_pytorch_batch(batch)
            self.prior_trainer.train_from_torch(batch)

        for k, v in self.prior_trainer.get_diagnostics().items():
            self.eval_statistics['prior_trainer/' + k] = v

    def train_from_buffer(self, reward_kwargs=None):

        """
        Compute intrinsic reward: approximate lower bound to I(s'; z | s)
        """

        if self.relabel_rewards:

            rewards, (logp, logp_altz, denom), reward_diagnostics = self.calculate_intrinsic_rewards(
                self._obs[:self._cur_replay_size],
                self._next_obs[:self._cur_replay_size],
                self._latents[:self._cur_replay_size],
                reward_kwargs=reward_kwargs
            )
            orig_rewards = rewards.copy()
            rewards, postproc_dict = self.reward_postprocessing(rewards, reward_kwargs=reward_kwargs)
            reward_diagnostics.update(postproc_dict)
            self._rewards[:self._cur_replay_size] = np.expand_dims(rewards, axis=-1)

            gt.stamp('intrinsic reward calculation', unique=False)

        """
        Train policy
        """

        state_latents = np.concatenate([self._obs, self._latents], axis=-1)[:self._cur_replay_size]
        next_state_latents = np.concatenate(
            [self._true_next_obs, self._latents], axis=-1)[:self._cur_replay_size]

        # Calculate importance weights
        if self.use_importance_weighting:
            log_probs = self.latent_prior.get_log_probs(
                ptu.from_numpy(self._obs[:self._cur_replay_size]),
                ptu.from_numpy(self._latents[:self._cur_replay_size]),
            )
            log_probs = self.epsilon_greedy * (0.5**self.latent_dim) +\
                        (1-self.epsilon_greedy) * log_probs.exp()
            log_probs = torch.log(log_probs + 1e-6)
            unif_lp = (0.5**self.latent_dim) * ptu.ones(*log_probs.shape)  # assume Uniform(-1, 1)^n
            iw = torch.exp(torch.log(unif_lp) - log_probs)
            iw = torch.clamp(iw, 1 / self.importance_weighting_clip,
                                 self.importance_weighting_clip)
            iw = ptu.get_numpy(iw)

        for _ in range(self.num_policy_updates):
            if self.use_importance_weighting:
                batch = ppp.sample_batch(
                    self.policy_batch_size,
                    observations=state_latents,
                    next_observations=next_state_latents,
                    actions=self._actions[:self._cur_replay_size],
                    rewards=self._rewards[:self._cur_replay_size],
                    importance_weights=iw,
                )
            else:
                batch = ppp.sample_batch(
                    self.policy_batch_size,
                    observations=state_latents,
                    next_observations=next_state_latents,
                    actions=self._actions[:self._cur_replay_size],
                    rewards=self._rewards[:self._cur_replay_size],
                )
            batch = ptu.np_to_pytorch_batch(batch)
            self.policy_trainer.train_from_torch(batch)

        gt.stamp('policy training', unique=False)

        """
        Diagnostics
        """

        if self._need_to_update_eval_statistics:
            # self._need_to_update_eval_statistics = False
            self.eval_statistics.update(self.policy_trainer.eval_statistics)

            if self.relabel_rewards:
                self.eval_statistics.update(reward_diagnostics)

                if self.use_importance_weighting:
                    self.eval_statistics.update(create_stats_ordered_dict(
                        'Prior Importance Weights',
                        iw,
                    ))

                self.eval_statistics.update(create_stats_ordered_dict(
                    'Discriminator Log Pis',
                    logp,
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Discriminator Alt Log Pis',
                    logp_altz,
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Intrinsic Reward Denominator',
                    denom,
                ))

                if self._ptr < self._epoch_size:
                    if self._ptr == 0:
                        inds = np.r_[len(rewards)-self._epoch_size:len(rewards)]
                    else:
                        inds = np.r_[0:self._ptr,len(rewards)-self._ptr:len(rewards)]
                else:
                    inds = np.r_[self._ptr-self._epoch_size:self._ptr]

                self.eval_statistics.update(create_stats_ordered_dict(
                    'Intrinsic Rewards (Original)',
                    orig_rewards[inds],
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Intrinsic Rewards (Processed)',
                    rewards[inds],
                ))

        self._n_train_steps_total += 1

    @property
    def networks(self):
        return self.prior_trainer.networks + self.policy_trainer.networks + [
            self.discriminator, self.latent_prior,
        ]

    def get_snapshot(self):
        snapshot = super().get_snapshot()
        snapshot['latent_prior'] = self.latent_prior

        for k, v in self.prior_trainer.get_snapshot().items():
            snapshot['prior_trainer/' + k] = v

        return snapshot

    def end_epoch(self, epoch):
        super().end_epoch(epoch)
        self.prior_trainer.end_epoch(epoch)
