import numpy as np
import torch

from GDT.decision_transformer2.training.trainer import Trainer
from torch.nn import functional as F  # noqa


class SequenceTrainer(Trainer):

    def train_step(self,vent):
        states, actions, rewards, dones, rtg, timesteps, attention_mask,s_next = self.get_batch(self.batch_size) #s, a, r, d, rtg, timesteps, mask
        action_target = torch.clone(actions)
        state_target = torch.clone(states)
        reward_target = torch.clone(rewards)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg, timesteps, attention_mask=attention_mask,
        )
        #if vent == False:
        log_likelihood = action_preds.log_prob(actions)[attention_mask > 0].mean() # LOG
        entropy = action_preds.entropy()[attention_mask>0].mean() #H
        entropy_reg = self.model.temperature().detach() #lambda
        entropy_reg_item = entropy_reg.item()
            
        act_loss = -(log_likelihood + entropy_reg * entropy)

        action_preds = action_preds.mean  # [2048,20,2]
        action_preds = action_preds[attention_mask > 0] #[2048,20,2]--[16553,2]
        # else:
        #     action_preds  = torch.round(action_preds).float()
        #     action_preds = torch.clamp(action_preds,0,self.model.act_max)
        #     act_loss = self.loss_fn(action_preds,actions.detach())
        #     act_loss = (act_loss * attention_mask.unsqueeze(-1)).mean()

        if abs(sum(action_preds[action_preds<0])) == 0:
            p = 0
        else:
            a = action_preds[action_preds<0]
            p = abs(torch.mean(a))
            w = abs(sum(a))

        
        reward_loss = F.mse_loss(reward_preds, reward_target.detach(), reduction="none")
        reward_loss = (reward_loss*attention_mask.unsqueeze(-1)).mean()

        # act_loss = F.mse_loss(action_preds, action_target.detach(), reduction="none")
        # # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
        # act_loss = (act_loss * attention_mask.unsqueeze(-1)).mean()
        a = state_preds
        b = s_next
        state_loss = F.mse_loss(state_preds,s_next.detach(),reduction="none")
        state_loss = (state_loss * attention_mask.unsqueeze(-1)).mean()


        # act_dim = action_preds.shape[2]
        # action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        # action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        # state_dim = state_preds.shape[2]
        # state_preds = state_preds.reshape(-1,state_dim)[attention_mask.reshape(-1)>0]
        # state_target = state_target.reshape(-1,state_dim)[attention_mask.reshape(-1)>0]

        # action_loss = torch.mean((action_preds - action_target) ** 2)
        # state_loss = torch.mean((state_preds - state_target) ** 2)

        loss = act_loss+0.02*state_loss+0.02*reward_loss#+0.01*p

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        #lambda更新
        #if vent==False:
        self.log_temperature_optimizer.zero_grad()
        temperature_loss = (self.model.temperature()*(entropy-self.model.target_entropy).detach())
        temperature_loss.backward()
        self.log_temperature_optimizer.step()

        with torch.no_grad():
            self.diagnostics['G_training/action_error'] = act_loss.item()
            self.diagnostics['G_training/state_error'] = state_loss.item()
            self.diagnostics['G_training/reward_error'] = reward_loss.item()
            self.diagnostics['G_training/loss'] = loss.item()
            self.diagnostics['G_training/entropy_H'] = entropy.item()
            self.diagnostics['G_training/entropy_lambda'] = entropy_reg_item
            self.diagnostics['G_training/temperature_loss'] = temperature_loss.item()
            #self.diagnostics['G_training/p_mean_abs_less_than0_action'] = p.item()
            #self.diagnostics['G_training/w_sum_abs_less_than0_action'] = w.item()
        return loss.detach().cpu().item()
