import os
from pickletools import StackObject

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from losses import *

# from deepspeed.ops.adam import DeepSpeedCPUAdam
import numpy as np
import time

import gym
from model import Policy, CommPolicy, FCNetwork
from gym.spaces.utils import flatdim
from storage import RolloutStorage, CommRolloutStorage, CommMemRolloutStorage, CommMGRolloutStorage, MGRolloutStorage
from sacred import Ingredient
# from torchviz import make_dot, make_dot_from_trace

algorithm = Ingredient("algorithm")

# torch.autograd.set_detect_anomaly(True)
@algorithm.config
def config():
    # RWare
    # lr = 3e-4
    # # lr = 3e-8
    # adam_eps = 0.001
    # gamma = 0.99
    # use_gae = False
    # gae_lambda = 0.95
    # entropy_coef = 0.01
    # value_loss_coef = 0.5
    # max_grad_norm = 0.5

    # PP
    lr = 3e-4
    adam_eps = 0.001
    gamma = 0.99
    use_gae = False
    gae_lambda = 0.97
    entropy_coef = 0.01
    value_loss_coef = 0.5
    max_grad_norm = 2500
    reward_scale = 1.0
    num_comm_outputs = 4

    # CIFAR
    # lr = 3e-4
    # adam_eps = 0.001
    # gamma = 0.99
    # use_gae = False
    # gae_lambda = 0.97
    # entropy_coef = 0.01
    # value_loss_coef = 0.5
    # max_grad_norm = 50
    # reward_scale = 1.0
    # num_comm_outputs = 10

    # TJ
    # lr = 3e-4
    # adam_eps = 0.001
    # gamma = 0.99
    # use_gae = False
    # gae_lambda = 0.97
    # entropy_coef = 0.1
    # value_loss_coef = 0.5
    # max_grad_norm = 2500
    # reward_scale = 1.0
    # num_comm_outputs = 4

    # Env Properties
    sensor_range = None
    request_queue_size = None
    env_properties = {}
    if(sensor_range != None or request_queue_size != None):
        env_properties['sensor_range'] = sensor_range
        env_properties['request_queue_size'] = request_queue_size
    else:
        env_properties = None

    use_proper_time_limits = True
    recurrent_policy = True
    use_linear_lr_decay = False

    # Regularization parameters
    use_clipping = True
    # Spectral norm is used for RL head
    use_spectral_norm = True
    # Use layer norm after CNN
    use_norm = False

    use_communication = False
    use_comm_gate = False
    use_aligner = False
    share_aligner = False
    aligner_type = 'mi' if use_aligner else ''
    use_dial = False
    if(use_aligner == False and use_communication):
        use_dial = True
    # use_dial = False
    aligner_coeff = 1.0
    mi_loss_prop_dict = {}
    if(use_aligner and 'mi' in aligner_type):
        # mi_temperature = 0.1
        # mi_coeff = 0.05
        mi_temperature = 0.1
        mi_coeff = 0.50
        mi_loss_prop_dict['temperature'] = mi_temperature
        mi_loss_prop_dict['mi_coeff'] = mi_coeff
    mi_temperature = 0.1
    mi_coeff = 0.50
    mi_loss_prop_dict['temperature'] = mi_temperature
    mi_loss_prop_dict['mi_coeff'] = mi_coeff
    use_projector = False
    if(use_aligner):
        use_projector = False

    # Whether to use learnable coefficient for grounding module
    use_learnable_ground_coeff = False

    use_message_decoder = True
    use_comm_sep_rnn = False
    # comm_embed_size = 128
    # below used for pp
    comm_embed_size = 8
    # below used for tj
    # comm_embed_size = 16
    use_memory = False
    use_pos_listening_loss = False
    comm_interval = 1
    concat_messages = True
    # Possibilities: SDND, LinMem
    mem_type = "SDND"
    mem_size = 200
    if("DND" in mem_type):
        mem_policy = 'kNN'
        mem_key_size = 6
    hidden_size = 32
    seac_coef = 0.0

    num_processes = 12
    # num_processes = 2
    num_steps = 5
    # num_steps = 10
    # For memory
    # num_steps_schedule = [5, 10, 15, 20]
    num_steps_schedule = [5]
    # num_steps_schedule = [10]

    if(seac_coef > 0.0):
        if(use_communication):
            if(use_comm_gate):
                algorithm_name = "seac_comm_gated"
            else:
                if(use_memory):
                    algorithm_name = "seac_comm_mem"
                else:
                    algorithm_name = "seac_comm"
        else:
            algorithm_name = "seac"
    else:
        if(use_communication):
            if(use_comm_gate):
                algorithm_name = "iac_comm_gated"
            else:
                if(use_memory):
                    algorithm_name = "iac_comm_mem"
                else:
                    algorithm_name = "iac_comm"
                    # algorithm_name = "iac_comm_alt"
        else:
            algorithm_name = "iac"

    # recurrent_gru_types = ['gru', 'gru_jit', 'mgu', 'mgu_jit', 'mgu2_jit', 'qrnn', 'qrnn_jit', 'qrnn_2', 'lmu', 'drnn']
    # recurrent_lstm_types = ['gru_jit']
    recurrent_type = 'gru'

    device = 'cpu'
    # device = 0

def cross_entropy(input, target):
    return torch.mean(-torch.sum(target * torch.log(input), 1))

