# yapf: disable
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ATAC.atac import ATAC, normalized_sum, l2_projection
from ATAC.util import compute_batched, DEFAULT_DEVICE, update_exponential_moving_average, sample_batch, cat_data_dicts
from armor.util import to_transition_batch

class SimpleARMOR(ATAC):
    """ ARMOR: Adversarial Models for Offline Reinforcement Learning """
    def __init__(self, *,
                 reference,
                 model,
                 model_lr=None,
                 rollout_horizon=5,
                 model_buffer_size=10**4,
                 model_batch_size=256,
                 reg_coeff=100.0,
                 rng=torch.Generator(device=DEFAULT_DEVICE),
                 obs_highs=None,
                 obs_lows=None,
                 wt=0.5,
                 ws=0.5,
                 bellman_model_grad_type='old', # 'rs', 'rstd'
                 use_data_terminals=False,
                 ignore_model_prediction=False,
                 hybrid=False,
                 atac_mode=False,
                 margin=float('inf'),
                 use_model_grad=True, 
                 **kwargs
                 ):

        #############################################################################################
        super().__init__(**kwargs)
        self._reference = reference
        self._model = model
        self._model_optimizer = torch.optim.Adam(self._model.parameters(), lr=model_lr or self._qf_lr)
        self._step_count = 0
        self._prev_model_obs_pred = None
        self._rollout_horizon = rollout_horizon
        self._model_buffer = None  # A dict of `observations` and `actions`
        self._model_buffer_size = model_buffer_size
        self._model_buffer_ind = 0  # start index
        self._model_buffer_current_size = 0  # model buffer size
        self._model_batch_size = model_batch_size
        self._reg_coeff = reg_coeff / float(self._model.model.num_members)
        self._ws = ws
        self._wt = wt
        self._bellman_model_grad_type = bellman_model_grad_type
        self._use_data_terminals = use_data_terminals
        self._ignore_model_prediction = ignore_model_prediction
        self._use_model_grad = use_model_grad
        self._hybrid = hybrid
        self._margin = margin
        self._atac_mode = atac_mode
        self._model_state = None
        self._rng = rng
        delta = 10.0
        self._obs_highs, self._obs_lows = (1.0 + obs_highs.sign()*delta) * obs_highs, (1.0 - obs_lows.sign()*delta) *  obs_lows

    def get_model_prediction(self, observations, actions, model_state=None, deterministic=False):
        """ One-step model prediction. """
        # rng = torch.Generator(device=DEFAULT_DEVICE)
        # model_state = self._model.reset(observations, rng=rng)
        if model_state is None:
            model_state = self._model.reset(observations, rng=self._rng)
        next_observations_pred, rewards_pred, terminal_pred, next_model_state = self._model.sample(
            actions, model_state, deterministic=deterministic, rng=self._rng)
        if self._obs_highs is not None:
            next_observations_pred = torch.clip(next_observations_pred, self._obs_lows, self._obs_highs)
            next_model_state["obs"] = next_observations_pred
        assert torch.isclose(next_model_state["obs"], next_observations_pred).prod()

        rewards_pred = rewards_pred.flatten()
        return dict(observations=observations,
                    actions=actions,
                    next_observations=next_observations_pred,
                    rewards=rewards_pred,
                    terminals=terminal_pred.flatten()), next_model_state

    def update(self, observations, actions, next_observations, rewards, terminals, **kwargs):
        """ Perform ATAC update with additional model loss to minimize data fitting error and Bellman residual error. """
        batch = dict(observations=observations,
                     actions=actions,
                     next_observations=next_observations,
                     rewards=rewards,
                     terminals=terminals)
        batch_size = len(rewards)
        # Generate rewards and next_observations by model prediction
        batch_joint = dict(observations=observations, actions=actions)
        if self._model_buffer is not None:  # Call model prediction to make batch complete
            current_model_buffer = {k:v[:self._model_buffer_current_size] for k,v in self._model_buffer.items() }
            batch_model = sample_batch(current_model_buffer, self._model_batch_size)  # SA from prev model rollouts=
            batch_joint = cat_data_dicts(batch_joint, batch_model)
        with torch.set_grad_enabled(self._use_model_grad):
            batch_joint, _ = self.get_model_prediction(**batch_joint) # trace
        if self._use_data_terminals:
            batch_joint['terminals'] = torch.cat([terminals,  batch_joint['terminals'].flatten()[batch_size:] ])

        if self._hybrid:
            batch_joint = cat_data_dicts(batch_joint, batch)

        if self._ignore_model_prediction or self._atac_mode:  # for debugging
            batch_joint = batch

        # Update
        fitting_loss = self._model.loss(to_transition_batch(batch))[0]
        log_info = self._update(**batch_joint, model_fitting_loss=fitting_loss)

        batch_size = len(observations)

        if self._rollout_horizon>0:
            with torch.no_grad():  # rollout under policy and reference
                if self._step_count % self._rollout_horizon == 0:
                    obs_buff = torch.cat([observations, observations], axis=0)
                    self._model_state = self._model.reset(obs_buff, rng=self._rng)
                else:
                    obs_buff = self._model_state["obs"]
                actions_policy = self.policy.act(obs_buff[:batch_size])
                actions_reference = self._reference.act(obs_buff[batch_size:])
                act_buff = torch.cat([actions_policy, actions_reference])
                self._step_count += 1
                # Predict one-step and update model pred buffer
                pred_data, self._model_state = self.get_model_prediction(obs_buff, act_buff, self._model_state)
                #sample terminal flags
                done_arr = torch.rand(pred_data['terminals'].shape, device=DEFAULT_DEVICE) < pred_data['terminals']
                self._model_state["obs"] = self._model_state["obs"][~done_arr]
                self._prev_model_obs_pred = pred_data['next_observations']
                self._prev_model_obs_pred = self._prev_model_obs_pred[~done_arr]
                # Update model_buffer
                new_model_data = {k: pred_data[k].detach().cpu().numpy()  for k in ('observations', 'actions') }
                new_batch_size = len(new_model_data['actions'])
                if new_batch_size>0:
                    if self._model_buffer is None:  # initialize
                        self._model_buffer = {}
                        for k,v in new_model_data.items():
                            self._model_buffer[k] = np.zeros((self._model_buffer_size, v.shape[1]))
                            self._model_buffer[k][:new_batch_size] = v  # assume new_batch_size < self._model_buffer_size
                        self._model_buffer_current_size = new_batch_size
                        self._model_buffer_ind = new_batch_size
                    else:
                        ind = np.arange(self._model_buffer_ind, self._model_buffer_ind+new_batch_size) % self._model_buffer_size
                        for k in self._model_buffer:
                            self._model_buffer[k][ind] = new_model_data[k]
                        self._model_buffer_ind = (ind[-1]+1) % self._model_buffer_size
                        self._model_buffer_current_size = min(self._model_buffer_current_size+new_batch_size, self._model_buffer_size)

        # logging
        pred_errors = { k+'_pred_error' : F.mse_loss( batch_joint[k][:batch_size].float(),  batch[k].float()).item() for k in ('rewards', 'next_observations', 'terminals')}
        log_info.update(pred_errors)
        return log_info

    def _update(self, observations, actions, next_observations, rewards, terminals, model_fitting_loss):

        rewards = rewards.flatten()
        terminals = terminals.flatten().float()

        ##### Update Critic AND Model #####
        def compute_bellman_backup(q_pred_next):
            assert rewards.shape == q_pred_next.shape
            return (rewards + (1.-terminals)*self._discount*q_pred_next).clamp(min=self._Vmin, max=self._Vmax)

        # Pre-computation
        with torch.set_grad_enabled(self._bellman_model_grad_type in ('rs' or 'rstd')):
            new_next_actions = self.policy(next_observations).rsample()
        with torch.set_grad_enabled(self._bellman_model_grad_type == 'rstd'):
            target_q_values = self._target_qf(next_observations, new_next_actions)  # projection
            q_target = compute_bellman_backup(target_q_values.flatten())
            actions_ref = self._reference(observations).sample()
            if self._atac_mode: 
                actions_ref = actions  

        new_actions_dist = self.policy(observations)  # This will be used to compute the entropy
        new_actions = new_actions_dist.rsample() # These samples will be used for the actor update too, so they need to be traced.

        qf_pred_both, qf_pred_next_both, qf_new_actions_both, qf_actions_ref_both \
            = compute_batched(self._qf.both, [observations, next_observations, observations,         observations],
                                             [actions,      new_next_actions,  new_actions.detach(), actions_ref])  # XXX ARMOR

        qf_loss = 0
        for qfp, qfpn, qfna, qfar in zip(qf_pred_both, qf_pred_next_both, qf_new_actions_both, qf_actions_ref_both):
            # Compute Bellman error
            assert qfp.shape == qfpn.shape == qfna.shape == q_target.shape == qfar.shape
            target_error = F.mse_loss(qfp, q_target)
            q_target_pred = compute_bellman_backup(qfpn)
            residual_error = F.mse_loss(qfp, q_target_pred)
            qf_bellman_loss = self._wt*target_error + self._ws*residual_error
            # Compute pessimism term
            pess_loss = torch.clamp((qfna - qfar).mean(), min=-self._margin)  
            ## Compute full q loss
            reg_loss = qf_bellman_loss + model_fitting_loss*self._reg_coeff
            qf_loss += normalized_sum(pess_loss, reg_loss, self.beta)

        # Update q AND model
        self._model_optimizer.zero_grad()
        self._qf_optimizer.zero_grad()
        qf_loss.backward()
        self._qf_optimizer.step()
        self._qf.apply(l2_projection(self._norm_constraint))
        update_exponential_moving_average(self._target_qf, self._qf, self._tau)
        self._model_optimizer.step()

        ##### Update Actor #####
        # Compute entropy
        log_pi_new_actions = new_actions_dist.log_prob(new_actions)
        policy_entropy = -log_pi_new_actions.mean()

        alpha_loss = 0
        if self._use_automatic_entropy_tuning:
            alpha_loss = self._log_alpha * (policy_entropy.detach() - self._target_entropy)
            self._alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self._alpha_optimizer.step()

        # Compute performance difference lower bound (policy_loss = - lower_bound - alpha * policy_kl)
        alpha = self._log_alpha.exp().detach()
        self._qf.requires_grad_(False)
        lower_bound = self._qf.both(observations, new_actions)[-1].mean() # just use one network
        self._qf.requires_grad_(True)
        policy_loss = normalized_sum(-lower_bound, -policy_entropy, alpha)

        self._policy_optimizer.zero_grad()
        policy_loss.backward()
        self._policy_optimizer.step()

        # Log
        log_info = dict(policy_loss=policy_loss.item(),
                        qf_loss=qf_loss.item(),
                        qf_bellman_loss=qf_bellman_loss.item(),
                        pess_loss=pess_loss.item(),
                        alpha_loss=alpha_loss.item(),
                        policy_entropy=policy_entropy.item(),
                        alpha=alpha.item(),
                        lower_bound=lower_bound.item(),
                        model_fitting_loss=model_fitting_loss.item(),
                        )

        # For logging
        if self._debug:
            with torch.no_grad():
                debug_log_info = dict(
                        bellman_surrogate=residual_error.item(),
                        qf1_pred_mean=qf_pred_both[0].mean().item(),
                        qf2_pred_mean = qf_pred_both[1].mean().item(),
                        q_target_mean = q_target.mean().item(),
                        target_q_values_mean = target_q_values.mean().item(),
                        qf1_new_actions_mean = qf_new_actions_both[0].mean().item(),
                        qf2_new_actions_mean = qf_new_actions_both[1].mean().item(),
                        qf1_actions_ref_mean = qf_actions_ref_both[0].mean().item(),
                        qf2_actions_ref_mean = qf_actions_ref_both[1].mean().item(),
                        action_diff = torch.mean(torch.norm(actions_ref - new_actions, dim=1)).item()
                        )
            log_info.update(debug_log_info)
        return log_info
