import torch

from transfer.training.trainer import Trainer


class ActTrainer(Trainer):
    def train_step(self, data):
        states, actions, rewards, dones, rtg, _, attention_mask = data
        state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards)

        state_preds, action_preds, reward_preds = self.model.forward(
            states,
            actions,
            rewards,
            attention_mask=attention_mask,
            target_return=rtg[:, 0],
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)
        action_target = action_target[:, -1].reshape(-1, act_dim)

        loss = self.loss_fn(
            state_preds,
            action_preds,
            reward_preds,
            state_target,
            action_target,
            reward_target,
        )
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().item()
