from collections import OrderedDict
import copy

from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer


class RENTrainer(TorchTrainer):

    def __init__(self, ren_ensemble, policy_trainer, update_ensemble_every=25):
        super().__init__()
        self.ren_ensemble = ren_ensemble
        self.policy_trainer = policy_trainer
        self.update_ensemble_every = update_ensemble_every

        self.eval_statistics = OrderedDict()
        self._need_to_update_eval_statistics = True
        self._num_timesteps_trained = 0

    def train_from_torch(self, batch):
        self.policy_trainer.train_from_torch(batch)
        if self._num_timesteps_trained % self.update_ensemble_every == 0:
            for i in range(1, len(self.ren_ensemble.policies)-1):
                self.ren_ensemble.policies[i+1] = self.ren_ensemble.policies[i]
            if len(self.ren_ensemble.policies) > 1:
                self.ren_ensemble.policies[1] = copy.deepcopy(self.ren_ensemble.policies[0])
        self._num_timesteps_trained += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self.ren_ensemble.end_epoch(epoch)
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [self.ren_ensemble]

    def get_snapshot(self):
        return dict(
            ren_ensemble=self.ren_ensemble,
        )
