import torch

import lfrl.torch.pytorch_util as ptu
from lfrl.trainers.qpg.mbpo import MBPOTrainer
from lfrl.util.eval_util import create_stats_ordered_dict


class PenaltyMBPOTrainer(MBPOTrainer):
    def __init__(
            self,
            reward_penalty,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.reward_penalty = reward_penalty

        self.terminal_cutoff = 0.5
        self.last_disagreement = None

    def predict_transition(self, state_actions):
        transitions, disagreement = self.dynamics_model.sample_with_disagreement(state_actions)
        disagreement = disagreement.view(-1, 1)
        self.last_disagreement = disagreement
        r = transitions[:, :1] - self.reward_penalty * disagreement
        d = (transitions[:, 1:2] > self.terminal_cutoff).float()
        obs_delta = transitions[:, 2:]
        return r, d, obs_delta

    def get_diagnostics(self):
        self.eval_statistics.update(self.policy_trainer.eval_statistics)
        self.eval_statistics.update(create_stats_ordered_dict(
            'Model Disagreement',
            ptu.get_numpy(self.last_disagreement),
        ))
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True
        self.policy_trainer.end_epoch(epoch)

    @property
    def networks(self):
        return self.policy_trainer.networks

    def get_snapshot(self):
        mbpo_snapshot = dict(
            dynamics_model=self.dynamics_model
        )
        mbpo_snapshot.update(self.policy_trainer.get_snapshot())
        return mbpo_snapshot
