from collections import OrderedDict
import gtimer as gt

import numpy as np
import torch

import lfrl.torch.pytorch_util as ptu
from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
from lfrl.samplers.utils.torch_rollout_functions import model_policy_rollout_torch


def always_one(train_steps):
    return 1


class MBPOTrainer(TorchTrainer):

    def __init__(
            self,
            policy_trainer,
            dynamics_model,  # MBPOTrainer not responsible for training this
            replay_buffer,
            generated_data_buffer,
            num_model_rollouts=400,
            rollout_generation_freq=250,
            num_policy_updates=20,
            rollout_len_func=always_one,
            rollout_batch_size=int(1e3),
            real_data_pct=0.05,
            **kwargs
    ):
        super().__init__()

        self.policy_trainer = policy_trainer
        self.policy = policy_trainer.policy
        self.dynamics_model = dynamics_model
        self.replay_buffer = replay_buffer
        self.generated_data_buffer = generated_data_buffer

        self.num_model_rollouts = num_model_rollouts
        self.rollout_generation_freq = rollout_generation_freq
        self.num_policy_updates = num_policy_updates
        self.rollout_len_func = rollout_len_func
        self.rollout_batch_size = rollout_batch_size
        self.real_data_pct = real_data_pct

        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()
        self.terminal_cutoff = 0.5

    def predict_transition(self, state_actions):
        transitions = self.dynamics_model.sample(state_actions)
        if (transitions != transitions).any():
            print('warning: nan transitions')
            transitions[transitions != transitions] = 0
        r = transitions[:, :1]
        d = (transitions[:, 1:2] > self.terminal_cutoff).float()
        obs_delta = transitions[:, 2:]
        return r, d, obs_delta

    def train_from_torch(self, batch):
        # we'll ignore batch (slight violation of abstraction)

        """
        Generate data using dynamics model
        """
        if self._n_train_steps_total % self.rollout_generation_freq == 0:
            rollout_len = self.rollout_len_func(self._n_train_steps_total)
            total_samples = min(
                self.rollout_generation_freq * self.num_model_rollouts * rollout_len,
                self.generated_data_buffer.max_replay_buffer_size()
            )

            num_samples = 0
            while num_samples < total_samples:
                batch_samples = min(self.rollout_batch_size, total_samples - num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = model_policy_rollout_torch(
                        self.dynamics_model,
                        self.policy_trainer.policy,
                        start_states,
                        max_path_length=rollout_len,
                        predict_transition=self.predict_transition,
                    )

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])

            gt.stamp('generating rollouts', unique=False)

        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):

            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations',
                      'actions', 'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            self.eval_statistics['MBPO Rollout Length'] = rollout_len

            returns, terminated = [], []
            for path in paths:
                returns.append(np.sum(path['rewards']))
                terminated.append(np.sum(path['terminals']) > 0)

            average_return = sum(returns) / len(paths)
            terminated_pct = sum(terminated) / len(paths)

            self.eval_statistics['MBPO Rollout Return Average'] = average_return
            self.eval_statistics['MBPO Rollout Return Per Timestep'] = \
                average_return / rollout_len
            self.eval_statistics['MBPO Rollout Terminated Percent'] = terminated_pct

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        self.eval_statistics.update(self.policy_trainer.eval_statistics)
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks

    def get_snapshot(self):
        mbpo_snapshot = dict(
            dynamics_model=self.dynamics_model
        )
        mbpo_snapshot.update(self.policy_trainer.get_snapshot())
        return mbpo_snapshot
