from collections import OrderedDict
import gtimer as gt

import numpy as np
import torch

import lfrl.torch.pytorch_util as ptu
from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer
from lfrl.samplers import model_policy_rollout_with_disagreement_torch

def always_one(train_steps):
    return 1

class MBPOTrainer(TorchTrainer):
    def __init__(
            self,
            policy_trainer,
            dynamics_model, # MBPOTrainer not responsible for training this
            replay_buffer,
            generated_data_buffer,
            num_model_rollouts=400,
            rollout_generation_freq=250,
            num_policy_updates=20,
            rollout_len_func=always_one,
            rollout_batch_size=int(1e3),
            max_rollout_len=100,
            disagreement_threshold=0.1,
            real_data_pct=0.05,
            **kwargs
    ):
        super().__init__()
        
        self.policy_trainer = policy_trainer
        self.policy = policy_trainer.policy
        self.dynamics_model = dynamics_model
        self.replay_buffer = replay_buffer
        self.generated_data_buffer = generated_data_buffer

        self.num_model_rollouts = num_model_rollouts
        self.rollout_generation_freq = rollout_generation_freq
        self.num_policy_updates = num_policy_updates
        self.rollout_len_func = rollout_len_func 
        self.rollout_batch_size = 5 # rollout_batch_size
        self.max_rollout_len = max_rollout_len
        self.disagreement_threshold = disagreement_threshold
        self.real_data_pct = real_data_pct

        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()
        
    def train_from_torch(self, batch):
        # we'll ignore batch (slight violation of abstraction)

        """
        Generate data using dynamics model
        """
        # print('!!!!!!!!!!!!!!1', self._n_train_steps_total)
        if self._n_train_steps_total % self.rollout_generation_freq == 0:
            total_samples = min(
                self.rollout_generation_freq * self.num_model_rollouts,
                self.generated_data_buffer.max_replay_buffer_size()
            )

            # getting this correct improves computation time, but doesn't
            # affect performance
            if self._n_train_steps_total < 100:
                rollout_batch_size = 1024
            elif self._n_train_steps_total < 1000:
                rollout_batch_size = 1024
            else:
                rollout_batch_size = 512 # self.rollout_batch_size

            num_samples = 0
            while num_samples < total_samples:
                batch_samples = min(rollout_batch_size, total_samples-num_samples)
                real_batch = self.replay_buffer.random_batch(batch_samples)
                start_states = real_batch['observations']

                with torch.no_grad():
                    paths = model_policy_rollout_with_disagreement_torch(
                        self.dynamics_model,
                        self.policy_trainer.policy,
                        start_states,
                        disagreement_threshold=self.disagreement_threshold,
                        max_path_length=self.max_rollout_len,
                    )

                for path in paths:
                    self.generated_data_buffer.add_path(path)
                    num_samples += len(path['observations'])
                    if num_samples >= total_samples:
                        break

            gt.stamp('generating rollouts', unique=False)

        """
        Update policy on both real and generated data
        """

        batch_size = batch['observations'].shape[0]
        n_real_data = int(self.real_data_pct * batch_size)
        n_generated_data = batch_size - n_real_data

        for _ in range(self.num_policy_updates):

            batch = self.replay_buffer.random_batch(n_real_data)
            generated_batch = self.generated_data_buffer.random_batch(
                n_generated_data)

            for k in ('rewards', 'terminals', 'observations',
                      'actions', 'next_observations'):
                batch[k] = np.concatenate((batch[k], generated_batch[k]), axis=0)
                batch[k] = ptu.from_numpy(batch[k])

            self.policy_trainer.train_from_torch(batch)

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False

            returns, terminated, rollout_lens = [], [], []
            for path in paths:
                returns.append(np.sum(path['rewards']))
                terminated.append(np.sum(path['terminals']) > 0)
                rollout_lens.append(len(path['rewards']))

            average_return = np.mean(returns)
            terminated_pct = np.mean(terminated)
            avg_rollout_len = np.mean(rollout_lens)

            self.eval_statistics['MBPO Rollout Length'] = avg_rollout_len
            self.eval_statistics['MBPO Rollout Return Average'] = average_return
            self.eval_statistics['MBPO Rollout Return Per Timestep'] = \
                average_return / avg_rollout_len
            self.eval_statistics['MBPO Rollout Terminated Percent'] = terminated_pct

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        self.eval_statistics.update(self.policy_trainer.eval_statistics)
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks

    def get_snapshot(self):
        mbpo_snapshot = dict(
            dynamics_model=self.dynamics_model
        )
        mbpo_snapshot.update(self.policy_trainer.get_snapshot())
        return mbpo_snapshot

