import torch
from ATAC.bp import BehaviorPretraining as ATAC_BehaviorPretraining
from armor.simple_armor import SimpleARMOR
from armor.util import to_transition_batch
from ATAC.util import DEFAULT_DEVICE
EPS=1e-6

class BehaviorPretraining(ATAC_BehaviorPretraining):
    """ Extend it to support taking arbitrary reference policy. """

    def __init__(self, *,
                reference=None,
                model=None,
                rng=torch.Generator(device=DEFAULT_DEVICE),
                **kwargs):
        super().__init__(**kwargs)
        self.reference = reference
        self.model = self._model =  model
        self._rng = rng

        if self.model is not None: # update the optimizer
            parameters = []
            for x in (self.policy, self.qf, self.vf, self.model):
                if x is not None:
                    parameters+= list(x.parameters())
            self.optimizer = torch.optim.Adam(parameters, lr=kwargs['lr'])

    def update(self, **batch):
        qf_loss = vf_loss = policy_loss = fitting_loss =  torch.tensor(0., device=batch['observations'].device)
        qf_info_dict = vf_info_dict = policy_info_dict = model_info_dict = {}
        # Compute loss
        if self.policy is not None:
            policy_loss, policy_info_dict = self.compute_policy_loss(**batch)
        if self.qf is not None:
            qf_loss, qf_info_dict = self.compute_qf_loss(**batch)
        if self.vf is not None:
            vf_loss, vf_info_dict = self.compute_vf_loss(**batch)
        if self.model is not None:
            fitting_loss, model_info_dict = self.compute_model_loss(**batch)

        # Update
        loss = policy_loss + qf_loss + vf_loss + fitting_loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        # Log
        info_dict = {**qf_info_dict, **vf_info_dict, **policy_info_dict, **model_info_dict}
        return info_dict

    def compute_policy_loss(self, observations, actions, **kwargs):
        if self.reference is not None:  # overwrite the data actions to clone the reference
            with torch.no_grad():
                ref_outs = self.reference(observations)
                if isinstance(ref_outs, torch.distributions.Distribution):
                    actions = ref_outs.sample()
                elif torch.is_tensor(ref_outs):
                    actions = ref_outs
                actions = torch.clip(actions, -1+EPS, 1-EPS)  # due to tanh
        return super().compute_policy_loss(observations, actions, **kwargs)

    def get_model_prediction(self, observations, actions, model_state=None):
        return SimpleARMOR.get_model_prediction(self, observations, actions, model_state=None)

    def compute_qf_loss(self, observations, actions, next_observations, rewards, terminals, **kwargs):
        # Add model fitting error.
        qf_loss, info_dict = super().compute_qf_loss(observations, actions, next_observations, rewards, terminals, **kwargs)
        # if self.model is not None:
            # batch = dict(observations=observations,
            #             actions=actions,
            #             next_observations=next_observations,
            #             rewards=rewards,
            #             terminals=terminals)
            # fitting_loss = self._model.loss(to_transition_batch(batch))[0]
            # info_dict['model_loss'] = fitting_loss.item()
            # qf_loss += fitting_loss
        return qf_loss, info_dict

    def compute_model_loss(self, observations, actions, next_observations, rewards, terminals, **kwargs):
        batch = dict(observations=observations,
                    actions=actions,
                    next_observations=next_observations,
                    rewards=rewards,
                    terminals=terminals)
        fitting_loss = self._model.loss(to_transition_batch(batch))[0]
        info_dict = {'model_loss': fitting_loss.item()}
        return fitting_loss, info_dict