class A2C:
    @algorithm.capture()
    def __init__(
        self,
        env_name,
        agent_id,
        obs_space,
        action_space,
        lr,
        adam_eps,
        recurrent_policy,
        recurrent_type,
        hidden_size,
        num_steps,
        num_processes,
        use_clipping,
        use_spectral_norm,
        use_norm,
        device,
    ):
        self.agent_id = agent_id
        if('cifar' not in env_name  or 'MarlGrid' not in env_name):
            self.obs_size = flatdim(obs_space)
            self.obs_space = obs_space
        elif('MarlGrid' in env_name):
            self.obs_space = obs_space
        self.action_size = flatdim(action_space)

        self.action_space = action_space

        use_conv_state_encoder = True if('cifar' in env_name or 'MarlGrid' in env_name) else False
        self.model = Policy(
            obs_space, action_space, hidden_size = hidden_size, use_spectral_norm = use_spectral_norm, use_conv_state_encoder = use_conv_state_encoder, env_name = env_name, use_norm = use_norm, base_kwargs={"recurrent": recurrent_policy, "recurrent_type": recurrent_type},
        )

        if('MarlGrid' in env_name):
            self.storage = MGRolloutStorage(
                obs_space,
                action_space,
                self.model.recurrent_hidden_state_size,
                num_steps,
                num_processes,
            )
        else:
            self.storage = RolloutStorage(
                obs_space,
                action_space,
                self.model.recurrent_hidden_state_size,
                num_steps,
                num_processes,
            )

        self.use_clipping = use_clipping
        self.model.to(device)
        self.optimizer = optim.Adam(self.model.parameters(), lr, eps=adam_eps)

        # self.intr_stats = RunningStats()
        self.saveables = {
            "model": self.model,
            "optimizer": self.optimizer,
        }

        self.env_name = env_name


    def save(self, path, step):
        torch.save(self.saveables, os.path.join(path, "models_{}.pt".format(step)))

    def restore(self, path, step):
        checkpoint = torch.load(os.path.join(path, "models_{}.pt".format(step)))
        for k, v in self.saveables.items():
            v.load_state_dict(checkpoint[k].state_dict())

    @algorithm.capture
    def compute_returns(self, use_gae, gamma, gae_lambda, use_proper_time_limits):
        with torch.no_grad():
            next_value = self.model.get_value(
                self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]),
                self.storage.recurrent_hidden_states[-1],
                self.storage.masks[-1],
            ).detach()

        self.storage.compute_returns(
            next_value, use_gae, gamma, gae_lambda, use_proper_time_limits,
        )

    @algorithm.capture
    #@profile
    # Using external model to compute returns
    def compute_returns_shared(self, agent, use_gae, gamma, gae_lambda, use_proper_time_limits):
        with torch.no_grad():
            next_value = agent.model.get_value(
                self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]),
                self.storage.recurrent_hidden_states[-1].clone(),
                self.storage.masks[-1],
            ).detach()

        self.storage.compute_returns(
            next_value, use_gae, gamma, gae_lambda, use_proper_time_limits,
        )

    @algorithm.capture
    #@profile
    def update(
        self,
        storages,
        value_loss_coef,
        entropy_coef,
        seac_coef,
        max_grad_norm,
        device,
    ):

        if("MarlGrid" not in self.env_name):
            obs_shape = self.storage.obs.size()[2:]
        else:
            img_obs_shape = self.storage.img_obs.size()[2:]
            df_obs_shape = self.model.base.num_direct_features
        action_shape = self.storage.actions.size()[-1]
        num_steps, num_processes, _ = self.storage.rewards.size()

        values, action_log_probs, dist_entropy, _, _ = self.model.evaluate_actions(
            self.storage.obs[:-1].view(-1, *obs_shape) if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)),
            self.storage.recurrent_hidden_states[0].view(
                -1, self.model.recurrent_hidden_state_size
            ),
            self.storage.masks[:-1].view(-1, 1),
            self.storage.actions.view(-1, action_shape),
        )

        values = values.view(num_steps, num_processes, 1)
        action_log_probs = action_log_probs.view(num_steps, num_processes, 1)

        advantages = self.storage.returns[:-1] - values

        policy_loss = -(advantages.detach() * action_log_probs).mean()
        value_loss = advantages.pow(2).mean()

        # calculate prediction loss for the OTHER actor
        other_agent_ids = [x for x in range(len(storages)) if x != self.agent_id]
        seac_policy_loss = 0
        seac_value_loss = 0
        # Unused seac code
        # if(seac_coef > 0.0):
        #     for oid in other_agent_ids:

        #         other_values, logp, _, _ = self.model.evaluate_actions(
        #             storages[oid].obs[:-1].view(-1, *obs_shape),
        #             storages[oid]
        #             .recurrent_hidden_states[0]
        #             .view(-1, self.model.recurrent_hidden_state_size),
        #             storages[oid].masks[:-1].view(-1, 1),
        #             storages[oid].actions.view(-1, action_shape),
        #         )
        #         other_values = other_values.view(num_steps, num_processes, 1)
        #         logp = logp.view(num_steps, num_processes, 1)
        #         other_advantage = (
        #             storages[oid].returns[:-1] - other_values
        #         )  # or storages[oid].rewards

        #         importance_sampling = (
        #             logp.exp() / (storages[oid].action_log_probs.exp() + 1e-7)
        #         ).detach()
        #         # importance_sampling = 1.0
        #         seac_value_loss += (
        #             importance_sampling * other_advantage.pow(2)
        #         ).mean()
        #         seac_policy_loss += (
        #             -importance_sampling * logp * other_advantage.detach()
        #         ).mean()

        self.optimizer.zero_grad()
        # for param in self.model.parameters():
        #     param.grad = None
        total_loss =  policy_loss + value_loss_coef * value_loss - entropy_coef * dist_entropy + seac_coef * seac_policy_loss + seac_coef * value_loss_coef * seac_value_loss
        total_loss.backward()

        if(self.use_clipping):
            nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

        self.optimizer.step()

        if(seac_coef > 0.0):
            return {
                "policy_loss": policy_loss.item(),
                "value_loss": value_loss_coef * value_loss.item(),
                "dist_entropy": entropy_coef * dist_entropy.item(),
                # "importance_sampling": importance_sampling.mean().item(),
                "seac_policy_loss": seac_coef * seac_policy_loss.item(),
                "seac_value_loss": seac_coef
                * value_loss_coef
                * seac_value_loss.item(),
            }
        else:
            return {
                "policy_loss": policy_loss.item(),
                "value_loss": value_loss_coef * value_loss.item(),
                "dist_entropy": entropy_coef * dist_entropy.item(),
                "grad_norm": self.get_total_norm()
            }

    def get_total_norm(self):
        total_norm = 0
        parameters = [p for p in self.model.parameters() if p.grad is not None and p.requires_grad]
        for p in parameters:
            param_norm = p.grad.clone().detach().data.norm(2)
            total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm

    @algorithm.capture
    #@profile
    # Update when the model is shared
    def update_shared(
        self,
        agents,
        value_loss_coef,
        entropy_coef,
        seac_coef,
        max_grad_norm,
        device,
    ):
        raise NotImplementedError("Update shared not inuse!")
        # obs_shape = self.storage.obs.size()[2:]
        # action_shape = self.storage.actions.size()[-1]
        # num_steps, num_processes, _ = self.storage.rewards.size()
        # storages = [a.storage for a in agents]
        # agent_ids = [x for x in range(len(storages))]
        # policy_loss = 0
        # value_loss = 0
        # dist_entropy = 0
        # for oid in agent_ids:

        #     other_values, logp, ent, _, _ = self.model.evaluate_actions(
        #         storages[oid].obs[:-1].view(-1, *obs_shape),
        #         storages[oid]
        #         .recurrent_hidden_states[0]
        #         .view(-1, self.model.recurrent_hidden_state_size),
        #         storages[oid].masks[:-1].view(-1, 1),
        #         storages[oid].actions.view(-1, action_shape),
        #     )
        #     other_values = other_values.view(num_steps, num_processes, 1)
        #     logp = logp.view(num_steps, num_processes, 1)
        #     other_advantage = (
        #         storages[oid].returns[:-1] - other_values
        #     )  # or storages[oid].rewards

        #     importance_sampling = (
        #         logp.exp() / (storages[oid].action_log_probs.exp() + 1e-7)
        #     ).detach()
        #     # importance_sampling = 1.0
        #     value_loss += (
        #         importance_sampling * other_advantage.pow(2)
        #     ).mean()
        #     policy_loss += (
        #         -importance_sampling * logp * other_advantage.detach()
        #     ).mean()
        #     dist_entropy += ent.mean()

        # self.optimizer.zero_grad()
        # # for param in self.model.parameters():
        # #     param.grad = None
        # total_loss =  policy_loss + value_loss_coef * value_loss - entropy_coef * dist_entropy
        # total_loss.backward()

        # if(self.use_clipping):
        #     nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

        # self.optimizer.step()

        # return {
        #     "policy_loss": policy_loss.item(),
        #     "value_loss": value_loss_coef * value_loss.item(),
        #     "dist_entropy": entropy_coef * dist_entropy.item(),
        #     "importance_sampling": importance_sampling.mean().item(),
        # }



