from collections import OrderedDict
import gtimer as gt
import time

import numpy as np
import torch

import lfrl.torch.pytorch_util as ptu
from lfrl.torch.pytorch_util import np_to_pytorch_batch
from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
from lfrl.models.uncertainty.rnd import RandomNetworkDistillation
from lfrl.samplers import model_policy_rollout_torch_rnd

def always_one(train_steps):
    return 1

class MBPOTrainer(TorchTrainer):
    def __init__(
            self,
            policy_trainer,
            dynamics_model, # MBPOTrainer not responsible for training this
            replay_buffer,
            generated_data_buffer,
            num_model_rollouts=400,
            rollout_generation_freq=250,
            num_policy_updates=20,
            rollout_len_func=always_one,
            rollout_batch_size=int(1e3),
            real_data_pct=0.05,
            rnd_grad_steps=1,
            **kwargs
    ):
        super().__init__()
        
        self.policy_trainer = policy_trainer
        self.policy = policy_trainer.policy
        self.dynamics_model = dynamics_model
        self.replay_buffer = replay_buffer
        self.generated_data_buffer = generated_data_buffer

        self.num_model_rollouts = num_model_rollouts
        self.rollout_generation_freq = rollout_generation_freq
        self.num_policy_updates = num_policy_updates
        self.rollout_len_func = rollout_len_func 
        self.rollout_batch_size = rollout_batch_size
        self.real_data_pct = real_data_pct

        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()

        self.rnd = RandomNetworkDistillation(
            ensemble_size=3,
            input_size=self.dynamics_model.obs_dim + self.dynamics_model.plan_dim,
            hidden_sizes=[64,64],
            init_w=1,
        )
        self.rnd_grad_steps = rnd_grad_steps

        self.start_time = time.time()
        
    def train_from_torch(self, batch):
        # we'll ignore batch (slight violation of abstraction)

        """
        Generate data using dynamics model
        """
        if self._n_train_steps_total % 1 == 0:
            data = self.replay_buffer.get_transitions()
            x = data[:,:self.dynamics_model.obs_dim + self.dynamics_model.plan_dim]
            self.rnd.fit_input_stats(x)

            # train RND
            sa = ptu.from_numpy(x[-self.rollout_generation_freq:])
            for _ in range(self.rnd_grad_steps):
                self.rnd.train(sa)

            with torch.no_grad():
                rnd_pred = self.rnd.get_prediction(ptu.from_numpy(x[-4096:]))
                rnd_threshold = np.mean(ptu.get_numpy(rnd_pred.mean()))

            gt.stamp('training rnd', unique=False)

            # rollout_len = self.rollout_len_func(self._n_train_steps_total)
            total_samples = min(
                self.num_model_rollouts,
                self.generated_data_buffer.max_replay_buffer_size()
            )

            num_samples = 0
            while num_samples < self.num_model_rollouts:
                batch_samples = min(self.rollout_batch_size, total_samples-num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = model_policy_rollout_torch_rnd(
                        self.dynamics_model,
                        self.policy_trainer.policy,
                        self.rnd,
                        start_states,
                        rnd_threshold,
                        max_path_length=200 if self._n_train_steps_total > 0 else 1 # rollout_len,
                    )

                rollout_lens = []
                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    rollout_lens.append(len(path['observations']))
                    num_samples += len(path['observations'])

            gt.stamp('generating rollouts', unique=False)

            avg_rollout_len = np.mean(rollout_lens)

            if self._n_train_steps_total % 50 == 0:
                print('%d | avg rollout len: %.4f | %.4f s' % \
                    (self._n_train_steps_total, avg_rollout_len, time.time() - self.start_time))

                test_batch_size = 4096
                with torch.no_grad():
                    fake_batch = np_to_pytorch_batch(
                        self.generated_data_buffer.random_batch(test_batch_size))
                    sa = torch.cat((fake_batch['observations'], fake_batch['actions']), dim=-1)
                    rnd_pred = self.rnd.get_prediction(sa)
                    pred = np.mean(ptu.get_numpy(rnd_pred.mean()))
                    self.eval_statistics['RND Prediction - Fake'] = pred
                    print('fake: %.4f' % pred)

                    real_batch = np_to_pytorch_batch(
                        self.replay_buffer.random_batch(test_batch_size))
                    sa = torch.cat((real_batch['observations'], real_batch['actions']), dim=-1)
                    rnd_pred = self.rnd.get_prediction(sa)
                    pred = np.mean(ptu.get_numpy(rnd_pred.mean()))
                    self.eval_statistics['RND Prediction - True'] = pred
                    print('real: %.4f' % pred)

            gt.stamp('training rnd', unique=False)

        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):

            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations',
                      'actions', 'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['MBPO Rollout Length'] = avg_rollout_len

            rewards, returns, terminated = [], [], []
            for path in paths:
                rewards.append(np.mean(path['rewards']))
                returns.append(np.sum(path['rewards']))
                terminated.append(np.sum(path['terminals']) > 0)

            average_reward = np.mean(rewards)
            average_return = sum(returns) / len(paths)
            terminated_pct = sum(terminated) / len(paths)

            self.eval_statistics['MBPO Rollout Return Average'] = average_return
            self.eval_statistics['MBPO Rollout Return Per Timestep'] = average_reward
            self.eval_statistics['MBPO Rollout Terminated Percent'] = terminated_pct

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        self.eval_statistics.update(self.policy_trainer.eval_statistics)
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks

    def get_snapshot(self):
        mbpo_snapshot = dict(
            dynamics_model=self.dynamics_model,
            rnd=self.rnd,
        )
        mbpo_snapshot.update(self.policy_trainer.get_snapshot())
        return mbpo_snapshot

