"""
directly based on https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/training/seq_trainer.py
"""
import numpy as np
import torch

from decision_transformer.training.trainer import Trainer


class SequenceTrainer(Trainer):

    def train_step(self):
        states, actions, rewards, dones, rtg, timesteps, attention_mask = self.get_batch(self.batch_size)
        action_target = torch.clone(actions)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )
        act_dim = actions.shape[2]
        if isinstance(action_preds, tuple) and len(action_preds) == 3:
            pi, mu, sigma = action_preds

            pi_masked = pi.reshape(-1, pi.shape[-1])[attention_mask.reshape(-1) > 0]  # [valid_tokens, n_mix]
            mu_masked = mu.reshape(-1, mu.shape[-2], mu.shape[-1])[attention_mask.reshape(-1) > 0]  # [valid_tokens, n_mix, act_dim]
            sigma_masked = sigma.reshape(-1, sigma.shape[-2], sigma.shape[-1])[attention_mask.reshape(-1) > 0]  # [valid_tokens, n_mix, act_dim]

            action_preds = (pi_masked, mu_masked, sigma_masked)
        else:
            action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        return loss.detach().cpu().item()