class CommA2C:
    @algorithm.capture()
    def __init__(
        self,
        env_name,
        agent_id,
        obs_space,
        action_space,
        num_other_agents,
        lr,
        adam_eps,
        recurrent_policy,
        recurrent_type,
        comm_embed_size,
        num_comm_outputs,
        use_comm_gate,
        use_memory,
        mem_size,
        mem_type,
        mem_key_size,
        use_aligner,
        share_aligner,
        aligner_type,
        aligner_coeff,
        use_pos_listening_loss,
        use_message_decoder,
        use_comm_sep_rnn,
        hidden_size,
        num_steps,
        num_processes,
        concat_messages,
        use_spectral_norm,
        mi_loss_prop_dict,
        use_learnable_ground_coeff,
        use_dial,
        use_projector,
        use_norm,
        device,
    ):
        self.agent_id = agent_id
        self.action_size = flatdim(action_space)
        if('cifar' not in env_name  or 'MarlGrid' not in env_name):
            self.obs_size = flatdim(obs_space)
            self.obs_space = obs_space
        elif('MarlGrid' in env_name):
            self.obs_space = obs_space
        else:
            self.obs_space = None
        self.action_space = action_space
        self.num_other_agents = num_other_agents
        self.use_pos_listening_loss = use_pos_listening_loss
        # Original - concat all other agents' messages
        if(concat_messages):
            self.message_space = [num_other_agents * num_comm_outputs, comm_embed_size, num_comm_outputs]
        else:
            self.message_space = [num_comm_outputs, comm_embed_size, num_comm_outputs]
        self.use_memory = use_memory
        self.share_aligner = share_aligner
        self.use_comm_sep_rnn = use_comm_sep_rnn
        self.use_aligner = use_aligner
        self.aligner_type = aligner_type
        self.use_projector = use_projector
        if(use_aligner):
            if(use_learnable_ground_coeff):
                self.aligner_coeff = nn.Parameter(torch.tensor(aligner_coeff).to(device), requires_grad=True)
                print("I shouldn't be here")
                exit()
            else:
                self.aligner_coeff = aligner_coeff
            if('mi' in aligner_type):
                self.mi_temp = mi_loss_prop_dict['temperature']
                if(use_learnable_ground_coeff):
                    print("I shouldn't be here")
                    exit()
                    self.mi_coeff = nn.Parameter(torch.tensor(mi_loss_prop_dict['mi_coeff']).to(device), requires_grad=True)
                else:
                    self.mi_coeff = mi_loss_prop_dict['mi_coeff']
                self.mi_loss = SupConLoss(temperature = self.mi_temp, device = device).to(device)
        use_conv_state_encoder = True if('cifar' in env_name  or 'MarlGrid' in env_name) else False
        self.model = CommPolicy(
            obs_space, action_space, self.message_space, use_comm_gate = use_comm_gate, use_memory = use_memory, use_aligner = use_aligner, aligner_type = aligner_type, mem_size = mem_size, mem_type = mem_type, mem_key_size = mem_key_size, hidden_size = hidden_size, use_comm_sep_rnn = use_comm_sep_rnn, use_spectral_norm = use_spectral_norm, use_dial = use_dial, use_conv_state_encoder = use_conv_state_encoder, num_agents = num_other_agents + 1, use_projector = use_projector, env_name = env_name, use_norm = use_norm, base_kwargs={"recurrent": recurrent_policy, "recurrent_type": recurrent_type, "use_message_decoder": use_message_decoder},
        )
        self.obs_space = obs_space 
        self.action_space = action_space
        self.num_comm_outputs = num_comm_outputs
        self.num_steps = num_steps 
        self.num_processes = num_processes
        self.use_comm_sep_rnn = use_comm_sep_rnn
        self.use_memory = use_memory 
        self.env_name = env_name
        if(use_memory):
            self.mem_hidden = hidden_size
            self.storage = CommMemRolloutStorage(
                obs_space,
                self.mem_hidden,
                action_space,
                num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                num_steps,
                num_processes,
                use_comm_sep_rnn = use_comm_sep_rnn
            )
        elif('MarlGrid' in env_name):
            self.storage = CommMGRolloutStorage(
                obs_space,
                action_space,
                num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                num_steps,
                num_processes,
                use_comm_sep_rnn = use_comm_sep_rnn
            )
        else:
            self.storage = CommRolloutStorage(
                obs_space,
                action_space,
                num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                num_steps,
                num_processes,
                use_comm_sep_rnn = use_comm_sep_rnn
            )

        self.model.to(device)
        # self.optimizer = optim.Adam(self.model.parameters(), lr, eps=adam_eps)
        if(use_aligner and use_learnable_ground_coeff):
            if('mi' in aligner_type):
                self.optimizer = optim.Adam([{'params':self.model.parameters()}, {'params': [self.aligner_coeff, self.mi_coeff]}], lr, eps=adam_eps)
            else:
                self.optimizer = optim.Adam([{'params':self.model.parameters()}, {'params': [self.aligner_coeff]}], lr, eps=adam_eps)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr, eps=adam_eps)

        # self.intr_stats = RunningStats()
        self.saveables = {
            "model": self.model,
            "optimizer": self.optimizer,
        }

        self.env_name = env_name

    # def save(self, path, step):
    #     torch.save(self.saveables, os.path.join(path, "models_{}.pt".format(step)))

    def replay_reset(self):
        if(self.use_memory):
            self.storage = CommMemRolloutStorage(
                self.obs_space,
                self.mem_hidden,
                self.action_space,
                self.num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                self.num_steps,
                self.num_processes,
                use_comm_sep_rnn = self.use_comm_sep_rnn
            )
        elif('MarlGrid' in self.env_name):
            self.storage = CommMGRolloutStorage(
                self.obs_space,
                self.action_space,
                self.num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                self.num_steps,
                self.num_processes,
                use_comm_sep_rnn = self.use_comm_sep_rnn
            )
        else:
            self.storage = CommRolloutStorage(
                self.obs_space,
                self.action_space,
                self.num_comm_outputs,
                self.model.recurrent_hidden_state_size,
                self.num_steps,
                self.num_processes,
                use_comm_sep_rnn = self.use_comm_sep_rnn
            )

    def save(self, path, step):
        save_name = "use_aligner" + str(int(self.use_aligner))
        if(self.use_aligner):
            save_name += "_aligner_type" + str(self.aligner_type)
            if('mi' in self.aligner_type):
                save_name += "_mi_temp" + str(self.mi_temp) + "_mi_coeff" + str(self.mi_coeff)
            else:
                save_name += "_align_coeff" + str(self.aligner_coeff)
        torch.save(self.saveables, os.path.join(path, "models_{}_{}.pt".format(save_name, step)))

    def restore(self, path, step):
        checkpoint = torch.load(os.path.join(path, "models_{}.pt".format(step)))
        for k, v in self.saveables.items():
            v.load_state_dict(checkpoint[k].state_dict())

    @algorithm.capture
    #@profile
    def compute_returns(self, use_gae, gamma, gae_lambda, use_proper_time_limits):
        with torch.no_grad():
            if(self.use_comm_sep_rnn):
                next_value = self.model.get_value(self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]) ,  self.storage.msg[-1].clone(), self.storage.recurrent_hidden_states[-1].clone(), self.storage.masks[-1], self.storage.comm_recurrent_hidden_states[-1].clone()).detach()
            else:
                next_value = self.model.get_value(self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]),  self.storage.msg[-1].clone(), self.storage.recurrent_hidden_states[-1].clone(), self.storage.masks[-1],).detach()

        self.storage.compute_returns(
            next_value, use_gae, gamma, gae_lambda, use_proper_time_limits,
        )

    @algorithm.capture
    #@profile
    # Using external model to compute returns
    def compute_returns_shared(self, agent, use_gae, gamma, gae_lambda, use_proper_time_limits):
        with torch.no_grad():
            if(self.use_comm_sep_rnn):
                next_value = agent.model.get_value(self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]), self.storage.msg[-1].clone(), self.storage.recurrent_hidden_states[-1].clone(), self.storage.masks[-1],  self.storage.comm_recurrent_hidden_states[-1].clone()).detach()
            else:
                next_value = agent.model.get_value(self.storage.obs[-1] if("MarlGrid" not in self.env_name) else (self.storage.img_obs[-1], self.storage.df_obs[-1]), self.storage.msg[-1].clone(), self.storage.recurrent_hidden_states[-1].clone(), self.storage.masks[-1],).detach()

        self.storage.compute_returns(
            next_value, use_gae, gamma, gae_lambda, use_proper_time_limits,
        )

    @algorithm.capture
    #@profile
    def update(
        self,
        agents,
        value_loss_coef,
        entropy_coef,
        seac_coef,
        max_grad_norm,
        device,
    ):
        action_shape = self.storage.actions.size()[-1]
        num_steps, num_processes, _ = self.storage.rewards.size()
        message_inputs = torch.stack(self.storage.msg[:-1]).view(-1, self.message_space[0])
        zero_message_inputs = torch.zeros(message_inputs.size(), device = device)
        # stacked message size - (num time step, num of parallel processes, num_other_agents * message_size)

        pl_loss = 0.0
        if(self.use_memory):
            if(self.use_comm_sep_rnn):
                raise NotImplementedError("No longer used!")
            #     values, action_log_probs, dist_entropy, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
            #         torch.stack(self.storage.obs[:-1]).view(-1, self.mem_hidden) ,
            #         message_inputs,
            #         self.storage.recurrent_hidden_states[0].view(
            #             -1, self.model.recurrent_hidden_state_size
            #         ),
            #         self.storage.masks[:-1].view(-1, 1),
            #         self.storage.actions.view(-1, action_shape),
            #         self.storage.comm_recurrent_hidden_states[0].view(
            #             -1, self.model.recurrent_hidden_state_size
            #         )
            #     )
            #     if(self.use_pos_listening_loss):
            #         _, cf_action_log_probs, _, _, _, _, cf_policy, _ = self.model.evaluate_actions(
            #             torch.stack(self.storage.obs[:-1]).view(-1, self.mem_hidden),
            #             zero_message_inputs,
            #             self.storage.recurrent_hidden_states[0].view(
            #                 -1, self.model.recurrent_hidden_state_size
            #             ),
            #             self.storage.masks[:-1].view(-1, 1),
            #             self.storage.actions.view(-1, action_shape),
            #             self.storage.comm_recurrent_hidden_states[0].view(
            #                 -1, self.model.recurrent_hidden_state_size
            #             )
            #         )
            #         # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
            #         pl_loss += -1.0 * F.l1_loss(policy, cf_policy.detach())
            # else:
            #     values, action_log_probs, dist_entropy, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
            #         torch.stack(self.storage.obs[:-1]).view(-1, self.mem_hidden),
            #         message_inputs,
            #         self.storage.recurrent_hidden_states[0].view(
            #             -1, self.model.recurrent_hidden_state_size
            #         ),
            #         self.storage.masks[:-1].view(-1, 1),
            #         self.storage.actions.view(-1, action_shape),
            #     )
            #     if(self.use_pos_listening_loss):
            #         _, cf_action_log_probs, _, _, _, _, cf_policy, _ = self.model.evaluate_actions(
            #             torch.stack(self.storage.obs[:-1]).view(-1, self.mem_hidden),
            #             zero_message_inputs,
            #             self.storage.recurrent_hidden_states[0].view(
            #                 -1, self.model.recurrent_hidden_state_size
            #             ),
            #             self.storage.masks[:-1].view(-1, 1),
            #             self.storage.actions.view(-1, action_shape),
            #         )
            #         # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
            #         pl_loss += -1.0 * F.l1_loss(policy, cf_policy.detach())
        else:
            if(self.use_comm_sep_rnn):
                raise NotImplementedError("No longer used!")
                # if("MarlGrid" not in self.env_name):
                #     obs_shape = self.storage.obs.size()[2:]
                # else:
                #     img_obs_shape = self.storage.img_obs.size()[2:]
                #     df_obs_shape = self.model.base.num_direct_features
                #
                # values, action_log_probs, dist_entropy, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
                #     self.storage.obs[:-1].view(-1, *obs_shape) if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)),
                #     message_inputs,
                #     self.storage.recurrent_hidden_states[0].view(
                #         -1, self.model.recurrent_hidden_state_size
                #     ),
                #     self.storage.masks[:-1].view(-1, 1),
                #     self.storage.actions.view(-1, action_shape),
                #     self.storage.comm_recurrent_hidden_states[0].view(
                #         -1, self.model.recurrent_hidden_state_size
                #     ),
                # )
                # if(self.use_pos_listening_loss):
                #     _, cf_action_log_probs, _, _, _, _, cf_policy, _ = self.model.evaluate_actions(
                #         self.storage.obs[:-1].view(-1, *obs_shape) if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)),
                #         zero_message_inputs,
                #         self.storage.recurrent_hidden_states[0].view(
                #             -1, self.model.recurrent_hidden_state_size
                #         ),
                #         self.storage.masks[:-1].view(-1, 1),
                #         self.storage.actions.view(-1, action_shape),
                #         self.storage.comm_recurrent_hidden_states[0].view(
                #             -1, self.model.recurrent_hidden_state_size
                #         ),
                #     )
                #     # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
                #     pl_loss += -1.0 * F.l1_loss(policy, cf_policy.detach())
            else:
                if("MarlGrid" not in self.env_name):
                    obs_shape = self.storage.obs.size()[2:]
                else:
                    img_obs_shape = self.storage.img_obs.size()[2:]
                    df_obs_shape = self.model.base.num_direct_features
                values, action_log_probs, dist_entropy, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
                    self.storage.obs[:-1].view(-1, *obs_shape)  if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)),
                    message_inputs,
                    self.storage.recurrent_hidden_states[0].view(
                        -1, self.model.recurrent_hidden_state_size
                    ),
                    self.storage.masks[:-1].view(-1, 1),
                    self.storage.actions.view(-1, action_shape),
                )
                if(self.use_pos_listening_loss):
                    _, cf_action_log_probs, _, _, _, _, cf_policy, _  = self.model.evaluate_actions(
                        self.storage.obs[:-1].view(-1, *obs_shape)  if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)),
                        zero_message_inputs,
                        self.storage.recurrent_hidden_states[0].view(
                            -1, self.model.recurrent_hidden_state_size
                        ),
                        self.storage.masks[:-1].view(-1, 1),
                        self.storage.actions.view(-1, action_shape),
                    )
                    # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
                    # Maximize L1 distance
                    pl_loss += -1.0 * F.l1_loss(policy, cf_policy.clone().detach())
                    # Minimize cross entropy for cf_policy to approximate average policy
                    pl_loss += cross_entropy(cf_policy, policy.clone().detach())

        values = values.view(num_steps, num_processes, 1)
        action_log_probs = action_log_probs.view(num_steps, num_processes, 1)

        advantages = self.storage.returns[:-1] - values

        policy_loss = -(advantages.detach() * action_log_probs).mean()
        value_loss = advantages.pow(2).mean()

        aligner_loss = 0.0
        if(self.model.aligner_type != 'mi_split' and self.model.aligner_type != 'mm' and (type(aligned_output) != type(None) or ('mi' in self.model.aligner_type and self.model.aligner_type != 'mi' and self.model.aligner_type != 'mi_intra_traj' and self.model.aligner_type != 'mi_simclr'))):
            if(self.share_aligner):
                # Use the first agent's aligner for everyone else's
                # First item is the one that doesn't backpropagate across agents
                aligned_output = agents[0].model.projector(agents[0].model.aligner(ms[1])) if self.use_projector else agents[0].model.aligner(ms[1])
            if(self.use_memory):
                obs_shape = self.storage.raw_obs.size()[2:]
                if(self.model.aligner_type == 'obs' or self.model.aligner_type == 'mi_obs'):
                    # This compares with raw obs
                    # aligner_loss = F.mse_loss(aligned_output, self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    # This compares with encoded obs
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss = F.mse_loss(aligned_output, encoded_obs)
                    aligner_loss *= self.aligner_coeff
                elif(self.model.aligner_type == 'act' or self.model.aligner_type == 'mi_act'):
                    # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
                    aligner_loss = F.cross_entropy(aligned_output.log(), torch.argmax(policy.clone(), dim = -1))
                    aligner_loss *= self.aligner_coeff
                    # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
                elif(self.model.aligner_type == 'kl_act' or self.model.aligner_type == 'mi_kl_act'):
                    aligner_loss = F.kl_div(aligned_output.log(), policy.clone().log().detach(), log_target = True)
                    aligner_loss *= self.aligner_coeff
                elif(self.model.aligner_type == 'obs_act' or self.model.aligner_type == 'mi_obs_act'):
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.mse_loss(aligned_output[0], encoded_obs)
                    # aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.cross_entropy(aligned_output[1].log(), torch.argmax(policy.clone(), dim = -1))
                    aligner_loss *= self.aligner_coeff
                elif(self.model.aligner_type == 'obs_kl_act' or self.model.aligner_type == 'mi_obs_kl_act'):
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.mse_loss(aligned_output[0], encoded_obs)
                    # aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.kl_div(aligned_output[1].log(), policy.clone().log().detach(), log_target = True)
                    aligner_loss *= self.aligner_coeff
                else:
                    raise NotImplementedError
            else:
                if("MarlGrid" not in self.env_name):
                    obs_shape = self.storage.obs.size()[2:]
                else:
                    img_obs_shape = self.storage.img_obs.size()[2:]
                    df_obs_shape = self.model.base.num_direct_features
                if(self.model.aligner_type == 'obs' or self.model.aligner_type == 'mi_obs'):
                    # This compares with raw obs
                    # aligner_loss = F.mse_loss(aligned_output, self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    # This compares with encoded obs
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.obs[:-1].view(-1, *obs_shape)   if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)))
                    aligner_loss = F.mse_loss(aligned_output, encoded_obs)
                    aligner_loss *= self.aligner_coeff
                    # print("policy loss {} value loss {} aligner loss {} pl loss {}".format(policy_loss, value_loss, aligner_loss, pl_loss))
                elif(self.model.aligner_type == 'act' or self.model.aligner_type == 'mi_act'):
                    # Cross-entropy
                    aligner_loss = F.cross_entropy(aligned_output.log(), torch.argmax(policy.clone(), dim = -1))
                    aligner_loss *= self.aligner_coeff
                    # KL Div relative distribution difference - probably not appropriate use cross-entropy
                    # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
                    # print("policy loss {} value loss {} aligner loss {} pl loss {}".format(policy_loss, value_loss, aligner_loss, pl_loss))
                elif(self.model.aligner_type == 'kl_act' or self.model.aligner_type == 'mi_kl_act'):
                    aligner_loss = F.kl_div(aligned_output.log(), policy.clone().log().detach(), log_target = True)
                    aligner_loss *= self.aligner_coeff
                elif(self.model.aligner_type == 'obs_act' or self.model.aligner_type == 'mi_obs_act'):
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.obs[:-1].view(-1, *obs_shape)   if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)))
                    aligner_loss += F.mse_loss(aligned_output[0], encoded_obs)
                    # aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.cross_entropy(aligned_output[1].log(), torch.argmax(policy.clone(), dim = -1))
                    aligner_loss *= self.aligner_coeff
                elif(self.model.aligner_type == 'obs_kl_act' or self.model.aligner_type == 'mi_obs_kl_act'):
                    with torch.no_grad():
                        encoded_obs = self.model.base.encode_obs(self.storage.obs[:-1].view(-1, *obs_shape)   if("MarlGrid" not in self.env_name) else (self.storage.img_obs[:-1].view(-1, *img_obs_shape), self.storage.df_obs[:-1].view(-1, df_obs_shape)))
                    aligner_loss += F.mse_loss(aligned_output[0], encoded_obs)
                    # aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
                    aligner_loss += F.kl_div(aligned_output[1].log(), policy.clone().log().detach(), log_target = True)
                    aligner_loss *= self.aligner_coeff
                else:
                    raise NotImplementedError

        if('mi' in self.model.aligner_type):
            # messages for step 1 to num_steps - 1, not using the first step, as messages for the first step is not produced in ms. These are messages received
            # stacked_msgs = torch.swapaxes(torch.stack(self.storage.msg[1:-1]), 0, 1)
            stacked_msgs = torch.transpose(torch.stack(self.storage.msg[1:-1]), 0, 1)
            # The way it was flattened above first flattens num steps then num processes during feedforward therefore, it has to be swapped below to get num_processes, num_steps
            # projector is needed here. NOTE: projected messages are never sent over, it is simply used for learning
            # rebatched_output_ms = torch.swapaxes(ms[1].view(num_steps, num_processes, -1)[:-1], 0, 1)
            rebatched_output_ms = torch.transpose(self.model.projector(ms[1]).view(num_steps, num_processes, -1)[:-1], 0, 1) if self.use_projector else torch.transpose(ms[1].view(num_steps, num_processes, -1)[:-1], 0, 1)
            # No longer use
            # reversed_rebatched_output_ms = torch.flip(rebatched_output_ms.clone(), [0])

            # Get positive and negative samples with labels
            # Shape (num processes, num steps, message length)
            samples = torch.cat((rebatched_output_ms, stacked_msgs), dim = -1)
            labels = torch.stack([torch.full((samples.size(1), ), label)for label in range(samples.size(0))])
            if('mi_simclr' == self.model.aligner_type):
                simclr_samples = torch.cat(torch.split(samples, int(samples.size(2) / (self.num_other_agents + 1)), dim = 2), dim = 1)
            # print(samples.size())
            # print(labels.size())
            samples = samples.view(-1, samples.size(2))
            labels = labels.view(-1)
            # print(labels)
            # print(samples[0])
            # Shuffle data
            rand_index = torch.randperm(samples.size()[0])
            samples = samples[rand_index]
            labels = labels[rand_index]

            # Split into multi-views, each view is an agent's message
            splited_samples = F.normalize(torch.stack(torch.split(samples, int(samples.size(1) / (self.num_other_agents + 1)), dim = 1), dim = 1), dim = 2)
            # splited_samples = torch.stack(torch.split(samples, int(samples.size(1) / (self.num_other_agents + 1)), dim = 1), dim = 1)
            # line below is faulty
            # splited_samples = F.normalize(torch.stack(torch.split(samples, self.num_other_agents + 1, dim = 1), dim = 1), dim = 2)

            if('mi_split' == self.model.aligner_type):
                first_half_splited_samples = splited_samples[:, :, : int(splited_samples.size(2) / 2)]
                second_half_splited_samples = splited_samples[:, :, int(splited_samples.size(2) / 2) :]
                # Get loss
                sup_con_loss = self.mi_loss(first_half_splited_samples, labels)
                aligner_loss += self.mi_coeff * sup_con_loss

                intra_traj_labels = torch.arange(1, labels.size(0) + 1).long()
                intra_sup_con_loss = self.mi_loss(second_half_splited_samples, intra_traj_labels)
                aligner_loss += self.mi_coeff * intra_sup_con_loss
            elif('mi_simclr' == self.model.aligner_type):
                simclr_loss = self.mi_loss(simclr_samples)
                aligner_loss += self.mi_coeff * simclr_loss
            else:
                # Get loss
                sup_con_loss = self.mi_loss(splited_samples, labels)
                # print("size {}, a1 loss: {}".format(splited_samples.size(), aligner_loss))
                # print("a2 loss: {}".format(sup_con_loss))

                if('mi_intra_traj' == self.model.aligner_type):
                    intra_traj_labels = torch.arange(1, labels.size(0) + 1).long()
                    intra_sup_con_loss = self.mi_loss(splited_samples, intra_traj_labels)
                    aligner_loss += self.mi_coeff * intra_sup_con_loss
                    
                if('mi' == self.model.aligner_type):
                    aligner_loss += self.mi_coeff * sup_con_loss

        if('mm' in self.model.aligner_type):
            stacked_msgs = torch.transpose(torch.stack(self.storage.msg[1:-1]), 0, 1)
            rebatched_output_ms = torch.transpose(ms[1].view(num_steps, num_processes, -1)[:-1], 0, 1)
            rand_msg_index = torch.randint(1, len(agents), (stacked_msgs.size(0), stacked_msgs.size(1))).unsqueeze(2).repeat(1, 1,  self.message_space[2])
            rand_mask_index = torch.arange(0, self.message_space[2]).unsqueeze(0).unsqueeze(0).repeat(stacked_msgs.size(0), stacked_msgs.size(1), 1) + rand_msg_index * self.message_space[2]
            samples = torch.cat((rebatched_output_ms, stacked_msgs), dim = -1)
            masked_out_messages = samples.clone().detach().gather(dim = 2, index = rand_mask_index)
            # Mask out a particular message
            samples.scatter_(index = rand_mask_index, dim = 2, value = 0.0)
            # Prediction by aligner
            masked_out_messages_predictions = self.model.aligner(samples)
            mm_loss = F.mse_loss(masked_out_messages_predictions, masked_out_messages)
            aligner_loss += mm_loss

        # Example for gradient computational graph visualization
        # dot = make_dot(msgs, params=dict(agents[1].model.base.named_parameters()))
        # dot.format = 'png'
        # dot.render('./debug_grad')

        # Testing that gradients do flow through agents
        # print(torch.sum(agents[0].model.base.critic[0].weight.grad))
        # print(torch.sum(agents[1].model.base.comm_actor[0].weight.grad))

        # calculate prediction loss for the actor using other actors' experiences
        other_agent_ids = [x for x in range(len(agents)) if x != self.agent_id]
        storages = [a.storage for a in agents]
        seac_policy_loss = 0
        seac_value_loss = 0

        total_loss = policy_loss + value_loss_coef * value_loss - entropy_coef * dist_entropy + seac_coef * seac_policy_loss + seac_coef * value_loss_coef * seac_value_loss + aligner_loss + pl_loss

        # if(len(agents) == self.agent_id + 1):
        #     # Last agent - so it no longer needs the graph
        #     total_loss.backward()
        # else:
        #     total_loss.backward(retain_graph = True)

        total_loss.backward(retain_graph = True)

        # print("aligner coeff: {}, aligner coeff grad: {}".format(self.aligner_coeff, self.aligner_coeff.grad))
        # print("mi aligner coeff: {}, mi aligner coeff grad: {}".format(self.mi_coeff, self.mi_coeff.grad))
        # print("mi aligner coeff: {} mi loss: {}".format(self.mi_coeff, sup_con_loss))
        if(seac_coef > 0.0):
            return {
                "policy_loss": policy_loss.item(),
                "value_loss": value_loss_coef * value_loss.item(),
                "dist_entropy": entropy_coef * dist_entropy.item(),
                # "importance_sampling": importance_sampling.mean().item(),
                "seac_policy_loss": seac_coef * seac_policy_loss.item(),
                "seac_value_loss": seac_coef
                * value_loss_coef
                * seac_value_loss.item(),
                "total_loss": total_loss.item(),
            }
        else:
            return {
            "policy_loss": policy_loss.item(),
            "value_loss": value_loss_coef * value_loss.item(),
            "dist_entropy": entropy_coef * dist_entropy.item(),
            "aligner_loss": aligner_loss.item() if aligner_loss != 0.0 else aligner_loss,
            "pl_loss": pl_loss.item() if pl_loss != 0.0 else pl_loss,
            "total_loss": total_loss.item(),
            "grad_norm": self.get_total_norm()
            }


    def get_total_norm(self):
        total_norm = 0
        parameters = [p for p in self.model.parameters() if p.grad is not None and p.requires_grad]
        for p in parameters:
            param_norm = p.grad.clone().detach().data.norm(2)
            total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm

    @algorithm.capture
    #@profile
    def update_shared(
        self,
        agents,
        value_loss_coef,
        entropy_coef,
        seac_coef,
        max_grad_norm,
        device,
    ):
        raise NotImplementedError("Update shared not inuse!")
        # action_shape = self.storage.actions.size()[-1]
        # num_steps, num_processes, _ = self.storage.rewards.size()

        # agent_ids = [x for x in range(len(agents))]
        # storages = [a.storage for a in agents]
        # policy_loss = 0
        # value_loss = 0
        # dist_entropy = 0
        # aligner_loss = 0
        # pl_loss = 0.0
        # for oid in agent_ids:
        #     message_inputs = torch.stack(storages[oid].msg[:-1]).view(-1, self.message_space[0])
        #     zero_message_inputs = torch.zeros(message_inputs.size(), device = device)
        #     if(self.use_memory):
        #         raise NotImplementedError("Not in use!")
        #         # other_values, logp, ent, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
        #         #     torch.stack(storages[oid].obs[:-1]).view(-1, self.mem_hidden),
        #         #     message_inputs,
        #         #     storages[oid]
        #         #     .recurrent_hidden_states[0]
        #         #     .view(-1, self.model.recurrent_hidden_state_size),
        #         #     storages[oid].masks[:-1].view(-1, 1),
        #         #     storages[oid].actions.view(-1, action_shape),
        #         # )
        #         # if(self.use_pos_listening_loss):
        #         #     _, cf_action_log_probs, _, _, _, _, cf_policy, _ = self.model.evaluate_actions(
        #         #         torch.stack(storages[oid].obs[:-1]).view(-1, self.mem_hidden),
        #         #         zero_message_inputs,
        #         #         storages[oid]
        #         #         .recurrent_hidden_states[0]
        #         #         .view(-1, self.model.recurrent_hidden_state_size),
        #         #         storages[oid].masks[:-1].view(-1, 1),
        #         #         storages[oid].actions.view(-1, action_shape),
        #         #     )
        #         #     # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
        #         #     pl_loss += -1.0 * F.l1_loss(policy, cf_policy.detach())
        #     else:
        #         if(self.env_name != "MarlGrid"):
        #             obs_shape = self.storage.obs.size()[2:]
        #         else:
        #             img_obs_shape = self.storage.img_obs.size()[2:]
        #             df_obs_shape = self.model.base    .num_direct_features
        #         other_values, logp, ent, ms, _, aligned_output, policy, _ = self.model.evaluate_actions(
        #             storages[oid].obs[:-1].view(-1, *obs_shape) if(self.env_name != "MarlGrid") else (storages[oid].img_obs[:-1].view(-1, *img_obs_shape), storages[oid].df_obs[:-1].view(-1, *df_obs_shape)),
        #             message_inputs,
        #             storages[oid]
        #             .recurrent_hidden_states[0]
        #             .view(-1, self.model.recurrent_hidden_state_size),
        #             storages[oid].masks[:-1].view(-1, 1),
        #             storages[oid].actions.view(-1, action_shape),
        #         )
        #         if(self.use_pos_listening_loss):
        #             _, cf_action_log_probs, _, _, _, _, cf_policy, _ = self.model.evaluate_actions(
        #                 storages[oid].obs[:-1].view(-1, *obs_shape)  if(self.env_name != "MarlGrid") else (storages[oid].img_obs[:-1].view(-1, *img_obs_shape), storages[oid].df_obs[:-1].view(-1, *df_obs_shape)),
        #                 zero_message_inputs,
        #                 storages[oid]
        #                 .recurrent_hidden_states[0]
        #                 .view(-1, self.model.recurrent_hidden_state_size),
        #                 storages[oid].masks[:-1].view(-1, 1),
        #                 storages[oid].actions.view(-1, action_shape),
        #             )
        #             # pl_loss += -1.0 * torch.norm((policy - cf_policy), 1, -1).mean()
        #             pl_loss += -1.0 * F.l1_loss(policy, cf_policy.detach())

        #     other_values = other_values.view(num_steps, num_processes, 1)
        #     logp = logp.view(num_steps, num_processes, 1)
        #     other_advantage = (
        #         storages[oid].returns[:-1] - other_values
        #     )  # or storages[oid].rewards

        #     importance_sampling = (
        #         logp.exp() / (storages[oid].action_log_probs.exp() + 1e-7)
        #     ).detach()
        #     # importance_sampling = 1.0
        #     value_loss += (
        #         importance_sampling * other_advantage.pow(2)
        #     ).mean()
        #     policy_loss += (
        #         -importance_sampling * logp * other_advantage.detach()
        #     ).mean()
        #     if(aligned_output != None):
        #         if(self.share_aligner):
        #             # Use the first agent's aligner for everyone else's
        #             aligned_output = agents[0].model.aligner(ms)
        #         if(self.use_memory):
        #             obs_shape = self.storage.raw_obs.size()[2:]
        #             if(self.model.aligner_type == 'obs'):
        #                 aligner_loss = F.mse_loss(aligned_output, self.storage.raw_obs[:-1].view(-1, *obs_shape))
        #             elif(self.model.aligner_type == 'act'):
        #                 # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
        #                 aligner_loss = F.cross_entropy(aligned_output.log(), torch.argmax(policy.clone(), dim = -1))
        #                 # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
        #             elif(self.model.aligner_type == 'kl_act'):
        #                 aligner_loss = F.kl_div(aligned_output.log(), policy.clone().log().detach(), log_target = True)
        #             elif(self.model.aligner_type == 'obs_act'):
        #                 aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
        #                 aligner_loss += F.cross_entropy(aligned_output[1].log(), torch.argmax(policy.clone(), dim = -1))
        #             elif(self.model.aligner_type == 'obs_kl_act'):
        #                 aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
        #                 aligner_loss += F.kl_div(aligned_output[1].log(), policy.clone().log().detach(), log_target = True)
        #             else:
        #                 raise NotImplementedError
        #         else:
        #             obs_shape = self.storage.obs.size()[2:]
        #             if(self.model.aligner_type == 'obs'):
        #                 aligner_loss = F.mse_loss(aligned_output, self.storage.obs[:-1].view(-1, *obs_shape))
        #                 # print("policy loss {} value loss {} aligner loss {} pl loss {}".format(policy_loss, value_loss, aligner_loss, pl_loss))
        #             elif(self.model.aligner_type == 'act'):
        #                 # Cross-entropy
        #                 aligner_loss = F.cross_entropy(aligned_output.log(), torch.argmax(policy.clone(), dim = -1))
        #                 # KL Div relative distribution difference - probably not appropriate use cross-entropy
        #                 # aligner_loss = F.kl_div(aligned_output.log(), policy.clone().detach())
        #                 # print("policy loss {} value loss {} aligner loss {} pl loss {}".format(policy_loss, value_loss, aligner_loss, pl_loss))
        #             elif(self.model.aligner_type == 'kl_act'):
        #                 aligner_loss = F.kl_div(aligned_output.log(), policy.clone().log().detach(), log_target = True)
        #             elif(self.model.aligner_type == 'obs_act'):
        #                 aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
        #                 aligner_loss += F.cross_entropy(aligned_output[1].log(), torch.argmax(policy.clone(), dim = -1))
        #             elif(self.model.aligner_type == 'obs_kl_act'):
        #                 aligner_loss += F.mse_loss(aligned_output[0], self.storage.raw_obs[:-1].view(-1, *obs_shape))
        #                 aligner_loss += F.kl_div(aligned_output[1].log(), policy.clone().log().detach(), log_target = True)
        #             else:
        #                 raise NotImplementedError
        #     dist_entropy += ent.mean()


        # total_loss = policy_loss + value_loss_coef * value_loss - entropy_coef * dist_entropy  + aligner_loss

        # # if(len(agents) == self.agent_id + 1):
        # #     # Last agent - so it no longer needs the graph
        # #     total_loss.backward()
        # # else:
        # #     total_loss.backward(retain_graph = True)

        # total_loss.backward(retain_graph = True)

        # return {
        #     "policy_loss": policy_loss.item(),
        #     "value_loss": value_loss_coef * value_loss.item(),
        #     "dist_entropy": entropy_coef * dist_entropy.item(),
        #     "aligner_loss": aligner_loss.item() if aligner_loss != 0.0 else aligner_loss,
        #     "pl_loss": pl_loss.item() if pl_loss != 0.0 else pl_loss,
        #     "total_loss": total_loss.item(),
        #     "grad_norm": self.get_total_norm()
        # }
