import abc

import gtimer as gt
from lfrl.core.rl_algorithms.offline.offline_rl_algorithm import OfflineRLAlgorithm


class OfflineMBRLAlgorithm(OfflineRLAlgorithm, metaclass=abc.ABCMeta):

    def __init__(
            self,
            model_trainer,
            model_batch_size,
            model_train_freq,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.model_trainer = model_trainer
        self.model_batch_size = model_batch_size
        self.model_train_freq = model_train_freq

    def _train(self):
        print(len(self.replay_buffer.get_transitions()))
        self.model_trainer.train_from_buffer(
            self.replay_buffer,
            max_grad_steps=250000,
            max_epochs_since_update=50,
        )
        gt.stamp('model training', unique=False)

        for epoch in gt.timed_for(
            range(self._start_epoch, self.num_epochs),
            save_itrs=True,
        ):
            self.eval_data_collector.collect_new_paths(
                self.max_path_length,
                self.num_eval_steps_per_epoch,
                discard_incomplete_paths=True,
            )
            gt.stamp('evaluation sampling')

            self.training_mode(True)
            for _ in range(self.num_train_loops_per_epoch):
                for t in range(self.num_trains_per_train_loop):
                    """
                    if self.model_train_freq is not None and t % self.model_train_freq == 0:
                        self.model_trainer.train_from_buffer(self.replay_buffer)
                        gt.stamp('model training')
                    """
                    train_data = self.replay_buffer.random_batch(self.batch_size)
                    self.trainer.train(train_data)
                    gt.stamp('policy training', unique=False)
            self.training_mode(False)

            self._end_epoch(epoch)

    def _get_training_diagnostics_dict(self):
        training_diagnostics = super()._get_training_diagnostics_dict()
        training_diagnostics['model_trainer'] = self.model_trainer.get_diagnostics()
        return training_diagnostics

    def _get_snapshot(self):
        snapshot = super()._get_snapshot()
        for k, v in self.model_trainer.get_snapshot().items():
            snapshot['model/' + k] = v
        return snapshot

    def _end_epochs(self, epoch):
        super()._end_epochs(epoch)
        self.model_trainer.end_epoch(epoch)
