from typing import Optional, Tuple
import itertools
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from fsrl.utils import DummyLogger, WandbLogger
from torch.distributions.beta import Beta
from torch.nn import functional as F  # noqa
from tqdm.auto import trange  # noqa
from typing import List
from core.common.net import MHDiagGaussianActor, TransformerBlock, mlp
from torch.distributions.normal import Normal
from torch.distributions.kl import kl_divergence
from torchviz import make_dot


class MOSDT(nn.Module):
    """
    
    Args:
        state_dim (int): dimension of the state space.
        action_dim (int): dimension of the action space.
        max_action (float): Maximum action value.
        seq_len (int): The length of the sequence to process.
        episode_len (int): The length of the episode.
        embedding_dim (int): The dimension of the embeddings.
        num_layers (int): The number of transformer layers to use.
        num_heads (int): The number of heads to use in the multi-head attention.
        attention_dropout (float): The dropout probability for attention layers.
        residual_dropout (float): The dropout probability for residual layers.
        embedding_dropout (float): The dropout probability for embedding layers.
        time_emb (bool): Whether to include time embeddings.
        use_rew (bool): Whether to include return embeddings.
        use_cost (bool): Whether to include cost embeddings.
        cost_transform (bool): Whether to transform the cost values.
        add_cost_feat (bool): Whether to add cost features.
        mul_cost_feat (bool): Whether to multiply cost features.
        cat_cost_feat (bool): Whether to concatenate cost features.
        action_head_layers (int): The number of layers in the action head.
        cost_prefix (bool): Whether to include a cost prefix.
        stochastic (bool): Whether to use stochastic actions.
        init_temperature (float): The initial temperature value for stochastic actions.
        target_entropy (float): The target entropy value for stochastic actions.
    """

    def __init__(
        self,
        state_dim: List[int],
        action_dim: List[int],
        max_action: List[float],
        seq_len: int = 10,
        episode_len: int = 1000,
        embedding_dim: int = 128,
        num_layers: int = 2,
        total_layers: int = 3,
        num_heads: int = 4,
        num_teacher_heads: int = 8,
        attention_dropout: float = 0.0,
        residual_dropout: float = 0.0,
        embedding_dropout: float = 0.0,
        time_emb: bool = True,
        cost_transform: bool = False,
        add_cost_feat: bool = False,
        mul_cost_feat: bool = False,
        cat_cost_feat: bool = False,
        action_head_layers: int = 1,
        stochastic: bool = False,
        init_temperature=0.1,
        target_entropy=None,
        fix_att: bool = False,
        decision_heads: int = 1,
        cost_classify: bool = True,
        ps: bool = True,
        be: bool = True,
    ):
        super().__init__()
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.episode_len = episode_len
        self.max_action = max_action
        self.agent_num = len(self.state_dim)
        if cost_transform:
            self.cost_transform = lambda x: 50 - x
        else:
            self.cost_transform = None
        self.add_cost_feat = add_cost_feat
        self.mul_cost_feat = mul_cost_feat
        self.cat_cost_feat = cat_cost_feat
        self.stochastic = stochastic
        num_layers = 1 if sum(self.state_dim) < 60 else num_layers
        num_teacher_layers = total_layers - num_layers
        self.acc_state_dim = list(itertools.accumulate(self.state_dim))
        self.acc_action_dim = list(itertools.accumulate(self.action_dim))
        self.acc_state_dim.insert(0, 0)
        self.acc_action_dim.insert(0, 0)
        self.ps = ps
        self.be = be

        if self.ps:
            self.emb_drop = nn.Dropout(embedding_dropout)
        else:
            self.emb_drop = nn.ModuleList([
                nn.Dropout(embedding_dropout) for _ in range(self.agent_num)
            ])
        if self.ps:
            self.emb_norm = nn.LayerNorm(embedding_dim)
        else:
            self.emb_norm = nn.ModuleList([
                nn.LayerNorm(embedding_dim) for _ in range(self.agent_num)
            ])

        if self.ps:
            self.out_norm = nn.LayerNorm(embedding_dim)
        else:
            self.out_norm = nn.ModuleList([
                nn.LayerNorm(embedding_dim) for _ in range(self.agent_num)
            ])
        self.teacher_out_norm = nn.LayerNorm(embedding_dim)

        self.time_emb = time_emb
        if self.time_emb:
            if self.ps:
                self.timestep_emb = nn.Embedding(episode_len + seq_len, embedding_dim)
            else:
                self.timestep_emb = nn.ModuleList([
                    nn.Embedding(episode_len + seq_len, embedding_dim) for _ in range(self.agent_num)
                ])
        
        if self.ps:
            self.state_ember = nn.Linear(max(self.state_dim), embedding_dim)
        else:
            self.state_ember = nn.ModuleList([
                nn.Linear(max(self.state_dim), embedding_dim) for _ in range(self.agent_num)
            ])
        if self.ps:
            self.action_ember = nn.Linear(max(self.action_dim), embedding_dim) 
        else:
            self.action_ember = nn.ModuleList([
                nn.Linear(max(self.action_dim), embedding_dim) for _ in range(self.agent_num)
            ])

        self.cost_classify = cost_classify
        if self.cost_classify:
            if self.be:
                if self.ps:
                    self.cost_ember = nn.Embedding(2, embedding_dim // 2)
                else:
                    self.cost_ember = nn.ModuleList([
                        nn.Embedding(2, embedding_dim // 2) for _ in range(self.agent_num)
                    ])
            else:
                if self.ps:
                    self.cost_ember = nn.Embedding(2, embedding_dim)
                else:
                    self.cost_ember = nn.ModuleList([
                        nn.Embedding(2, embedding_dim) for _ in range(self.agent_num)
                    ])
        else:
            if self.be:
                if self.ps:
                    self.cost_ember = nn.Linear(1, embedding_dim // 2)
                else:
                    self.cost_ember = nn.ModuleList([
                        nn.Linear(1, embedding_dim // 2) for _ in range(self.agent_num)
                    ])
            else:
                if self.ps:
                    self.cost_ember = nn.Linear(1, embedding_dim)
                else:
                    self.cost_ember = nn.ModuleList([
                        nn.Linear(1, embedding_dim) for _ in range(self.agent_num)
                    ])
        if self.be:
            if self.ps:
                self.return_ember = nn.Linear(1, embedding_dim // 2)
            else:
                self.return_ember = nn.ModuleList([
                    nn.Linear(1, embedding_dim // 2) for _ in range(self.agent_num)
                ])
        else:
            if self.ps:
                self.return_ember = nn.Linear(1, embedding_dim)
            else:
                self.return_ember = nn.ModuleList([
                    nn.Linear(1, embedding_dim) for _ in range(self.agent_num)
                ])

        if self.be:
            self.seq_repeat = 2 + 1
        else:
            self.seq_repeat = 2 + 1 + 1
        dt_seq_len = self.seq_repeat * seq_len

        if self.ps:
            self.blocks = nn.ModuleList([
                    TransformerBlock(
                        seq_len=dt_seq_len,
                        embedding_dim=embedding_dim,
                        num_heads=num_heads,
                        attention_dropout=attention_dropout,
                        residual_dropout=residual_dropout,
                        fix_att=fix_att,
                    ) for _ in range(num_layers)
                ])
        else:
            self.blocks = nn.ModuleList([
                nn.ModuleList([
                    TransformerBlock(
                        seq_len=dt_seq_len,
                        embedding_dim=embedding_dim,
                        num_heads=num_heads,
                        attention_dropout=attention_dropout,
                        residual_dropout=residual_dropout,
                        fix_att=fix_att,
                    ) for _ in range(num_layers)
                ]) for _ in range(self.agent_num)
            ])

        self.teacher = nn.ModuleList([
                TransformerBlock(
                    seq_len=dt_seq_len,
                    embedding_dim=embedding_dim,
                    num_heads=num_teacher_heads,
                    attention_dropout=attention_dropout,
                    residual_dropout=residual_dropout,
                    fix_att=fix_att,
                ) for _ in range(num_teacher_layers)
            ])

        action_emb_dim = 2 * embedding_dim if self.cat_cost_feat else embedding_dim

        if self.ps:
            self.feat_porj = nn.Sequential(
                    nn.Linear(action_emb_dim, action_emb_dim),
                    nn.GELU(),
                )
        else:
            self.feat_porj = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(action_emb_dim, action_emb_dim),
                    nn.GELU(),
                ) for _ in range(self.agent_num)
            ])

        self.teacher_feat_porj = nn.ModuleList([
            nn.Sequential(
                nn.Linear(action_emb_dim, action_emb_dim),
                nn.GELU(),
            ) for _ in range(self.agent_num)
        ])

        if self.ps:
            self.action_head = MHDiagGaussianActor(action_emb_dim, max(self.action_dim), decision_heads)
        else:
            self.action_head = nn.ModuleList([
                MHDiagGaussianActor(action_emb_dim, max(self.action_dim), decision_heads) for _ in range(self.agent_num)
            ])
        self.teacher_action_head = nn.ModuleList([
            MHDiagGaussianActor(action_emb_dim, max(self.action_dim), decision_heads) for _ in range(self.agent_num)
        ])

        if self.stochastic:
            self.log_temperature = [torch.tensor(np.log(init_temperature), dtype=torch.float) for _ in range(self.agent_num)]
            self.teacher_log_temperature = [torch.tensor(np.log(init_temperature), dtype=torch.float) for _ in range(self.agent_num)]
            for lt in self.log_temperature:
                lt.requires_grad = True
            for lt in self.teacher_log_temperature:
                lt.requires_grad = True
            self.target_entropy = target_entropy
            self.teacher_target_entropy = target_entropy

        self.apply(self._init_weights)

    def temperature(self, agent):
        if self.stochastic:
            return self.log_temperature[agent].exp()
        else:
            return None
        
    def teacher_temperature(self, agent):
        if self.stochastic:
            return self.teacher_log_temperature[agent].exp()
        else:
            return None

    @staticmethod
    def _init_weights(module: nn.Module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(
            self,
            states: torch.Tensor,  # [batch_size, seq_len, state_dim]
            actions: torch.Tensor,  # [batch_size, seq_len, action_dim]
            returns_to_go: torch.Tensor,  # [batch_size, seq_len, agent_num]
            costs_to_go: torch.Tensor,  # [batch_size, seq_len, agent_num]
            time_steps: torch.Tensor,  # [batch_size, seq_len]
            padding_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
    ) -> torch.FloatTensor:
        batch_size, seq_len = states.shape[0], states.shape[1]

        if padding_mask is not None:
            padding_mask = torch.stack([padding_mask] * self.seq_repeat, dim=1).permute(0, 2, 1).reshape(batch_size, -1)

        reward_feats, cost_feats, state_feats, action_feats, action_preds, rc_feats, sequences  = [], [], [], [], [], [], []
        agent_actions = []

        for agent in range(self.agent_num):
            if self.time_emb:
                if self.ps:
                    timestep_emb = self.timestep_emb(time_steps)
                else:
                    timestep_emb = self.timestep_emb[agent](time_steps)
            else:
                timestep_emb = 0.0
            
            if self.cost_classify:
                if self.ps:
                    cost_emb = self.cost_ember(costs_to_go[:, :, agent].type(torch.int64))
                else:
                    cost_emb = self.cost_ember[agent](costs_to_go[:, :, agent].type(torch.int64))
            else:
                if self.ps:
                    cost_emb = self.cost_ember(costs_to_go[:, :, agent].unsqueeze(-1))
                else:
                    cost_emb = self.cost_ember[agent](costs_to_go[:, :, agent].unsqueeze(-1))
            if not self.be:
                cost_emb += timestep_emb

            s = states[:, :, self.acc_state_dim[agent]:self.acc_state_dim[agent+1]]
            s = F.pad(s, (0, max(self.state_dim) - s.shape[-1]), "constant", 0)
            if self.ps:
                state_emb = self.state_ember(s) + timestep_emb
            else:
                state_emb = self.state_ember[agent](s) + timestep_emb
            a = actions[:, :, self.acc_action_dim[agent]:self.acc_action_dim[agent+1]]
            a = F.pad(a, (0, max(self.action_dim) - a.shape[-1]), "constant", 0)
            if self.ps:
                act_emb = self.action_ember(a) + timestep_emb  
            else:
                act_emb = self.action_ember[agent](a) + timestep_emb                
            agent_actions.append(a)

            if self.cost_transform is not None:
                costs_to_go = self.cost_transform(costs_to_go.detach())
            
            if self.ps:
                return_emb = self.return_ember(returns_to_go[:, :, agent].unsqueeze(-1))
            else:
                return_emb = self.return_ember[agent](returns_to_go[:, :, agent].unsqueeze(-1))
            if not self.be:
                return_emb += timestep_emb
            
            if self.be:
                rc_emb = torch.cat([return_emb, cost_emb], dim=-1) + timestep_emb
                sequence = [rc_emb, state_emb, act_emb]
            else:
                sequence = [return_emb, cost_emb, state_emb, act_emb]

            sequence = torch.stack(sequence, dim=1).permute(0, 2, 1, 3)
            sequence = sequence.reshape(batch_size, self.seq_repeat * seq_len, self.embedding_dim)

            # LayerNorm and Dropout (!!!) as in original implementation,
            # while minGPT & huggingface uses only embedding dropout
            if self.ps:
                sequence = self.emb_norm(sequence)
            else:
                sequence = self.emb_norm[agent](sequence)
            if self.ps:
                sequence = self.emb_drop(sequence)
            else:
                sequence = self.emb_drop[agent](sequence)

            if self.ps:
                for block in self.blocks:
                    sequence = block(sequence, padding_mask=padding_mask)
            else:
                for block in self.blocks[agent]:
                    sequence = block(sequence, padding_mask=padding_mask)

            sequences.append(sequence)
            if self.ps:
                sequence = self.out_norm(sequence)
            else:
                sequence = self.out_norm[agent](sequence)

            sequence = sequence.reshape(batch_size, seq_len, self.seq_repeat, self.embedding_dim)
            sequence = sequence.permute(0, 2, 1, 3)

            # reward_feat = sequence[:, 0]
            # cost_feat = sequence[:, 1]
            # rc_feat = sequence[:, 0]
            state_feat = sequence[:, -2]
            # action_feat = sequence[:, -1]

            if self.ps:
                state_feat = self.feat_porj(state_feat)
            else:
                state_feat = self.feat_porj[agent](state_feat)

            state_feats.append(state_feat)

            # if self.add_cost_feat and self.use_cost:
            #     state_feat = state_feat + cost_emb.detach()
            # if self.mul_cost_feat and self.use_cost:
            #     state_feat = state_feat * cost_emb.detach()
            # if self.cat_cost_feat and self.use_cost:
            #     # cost_prefix feature, deprecated
            #     # episode_cost_emb = episode_cost_emb.repeat_interleave(seq_len, dim=1)
            #     # [batch_size, seq_len, 2 * embedding_dim]
            #     state_feat = torch.cat([state_feat, cost_emb.detach()], dim=2)

            # get predictions (predict next action given state
            if self.ps:
                action_pred = self.action_head(state_feat)
            else:
                action_pred = self.action_head[agent](state_feat)
            if not self.stochastic:
                action_pred = action_pred * self.max_action[agent]
            action_preds.append(action_pred)

        if self.training:
            teacher_sequence = torch.sum(torch.stack(sequences, dim=0), dim=0)

            for block in self.teacher:
                    teacher_sequence = block(teacher_sequence, padding_mask=padding_mask)
            teacher_sequence = self.teacher_out_norm(teacher_sequence)
            teacher_sequence = teacher_sequence.reshape(batch_size, seq_len, self.seq_repeat, self.embedding_dim)
            teacher_sequence = teacher_sequence.permute(0, 2, 1, 3)
            all_teacher_state_feats = teacher_sequence[:, -2]
            teacher_action_preds, teacher_state_feats_detach = [], []
            for agent in range(self.agent_num):
                teacher_state_feat = self.teacher_feat_porj[agent](all_teacher_state_feats)
                teacher_action_pred = self.teacher_action_head[agent](teacher_state_feat)
                if not self.stochastic:
                    teacher_action_pred = teacher_action_pred * self.max_action[agent]
                teacher_action_preds.append(teacher_action_pred)
                teacher_state_feats_detach.append(teacher_state_feat.detach())

            teacher_action_preds_detach = []
            for n in teacher_action_preds:
                loc = n.mean.detach()
                scale = n.stddev.detach()
                n_detach = Normal(loc=loc, scale=scale)
                teacher_action_preds_detach.append(n_detach)
        else:
            teacher_action_preds = None
            teacher_action_preds_detach = None
            teacher_state_feats_detach = None

        return action_preds, teacher_action_preds, teacher_action_preds_detach, agent_actions, state_feats, teacher_state_feats_detach

class MOSDTTrainerDev:
    """
    
    Args:
        model (MOSDT): A MOSDT model to train.
        env (gym.Env): The OpenAI Gym environment to train the model in.
        logger (WandbLogger or DummyLogger): The logger to use for tracking training progress.
        learning_rate (float): The learning rate for the optimizer.
        weight_decay (float): The weight decay for the optimizer.
        betas (Tuple[float, ...]): The betas for the optimizer.
        clip_grad (float): The clip gradient value.
        lr_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
        reward_scale (float): The scaling factor for the reward signal.
        cost_scale (float): The scaling factor for the constraint cost.
        cost_reverse (bool): Whether to reverse the cost.
        no_entropy (bool): Whether to use entropy.
        device (str): The device to use for training (e.g. "cpu" or "cuda").

    """

    def __init__(
            self,
            model: MOSDT,
            env: gym.Env,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            weight_decay: float = 1e-4,
            betas: Tuple[float, ...] = (0.9, 0.999),
            clip_grad: float = 0.25,
            lr_warmup_steps: int = 10000,
            reward_scale: float = 1.0,
            cost_scale: float = 1.0,
            cost_reverse: bool = False,
            no_entropy: bool = False,
            device="cpu",
            sd_weight: float = 0.3,
            feat_sd_weight: float = 0.03,
            sd: bool = True,
            ) -> None:
        self.model = model
        self.logger = logger
        self.env = env
        self.clip_grad = clip_grad
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.device = device
        self.cost_reverse = cost_reverse
        self.no_entropy = no_entropy
        self.sd_weight = sd_weight
        self.feat_sd_weight = feat_sd_weight
        self.sd = sd

        self.optim = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=betas,
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optim,
            lambda steps: min((steps + 1) / lr_warmup_steps, 1),
        )
        self.stochastic = self.model.stochastic
        if self.stochastic:
            self.log_temperature_optimizer = torch.optim.Adam(
                self.model.log_temperature,
                lr=1e-4,
                betas=[0.9, 0.999],
            )
            self.teacher_log_temperature_optimizer = torch.optim.Adam(
                self.model.teacher_log_temperature,
                lr=1e-4,
                betas=[0.9, 0.999],
            )
        self.max_action = self.model.max_action
        self.agent_num = self.model.agent_num

        self.beta_dist = Beta(torch.tensor(2, dtype=torch.float, device=self.device),
                              torch.tensor(5, dtype=torch.float, device=self.device))

    def train_one_step(self, states, actions, returns, costs_return, time_steps, mask, step):
        # True value indicates that the corresponding key value will be ignored
        padding_mask = ~mask.to(torch.bool)
        action_preds, teacher_action_preds, teacher_action_preds_detach, agent_actions, state_feats, teacher_state_feats_detach = self.model(
            states=states,
            actions=actions,
            returns_to_go=returns,
            costs_to_go=costs_return,
            time_steps=time_steps,
            padding_mask=padding_mask,
        )

        act_mean_max, act_mean_min, act_std_max, act_std_min, entropys, entropy_reg_items, teacher_entropys, teacher_entropy_reg_items = [], [], [], [], [], [], [], []
        act_loss, teacher_act_loss, sd_loss, feat_sd_loss = 0.0, 0.0, 0.0, 0.0
        for agent in range(self.agent_num):
            if self.stochastic:
                act_mean_max.append(action_preds[agent].loc.max())
                act_mean_min.append(action_preds[agent].loc.min())
                act_std_max.append(action_preds[agent].scale.max())
                act_std_min.append(action_preds[agent].scale.min())

                log_likelihood = action_preds[agent].log_prob(agent_actions[agent])[mask > 0].mean()
                entropy = action_preds[agent].entropy()[mask > 0].mean()
                entropys.append(entropy)
                if self.no_entropy:
                    entropy_reg = 0.0
                    entropy_reg_items.append(0.0)
                else:
                    entropy_reg = self.model.temperature(agent).detach()
                    entropy_reg_items.append(entropy_reg.item())
                act_loss += -(log_likelihood + entropy_reg * entropy)

                teacher_log_likelihood = teacher_action_preds[agent].log_prob(agent_actions[agent])[mask > 0].mean()
                teacher_entropy = teacher_action_preds[agent].entropy()[mask > 0].mean()
                teacher_entropys.append(teacher_entropy)
                if self.no_entropy:
                    teacher_entropy_reg = 0.0
                    teacher_entropy_reg_items.append(0.0)
                else:
                    teacher_entropy_reg = self.model.teacher_temperature(agent).detach()
                    teacher_entropy_reg_items.append(teacher_entropy_reg.item())
                teacher_act_loss += -(teacher_log_likelihood + teacher_entropy_reg * teacher_entropy)

                sd_loss += kl_divergence(action_preds[agent], teacher_action_preds_detach[agent])[mask > 0].mean() * self.sd_weight
                feat_sd_loss += torch.dist(state_feats[agent][mask > 0], teacher_state_feats_detach[agent][mask > 0]) * self.feat_sd_weight
                # for other_agent in range(self.agent_num):
                #     if other_agent != agent:
                #         feat_sd_loss += F.huber_loss(
                #             F.cosine_similarity(state_feats[agent], state_feats[other_agent], dim=-1),
                #             F.cosine_similarity(teacher_state_feats_detach[agent], teacher_state_feats_detach[other_agent], dim=-1),
                #             reduction='none'
                #         )[mask > 0].mean() * self.feat_sd_weight
        
        if self.sd:
            loss =  act_loss + teacher_act_loss + sd_loss + feat_sd_loss
        else:
            loss =  act_loss

        # dot = make_dot(loss, params=dict(self.model.named_parameters()))
        # dot.render('model', format='pdf', cleanup=True)

        self.optim.zero_grad()
        loss.backward()
        if self.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
        self.optim.step()

        if self.stochastic:
            if not self.no_entropy:
                self.log_temperature_optimizer.zero_grad()
                temperature_loss = 0.0
                for agent in range(self.agent_num):
                    temperature_loss += (self.model.temperature(agent) * (entropys[agent] - self.model.target_entropy[agent]).detach())
                temperature_loss.backward()
                self.log_temperature_optimizer.step()
                self.logger.add_scalar("loss/loss_temperature", temperature_loss.item(), step)

                if self.sd:
                    self.teacher_log_temperature_optimizer.zero_grad()
                    teacher_temperature_loss = 0.0
                    for agent in range(self.agent_num):
                        teacher_temperature_loss += (self.model.teacher_temperature(agent) * (teacher_entropys[agent] - self.model.teacher_target_entropy[agent]).detach())
                    teacher_temperature_loss.backward()
                    self.teacher_log_temperature_optimizer.step()
                    self.logger.add_scalar("loss/loss_teacher_temperature", teacher_temperature_loss.item(), step)

            act_mean_max = max(act_mean_max).item()
            act_mean_min = min(act_mean_min).item()
            act_std_max = max(act_std_max).item()
            act_std_min = min(act_std_min).item()
            mean_entropy = torch.stack(entropys, dim=0).mean().item()
            mean_entropy_reg = sum(entropy_reg_items) / len(entropy_reg_items)
            mean_teacher_entropy = torch.stack(teacher_entropys, dim=0).mean().item()
            mean_teacher_entropy_reg = sum(teacher_entropy_reg_items) / len(teacher_entropy_reg_items)
            
            self.logger.add_scalar("train/mean_entropy", mean_entropy, step)
            self.logger.add_scalar("train/mean_entropy_reg", mean_entropy_reg, step)
            self.logger.add_scalar("train/mean_teacher_entropy", mean_teacher_entropy, step)
            self.logger.add_scalar("train/mean_teacher_entropy_reg", mean_teacher_entropy_reg, step)
            self.logger.add_scalar("train/act_mean_max", act_mean_max, step)
            self.logger.add_scalar("train/act_mean_min", act_mean_min, step)
            self.logger.add_scalar("train/act_std_max", act_std_max, step)
            self.logger.add_scalar("train/act_std_min", act_std_min, step)

        self.logger.add_scalar("loss/loss", loss.item(), step)
        self.logger.add_scalar("loss/loss_act_weighted", act_loss.item(), step)
        self.logger.add_scalar("loss/loss_act_teacher", teacher_act_loss.item(), step)
        self.logger.add_scalar("loss/loss_sd_weighted", sd_loss.item(), step)
        self.logger.add_scalar("loss/loss_feat_sd_weighted", feat_sd_loss.item(), step)
        self.logger.add_scalar("train/lr", self.scheduler.get_last_lr()[0], step)

        self.scheduler.step()