from collections import OrderedDict

from lfrl.core.rl_algorithms.torch_rl_algorithm import TorchTrainer


class DoubleTrainer(TorchTrainer):

    def __init__(
            self,
            trainer_1,
            trainer_2,
            trainer_1_steps=1,
            trainer_2_steps=1,
    ):
        super().__init__()

        self.trainer_1 = trainer_1
        self.trainer_2 = trainer_2
        self.trainer_1_steps = trainer_1_steps
        self.trainer_2_steps = trainer_2_steps

        self.eval_statistics = OrderedDict()

    def train_from_torch(self, batch):
        for _ in range(self.trainer_1_steps):
            self.trainer_1.train_from_torch(batch)
        for _ in range(self.trainer_2_steps):
            self.trainer_2.train_from_torch(batch)

        for k, v in self.trainer_1.get_diagnostics().items():
            self.eval_statistics['trainer_1/%s' % k] = v
        for k, v in self.trainer_2.get_diagnostics().items():
            self.eval_statistics['trainer_2/%s' % k] = v

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self.trainer_1.end_epoch(epoch)
        self.trainer_2.end_epoch(epoch)

    @property
    def networks(self):
        return self.trainer_1.networks + self.trainer_2.networks

    def get_snapshot(self):
        snapshot = dict()
        for k, v in self.trainer_1.get_snapshot().items():
            snapshot['trainer_1/%s' % k] = v
        for k, v in self.trainer_2.get_snapshot().items():
            snapshot['trainer_2/%s' % k] = v
        return snapshot
