import numpy as np
import torch

from ICRL.training.trainer import Trainer
from torch.nn import functional as F  # noqa


class SequenceTrainer(Trainer):

    def train_step(self):
        #print(self.model.cost_atten_layer.state_dict())

        states_e, actions_e, rewards_e, dones_e, rtg_e, timesteps_e, attention_mask_e,batch_inds = self.get_batch_e(self.batch_size) #s, a, r, d, rtg, timesteps, mask


        B,T,_ = actions_e.shape
        

        trans_pred_e,_ = self.model.forward(
            states_e, actions_e, timesteps_e, attention_mask=attention_mask_e,training=True
        )

        states_o, actions_o, rewards_o, dones_o, rtg_o, timesteps_o, attention_mask_o = self.get_batch_o(self.batch_size) #s, a, r, d, rtg, timesteps, mask


        trans_pred_o,_ = self.model.forward(
            states_o, actions_o, timesteps_o, attention_mask=attention_mask_o,training=True
        )

        #print(self.model.cost_atten_layer.state_dict())



        if self.use_weighted_sum:
                trans_pred_e = trans_pred_e["weighted_sum"]
                trans_pred_o = trans_pred_o["weighted_sum"]
        else:
                trans_pred_e = trans_pred_e["value"]
                trans_pred_o = trans_pred_o["value"]

        if self.train_type == "mean":
                a = trans_pred_e.reshape(B, T)
                b = a*attention_mask_e
                sum_pred_e = torch.mean(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_o = torch.mean(trans_pred_o.reshape(B, T), axis=1).reshape(-1, 1)
        elif self.train_type == "sum":
                sum_pred_e = torch.sum(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
                sum_pred_o = torch.sum(trans_pred_o.reshape(B, T), axis=1).reshape(-1, 1)
        elif self.train_type == "last":
                sum_pred_e = trans_pred_e.reshape(B, T)[:, -1].reshape(-1, 1)
                sum_pred_o = trans_pred_o.reshape(B, T)[:, -1].reshape(-1, 1)
           
        #balance_ = 0.2*(torch.mean(sum_pred_e)-torch.mean(sum_pred_o))**2
        # KL Divergence between net(a) and net(b)
        #kl_divergence = F.kl_div(F.log_softmax(sum_pred_e, dim=-1), F.softmax(sum_pred_o, dim=-1), reduction='batchmean')
        sum_pred_e = torch.clamp(sum_pred_e,min=0,max=1)
        sum_pred_o = torch.clamp(sum_pred_o,min=0,max=1)
        balance_ = F.mse_loss(sum_pred_e,sum_pred_o)
        loss = torch.mean(sum_pred_e) - torch.mean(sum_pred_o) #+ 0.1*balance_

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.25)
        self.optimizer.step()
        
        return loss.detach().cpu().item(),torch.mean(sum_pred_e).detach().cpu().item(),torch.mean(sum_pred_o).detach().cpu().item(),balance_.detach().cpu().item()
