import math
import time
import random
import gym
from onpolicy.debug import debug_print
import torch
import torch.nn as nn
from onpolicy.algorithms.utils.util import init, check
from onpolicy.algorithms.utils.cnn import CNNBase
from onpolicy.algorithms.utils.mlp import MLPBase
from onpolicy.algorithms.utils.rnn import RNNLayer
from onpolicy.algorithms.utils.act import ACTLayer
from onpolicy.algorithms.utils.popart import PopArt
from onpolicy.algorithms.diffusion_ac.sfbc_model import ScoreNet, MlpScoreNet, GaussianFourierProjection, TransformerScoreNet
from onpolicy.algorithms.diffusion_ac.diffusion import Diffusion
from onpolicy.algorithms.diffusion_ac.dppo.diffusion import DiffusionModel
from onpolicy.utils.util import get_shape_from_obs_space
from torch.distributions import Distribution, Independent, Normal
from onpolicy.algorithms.diffusion_ac.dppo.mlp_diffusion import DiffusionMLP, VisionDiffusionMLP
from onpolicy.algorithms.diffusion_ac.common.vit import VitEncoder
from onpolicy.algorithms.diffusion_ac.dppo.mlp import MLP

class ActingLayer(nn.Module):
    def __init__(self, action_space, args):
        super().__init__()
        assert isinstance(action_space, list)

        self.initial_logstd = args.initial_logstd

        self.use_small_std = args.act_with_small_std

        self.action_space = action_space
        self.action_sizes = []
        self.latent_action_sizes = []
        self.act_step = args.act_step
        self.acting_params = []
        self.sep_logprob = args.sep_logprob

        # debug_print("acting. action space:", self.action_space)

        for _act_space in self.action_space:
            if isinstance(_act_space, gym.spaces.Discrete):
                self.action_sizes.append(1)
                self.latent_action_sizes.append(_act_space.n)
                self.acting_params.append(nn.Parameter(torch.ones(_act_space.n) * self.initial_logstd))
            elif isinstance(_act_space, gym.spaces.MultiDiscrete):
                for action_dim in _act_space.nvec:
                    self.action_sizes.append(1)
                    self.latent_action_sizes.append(action_dim)
                    self.acting_params.append(nn.Parameter(torch.ones(action_dim) * self.initial_logstd))
            elif isinstance(_act_space, gym.spaces.Box):
                self.action_sizes.append(_act_space.shape[0])
                self.latent_action_sizes.append(_act_space.shape[0])
                self.acting_params.append(nn.Parameter(torch.ones(_act_space.shape[0]) * self.initial_logstd))
            else:
                raise RuntimeError(f"{act_space} is not supported.")
        for i in range(len(self.latent_action_sizes)):
            self.latent_action_sizes[i] *= self.act_step

        self.acting_params = nn.ParameterList(self.acting_params)
    
    def update_logstd(self, log_std):
        for i, logstd in enumerate(self.acting_params):
            logstd.data.fill_(log_std)

    @property
    def latent_action_size(self):
        return sum(self.latent_action_sizes)
    
    @property
    def action_size(self):
        return sum(self.action_sizes)
    
    def acting(self, latent_actions, available_actions=None, deterministic=False):
        actions = []
        eps = 1e-10
        final_sampled_actions = []
        # assert available_actions is not None
        # debug_print(latent_actions.shape, available_actions.shape)
        # debug_print('fa', latent_actions.shape)
        if available_actions is not None:
            latent_actions = latent_actions.clone()
            # debug_print(latent_actions.shape, available_actions.shape)
            # debug_print(available_actions.shape)
            # print(available_actions.shape, latent_actions.shape)
            available_actions = available_actions.reshape(* latent_actions.shape)
            latent_actions[available_actions < 0.5] = - 1e10
            # latent_actions = latent_actions + (1 - available_actions) * (-1e9)
        log_probs = []
        latent_idx = 0
        for action_idx, (action_size, latent_action_size) in enumerate(zip(self.action_sizes, self.latent_action_sizes)):

            raw_actions = latent_actions[..., latent_idx:latent_idx + latent_action_size]
            # debug_print(raw_actions.shape, latent_actions.shape, )

            if latent_action_size // self.act_step > action_size:
                act_size = latent_action_size // self.act_step
                raw_actions = raw_actions.reshape(-1, act_size)
                # debug_print(raw_actions.shape)
                # discrete action space: sample with softmax
                # if available_actions is not None:
                #     raw_actions = raw_actions + (1 - available_actions[action_idx]) * (-1e9)
                mx = raw_actions.max(dim=-1, keepdim=True)[0]
                raw_actions = raw_actions - mx
                probs = raw_actions.exp() + (eps * available_actions[..., latent_idx:latent_idx + latent_action_size].reshape(-1, act_size) 
        if available_actions is not None else eps)
                probs = probs / probs.sum(dim=-1, keepdim=True)

                dist = torch.distributions.Categorical(probs)
                if deterministic:
                    gt_actions = dist.mode.unsqueeze(-1)
                else:
                    gt_actions = dist.sample().unsqueeze(-1)
                # debug_print(raw_actions)

                log_prob = dist.log_prob(gt_actions.squeeze(-1)).reshape(-1, self.act_step)

                sampled_actions = torch.zeros_like(raw_actions)
                sampled_actions.scatter_(-1, gt_actions, 1.0)
                sampled_actions = sampled_actions.reshape(-1, self.act_step, act_size)
                gt_actions = gt_actions.reshape(-1, self.act_step, 1)
                # debug_print(sampled_actions.shape, log_prob.shape, 1)
            else:
                # continuous action space: add noise
                log_std = self.acting_params[action_idx].reshape(*([1,] * (len(raw_actions.shape)-1)), -1)
                raw_actions = raw_actions.reshape(*raw_actions.shape[:-1], self.act_step, -1)
                # debug_print('fa', raw_actions.shape, log_std.unsqueeze(-2).shape)
                if self.sep_logprob:
                    dist = Normal(raw_actions, log_std.exp().unsqueeze(-2))
                else:
                    dist = Independent(Normal(raw_actions, log_std.exp().unsqueeze(-2)), 1)
                if deterministic:
                    gt_actions = dist.mean
                else:
                    gt_actions = dist.sample()
                # debug_print(gt_actions.shape, dist.log_prob(gt_actions).shape)
                log_prob = dist.log_prob(gt_actions) #/ ( - self.initial_logstd * 2)
                # debug_print('fa', log_prob.shape)
                sampled_actions = gt_actions

                # noise = torch.randn_like(raw_actions)
                # if self.use_small_std:
                #     log_std = torch.ones_like(log_std) * (-5)
                # gt_actions = raw_actions + noise * log_std.exp()

                # # log_prob = -0.5 * noise.pow(2).sum(dim=-1)
                
                # sampled_actions = gt_actions

                
                # var = (2 * log_std).exp()
                # # noise = (sampled_actions - raw_actions)**2 / var
                # log_prob = - (sampled_actions - raw_actions).pow(2) / (2 * var) - math.log(2 * math.pi) - log_std
                log_prob = log_prob#.sum(dim=-1)
                # debug_print(log_prob.shape, sampled_actions.shape)
                # debug_print(log_prob.shape, sampled_actions.shape)

            actions.append(gt_actions)
            log_probs.append(log_prob)
            final_sampled_actions.append(sampled_actions)

            latent_idx += latent_action_size
        # debug_print(len(log_probs), log_probs[0].shape, self.action_sizes, self.latent_action_sizes, torch.stack(log_probs, dim=-1).sum(dim=-1, keepdim=True).shape)
        # debug_print(final_sampled_actions)
        return torch.cat(actions, dim=-1), torch.cat(final_sampled_actions, dim=-1), torch.stack(log_probs, dim=-1)#.sum(dim=-1, keepdim=True)

    def evaluate_actions(self, latent_actions, final_sampled_actions, available_actions=None, active_masks=None):
        log_probs = []
        entropys = []
        latent_idx = 0
        eps = 1e-10
        
        if available_actions is not None:
            latent_actions = latent_actions.clone()
            available_actions = available_actions.reshape(* latent_actions.shape)
            latent_actions[available_actions < 0.5] = - 1e10
        # debug_print(active_masks.shape)
        
        for action_idx, (action_size, latent_action_size) in enumerate(zip(self.action_sizes, self.latent_action_sizes)):

            raw_actions = latent_actions[..., latent_idx:latent_idx + latent_action_size]
            sampled_actions = final_sampled_actions[..., latent_idx:latent_idx + latent_action_size]

            if latent_action_size // self.act_step > action_size:
                # discrete action space: sample with softmax
                act_size = latent_action_size // self.act_step
                raw_actions = raw_actions.reshape(-1, act_size)
                # debug_print('fa', raw_actions.shape, self.act_step)
                sampled_actions = sampled_actions.reshape(-1, act_size)
                # debug_print(raw_actions[0])
                mx = raw_actions.max(dim=-1, keepdim=True)[0]
                raw_actions = raw_actions - mx
                probs = raw_actions.exp() + (eps * available_actions[..., latent_idx:latent_idx + latent_action_size].reshape(-1, act_size) 
        if available_actions is not None else eps)
                # debug_print(raw_actions.shape)
                # if active_masks != None:
                # debug_print('active_mask', active_masks.shape)
                probs = probs / probs.sum(dim=-1, keepdim=True)
                # probs.register_hook(print)
                dist = torch.distributions.Categorical(probs)
                # debug_print(probs.shape)
                

                assert (sampled_actions.sum(dim=-1) == 1.0).all()
                # debug_print((sampled_actions * available_actions[..., latent_idx:latent_idx + latent_action_size]).sum())

                log_prob = (probs * sampled_actions).sum(dim=-1).log()
                # debug_print('log_prob', log_prob.shape, probs.shape, sampled_actions.shape, active_masks.shape)
                # debug_print(log_prob)
                entropy = dist.entropy().unsqueeze(-1)
                if active_masks is not None:
                    # debug_print(entropy.shape, active_masks[0:100])
                    entropy = (entropy*active_masks)/active_masks.mean()
                log_prob = log_prob.reshape(-1, self.act_step)
                entropy = entropy.reshape(-1, self.act_step)
                # debug_print(entropy.shape)
                # debug_print(entropy.shape)
                # debug_print(log_prob.shape, probs.shape, sampled_actions.shape, entropy.shape)
            else:
                # continuous action space: add noise
                log_std = self.acting_params[action_idx].reshape(*([1,] * (len(raw_actions.shape)-1)), -1)
                raw_actions = raw_actions.reshape(*raw_actions.shape[:-1], self.act_step, -1)
                sampled_actions = sampled_actions.reshape(*sampled_actions.shape[:-1], self.act_step, -1)
                if self.sep_logprob:
                    dist = Normal(raw_actions, log_std.exp().unsqueeze(-2))
                else:
                    dist = Independent(Normal(raw_actions, log_std.exp().unsqueeze(-2)), 1)
                # gt_actions = dist.sample()
                # sampled_actions = gt_actions
                log_prob = dist.log_prob(sampled_actions) #/ ( - self.initial_logstd * 2)
                # debug_print(log_prob.shape, sampled_actions.shape)
                entropy = dist.entropy().mean().reshape([1, 1])

            log_probs.append(log_prob)
            entropys.append(entropy)
            latent_idx += latent_action_size
        # debug_print(len(entropys))
        # debug_print(torch.stack(log_probs, dim=-1))

        return torch.stack(log_probs, dim=-1), torch.stack(entropys, dim=-1).mean(dim=-1)

class Diffusion_R_Actor(nn.Module):
    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        super(Diffusion_R_Actor, self).__init__()

        self.n_timesteps = args.n_timesteps
        self.beta_schedule = args.beta_schedule
        self.predict_epsilon = args.predict_epsilon
        self.hidden_size = args.hidden_size
        self.t_dim = args.t_dim
        self.rnum_agents = args.rnum_agents
        self.repeat_num = args.repeat_num
        self.joint_train = args.joint_train
        self.use_attention = args.use_attention
        self.logit_scaling = args.logit_scaling
        self.logit_offset = args.logit_offset
        self.use_latent_prob = args.use_latent_prob
        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self._use_policy_active_masks = args.use_policy_active_masks
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.act_step = args.act_step
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.args = args


        obs_shape = get_shape_from_obs_space(obs_space)
        # debug_print(obs_shape)
        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape) if not self.use_attention and self.joint_train else base(args, [obs_shape[0] // self.rnum_agents])
        self.base = nn.Identity()

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            if self.use_attention:
                self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)
            else:
                self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        # act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)
        self.act = ActingLayer(action_space, args)
        self.diffusion_action_size = self.act.latent_action_size
        # debug_print('fa', self.diffusion_action_size)
        obs_dim = obs_shape[0]
        # if self.use_attention:
        #     self.diffusion_model = TransformerScoreNet(self.hidden_size, self.diffusion_action_size // self.rnum_agents, self.rnum_agents, self.t_dim, device=device)
        # elif not self.joint_train:
        #     self.diffusion_model = MlpScoreNet(self.hidden_size, self.diffusion_action_size // self.rnum_agents, self.t_dim, args.unet_hidden_size, device=device, rnum_agents=self.rnum_agents, joint_train=self.joint_train)
        # else:
        #     self.diffusion_model = MlpScoreNet(obs_dim, self.diffusion_action_size, self.t_dim, args.unet_hidden_size, device=device)
        if not args.use_image:
            self.diffusion_model = DiffusionMLP(
                action_dim=self.diffusion_action_size//self.act_step,
                horizon_steps=self.act_step,
                cond_dim=obs_dim, #obs_dim,
                time_dim=32,
                mlp_dims=[args.unet_hidden_size] * args.unet_num_layer,
                cond_mlp_dims=None,
                residual_style=True,
                activation_type=args.activation,
            )
        else:
            cfg = {
                "patch_size": 8,
                "depth": 1,
                "embed_dim": 128,
                "num_heads": 4,
                "embed_style": "embed2",
                "embed_norm": 0,
            }
            vit_encoder = VitEncoder(
                obs_shape=obs_shape,
                cfg=cfg,
                num_channel=3,
                img_h=args.img_size,
                img_w=args.img_size,
            )
            self.diffusion_model = VisionDiffusionMLP(
                backbone=vit_encoder,
                action_dim=self.diffusion_action_size//self.act_step,
                horizon_steps=self.act_step,
                time_dim=32,
                mlp_dims=[args.unet_hidden_size] * args.unet_num_layer,
                cond_dim=obs_dim - args.num_img * 3 * args.img_size * args.img_size,
                residual_style=True,
                augment=True,
                num_img=args.num_img,
                img_h=args.img_size,
                img_w=args.img_size,
            )
        

        # self.diffusion = Diffusion(self.hidden_size, self.diffusion_action_size, self.diffusion_model, +1e9,
        #                            beta_schedule=self.beta_schedule, n_timesteps=self.n_timesteps,
        #                            loss_type='l2', clip_denoised=True, predict_epsilon=False, noise_scale=args.noise_scale)

        if args.use_mlp:
            self.diffusion = MLP(obs_dim, self.diffusion_action_size//self.act_step, denoising_steps=self.n_timesteps, horizon_steps=self.act_step, predict_epsilon=False, args=args,
                                    )
        else:
            self.diffusion = DiffusionModel(obs_dim, self.diffusion_action_size//self.act_step, self.diffusion_model, denoising_steps=self.n_timesteps, horizon_steps=self.act_step, predict_epsilon=False, args=args,
                                            network_path='transport-img-pretrained-base.pt',
                                    )
        # self.diffusion = DiffusionModel(obs_dim, self.diffusion_action_size//self.act_step, self.diffusion_model, denoising_steps=self.n_timesteps, horizon_steps=self.act_step,
        #                            )
        self.to(device)
        self.algo = args.algorithm_name

    def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):
        """
        Compute actions from the given inputs.
        :param obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param deterministic: (bool) whether to sample from action distribution or return the mode.

        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        # debug_print('fa', obs.shape)
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)
        # debug_print(available_actions.shape)
            
        # debug_print('shape', obs.shape, self.base(obs).shape)
        
        # if self.use_attention:
        #     obs = obs.reshape(-1, self.rnum_agents, *obs.shape[1:])

        if self.use_attention or not self.joint_train:
            actor_features = self.base(obs.reshape(obs.shape[0], self.rnum_agents, -1))
        else:
            actor_features = self.base(obs)
        # debug_print('feature', actor_features.shape, actor_features.reshape(actor_features.shape[0] * self.rnum_agents, -1).shape, rnn_states.shape, masks.shape)
        

        # debug_print('fa')
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            if self.use_attention or not self.joint_train:
                actor_features, rnn_states = self.rnn(actor_features.reshape(actor_features.shape[0] * self.rnum_agents, -1), rnn_states, masks.repeat(1, self.rnum_agents).reshape(-1, 1))
                actor_features = actor_features.reshape(actor_features.shape[0] // self.rnum_agents, self.rnum_agents, -1)
            else:
                actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
        
        
        # _, action_seqs, latent_probs, noise = self.diffusion.p_sample_loop(actor_features, tuple(list(actor_features.shape[:(-2 if self.use_attention or not self.joint_train
        #                                                                                                                    else -1)]) + [self.diffusion_action_size,]), return_diffusion=True, return_noise=True)
        # debug_print(actor_features.shape, obs.shape)
        obs = obs.reshape(obs.shape[0], -1)
        _, action_seqs, latent_probs, noise = self.diffusion.p_sample_loop(obs, return_diffusion=True, return_noise=True)
        # debug_print('action_seqs', action_seqs.shape)
        # debug_print('avail', action_seqs.shape, available_actions.shape)
        # debug_print(action_seqs.shape, action_log_probs.shape)
        
        # debug_print(action_seqs[..., -1, :].shape, available_actions)
        # debug_print(action_seqs.shape, action_seqs[..., -1, :].shape)

        actions, sampled_actions, log_probs = self.act.acting(action_seqs[..., -1, :] * self.logit_scaling + self.logit_offset, available_actions, deterministic=deterministic)
        log_probs = log_probs.squeeze(-1)
        # debug_print(latent_probs.shape, log_probs.shape)
        # log_probs = log_probs * reshape(r_active_masks_batch, compress=True).moveaxis(-1, -2)
        # log_probs = log_probs.sum(dim=-1, keepdim=True)
        # log_probs = log_probs.reshape(-1, *log_probs.shape[2:])
        # if self.use_latent_prob:
        #     # debug_print(log_probs.shape, latent_probs.shape)
        #     # debug_print(latent_probs.sum(dim=1)[:, 0], log_probs.mean(dim=1)[:])
        #     log_probs += latent_probs.sum(dim=(1, 2)).unsqueeze(-1)
        # log_probs += latent_probs.sum(dim=(1, 2)).unsqueeze(-1)
        # log_probs = log_probs.sum(dim=-1)
        # debug_print(noise[0])
        # debug_print('A', obs, noise, log_probs, log_probs.shape)
        # debug_print(noise.shape)
        # debug_print('action', actions.shape, sampled_actions.shape, log_probs.shape, action_seqs[..., -1, :].shape, action_log_probs.shape)
        return action_seqs, sampled_actions, actions, latent_probs[..., :, :], log_probs, rnn_states, noise

    def evaluate_actions(self, obs, rnn_states, sampled_actions, action_ts, agent_idx, masks, available_actions=None, active_masks=None, joint_ppo=False, noises=None, r_active_masks_batch=None, latent_actions=None):
        """
        Compute log probability and entropy of given actions.
        :param obs: (torch.Tensor) observation inputs into network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        # debug_print(latent_actions.shape, sampled_actions.shape)
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        # action_pairs = check(action_pairs).to(**self.tpdv)
        sampled_actions = check(sampled_actions).to(**self.tpdv)
        action_ts = check(action_ts).to(**self.tpdv).long()
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)
            
        # debug_print('T7', torch.cuda.memory_allocated()/1024/1024)
        if active_masks is not None:
            active_masks = check(active_masks).to(**self.tpdv)
        
        def reshape(x, compress=False):
            x = x.reshape(-1, self.act_step, *x.shape[1:]).moveaxis(1, -1 if not compress else -2)
            # debug_print('reshape', x.shape, compress)
            if compress:
                return x.reshape(*x.shape[:-2], -1)
            else:
                return x
        
        # debug_print('actor', 'sampled_actions:', reshape(sampled_actions).shape, 'action_ts:', action_ts.shape, 'noise:', reshape(noises).shape)

        sampled_actions = reshape(sampled_actions, compress=True)
        if noises is not None:
            # debug_print('noises', noises.shape)
            noises = reshape(noises, compress=True)
        
        if self.use_attention or not self.joint_train:
            actor_features = self.base(obs.reshape(obs.shape[0], self.rnum_agents, -1))
        else:
            actor_features = self.base(obs)
            
        # debug_print(actor_features.shape)
        # diff_log_probs = self.diffusion.p_sample_log_prob()


        # debug_print('T6', torch.cuda.memory_allocated()/1024/1024)
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            if self.use_attention or not self.joint_train:
                actor_features, rnn_states = self.rnn(actor_features.reshape(actor_features.shape[0] * self.rnum_agents, -1), rnn_states, masks.repeat(1, self.rnum_agents).reshape(-1, 1))
                actor_features = actor_features.reshape(actor_features.shape[0] // self.rnum_agents, self.rnum_agents, -1)
            else:
                actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
        
        actor_features = reshape(actor_features)[..., 0]
        latent_actions = reshape(latent_actions, compress=True)
        obs = reshape(obs)[..., 0]
        
    #     batch_size, num_steps, action_dim = latent_actions.shape
    # # Flatten batch and step dimensions for parallel computation
    #     flat_current = latent_actions[:, :-1].reshape(-1, action_dim)  # (batch_size * (num_steps-1), action_dim)
    #     flat_next = latent_actions[:, 1:].reshape(-1, action_dim)      # (batch_size * (num_steps-1), action_dim)
    #     flat_t = torch.arange(self.n_timesteps-1, -1, -1, device=latent_actions.device)
    #     flat_t = flat_t.repeat(batch_size)#.unsqueeze(-1)         # (batch_size * (num_steps-1), 1)
    #     flat_features = actor_features.unsqueeze(1).repeat(1, num_steps-1, 1).reshape(-1, actor_features.shape[-1])
        
        
    #     # Compute log probabilities for all steps at once
    #     # debug_print(flat_current.shape, flat_next.shape, flat_t.shape, flat_features.shape)
    #     latent_log_probs = self.diffusion.p_sample_log_prob(flat_current, flat_next, flat_t, flat_features)[0]
    #     latent_log_probs = latent_log_probs.reshape(batch_size, num_steps-1)  # Reshape back to (batch_size, num_steps-1)
        
    #     # Sum log probabilities across steps
    #     # debug_print(latent_log_probs)
    #     total_latent_log_probs = latent_log_probs.sum(dim=1, keepdim=True)

        # debug_print('actor_features', actor_features.shape, noises.shape)
        # debug_print('actor_features', actor_features.shape)

        # repeat = self.repeat_num
        # debug_print('T3', torch.cuda.memory_allocated()/1024/1024)
        obs = obs.reshape(obs.shape[0], -1)
        if noises is None:
            _, action_seqs, latent_probs, noise = self.diffusion.p_sample_loop(obs, return_diffusion=True, return_noise=True)
            # _, action_seqs, latent_probs = self.diffusion.p_sample_loop(actor_features, tuple(list(actor_features.shape[:(-2 if self.use_attention or not self.joint_train
            #                                                                                                                else -1)]) + [self.diffusion_action_size,]), return_diffusion=True, return_latent_prob=self.use_latent_prob)
        else:
            _, action_seqs, latent_probs = self.diffusion.p_sample_loop_with_noise(obs, return_diffusion=True, return_noise=False, noises=noises, detach=self.use_latent_prob)
            # _, action_seqs, latent_probs = self.diffusion.p_sample_loop_with_noise(actor_features, tuple(list(actor_features.shape[:(-2 if self.use_attention or not self.joint_train
            #                                                                                                                else -1)]) + [self.diffusion_action_size,]), return_diffusion=True, noises=noises)
        latent_actions_0 = action_seqs[..., -1, :]
        log_probs, entropys = self.act.evaluate_actions(latent_actions_0 * self.logit_scaling + self.logit_offset, sampled_actions, available_actions, active_masks = active_masks)
        # total_latent_log_probs = total_latent_log_probs.repeat(1, self.act_step).unsqueeze(-1)
        # debug_print(total_latent_log_probs.mean(), log_probs.mean(), total_latent_log_probs.std(), log_probs.std())
        # log_probs = log_probs + total_latent_log_probs
        # debug_print(total_latent_log_probs.shape, latent_log_probs.shape, log_probs.shape)
        # print(agent_idx.max())
        # debug_print(log_probs.shape, action_ts.shape)
        # joint_ppo=False
            # debug_print(log_probs)
        # if self.use_latent_prob:
        #     # debug_print(log_probs.shape, latent_probs.shape)
        #     log_probs += latent_probs.sum(dim=(1,2), keepdim=True)
            # debug_print(latent_probs.shape, log_probs.shape, latent_probs.sum(dim=1)[:, 0], log_probs.mean(dim=1)[:, 0])
        # debug_print(log_probs.shape, r_active_masks_batch.shape, reshape(r_active_masks_batch).moveaxis(-1, -2).shape)
        # debug_print(log_probs.shape, r_active_masks_batch.shape)
        # debug_print(reshape(r_active_masks_batch, compress=True).moveaxis(-1, -2).shape)
        r_active_masks_batch = reshape(r_active_masks_batch, compress=True).moveaxis(-1, -2)
        # debug_print('eval_action', log_probs.shape, r_active_masks_batch.shape)
        if self.args.sep_logprob:
            r_active_masks_batch = r_active_masks_batch.unsqueeze(-1)
        log_probs = log_probs * r_active_masks_batch
        log_probs = log_probs.sum(dim=-1, keepdim=True)
        log_probs = log_probs.reshape(-1, *log_probs.shape[2:])
        # debug_print(log_probs.shape)
        
        # debug_print(action_log_probs.shape, log_probs.shape, agent_idx.shape)
        # exit()
        action_log_probs = log_probs
        dist_entropy = entropys.mean()
        # debug_print(latent_probs.shape, log_probs.shape)
        # debug_print(dist_entropy.shape)

        # debug_print('T5', torch.cuda.memory_allocated()/1024/1024)
        # debug_print('B', obs, noises, log_probs)
        return action_log_probs, latent_probs, dist_entropy

    def bc_loss(self, obs, rnn_states, sampled_actions, masks, available_actions=None, active_masks=None):
        """
        Compute log probability and entropy of given actions.
        :param obs: (torch.Tensor) observation inputs into network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        # mn = sampled_actions.min(dim=-1, keepdim=True)
        # mx = torch.abs(sampled_actions).max(dim=-1, keepdim=True).values.detach()
        # mx = torch.max(mx, torch.ones_like(mx))
        # sampled_actions = sampled_actions / mx
        # debug_print('bc', mn, mx)
        obs = check(obs).to(**self.tpdv)
        if rnn_states is not None:
            rnn_states = check(rnn_states).to(**self.tpdv)
        sampled_actions = check(sampled_actions).to(**self.tpdv)
        if masks is not None:
            masks = check(masks).to(**self.tpdv)
        # debug_print(sampled_actions.shape)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        if active_masks is not None:
            active_masks = check(active_masks).to(**self.tpdv)

        def reshape(x, compress=False):
            x = x.reshape(-1, self.act_step, *x.shape[1:]).moveaxis(1, -1 if not compress else -2)
            # debug_print('reshape', x.shape, compress)
            if compress:
                return x.reshape(*x.shape[:-2], -1)
            else:
                return x
        
        if self.use_attention or not self.joint_train:
            actor_features = self.base(obs.reshape(obs.shape[0], self.rnum_agents, -1))
        else:
            actor_features = self.base(obs)
        # debug_print(actor_features.shape, obs.shape)

        # debug_print(rnn_states.shape, actor_features.shape, sampled_actions.shape, masks.shape)
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            if self.use_attention or not self.joint_train:
                actor_features, rnn_states = self.rnn(actor_features.reshape(actor_features.shape[0] * self.rnum_agents, -1), rnn_states, masks.repeat(1, self.rnum_agents).reshape(-1, 1))
                actor_features = actor_features.reshape(actor_features.shape[0] // self.rnum_agents, self.rnum_agents, -1)
            else:
                actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)
                
        
        
        actor_features = reshape(actor_features)[..., 0]
        sampled_actions = reshape(sampled_actions, compress=True)
        # debug_print(sampled_actions.shape, actor_features.shape)
        # debug_print('actor_features', actor_features.shape, 'sampled_actions', sampled_actions.shape)

        # bc_loss = self.diffusion.loss(sampled_actions, actor_features, weights=1.0)
        # debug_print(sampled_actions.shape, actor_features.shape)
        sampled_actions = sampled_actions.reshape(sampled_actions.shape[0], self.act_step, -1)
        bc_loss = self.diffusion.loss(sampled_actions, actor_features)
        return bc_loss

class Diffusion_R_Critic(nn.Module):
    """
    Critic network class for MAPPO. Outputs value function predictions given centralized input (MAPPO) or
                            local observations (IPPO).
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param cent_obs_space: (gym.Space) (centralized) observation space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, cent_obs_space, action_space, device=torch.device("cpu")):
        super(Diffusion_R_Critic, self).__init__()

        self.n_timesteps = args.n_timesteps
        self.beta_schedule = args.beta_schedule
        self.predict_epsilon = args.predict_epsilon
        self.hidden_size = args.hidden_size
        self.n_timesteps = args.n_timesteps
        self.t_dim = args.t_dim

        self.hidden_size = args.hidden_size
        self._use_orthogonal = args.use_orthogonal
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self._use_popart = args.use_popart
        self._use_latent_action = args.use_latent_actions
        self.tpdv = dict(dtype=torch.float32, device=device)
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]

        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
        self.base = base(args, cent_obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))

        act = ActingLayer(action_space, args)
        self.latent_action_size = act.latent_action_size
        self.diffusion_action_size = act.latent_action_size
        del act
        self.t_embd = nn.Sequential(GaussianFourierProjection(embed_dim=self.t_dim), 
                                    nn.Linear(self.t_dim, self.t_dim),
                                    nn.LayerNorm(self.t_dim),
                                    nn.LeakyReLU())
        self.act_fc = nn.Sequential(nn.Linear(self.latent_action_size, self.hidden_size),
                                    nn.LayerNorm(self.hidden_size),
                                    nn.LeakyReLU())

        inp_dim = self.hidden_size # + self.hidden_size + self.t_dim

        self.net = nn.Sequential(
            nn.Linear(inp_dim, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.ReLU(),
        )

        if self._use_popart:
            self.v_out = init_(PopArt(self.hidden_size, 1, device=device))
        elif self._use_latent_action:
            self.v_out = init_(nn.Linear(self.hidden_size, args.rnum_agents))
        else:
            self.v_out = init_(nn.Linear(self.hidden_size, 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, action_seqs, masks):
        """
        Compute actions from the given inputs.
        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

        :return values: (torch.Tensor) value function predictions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        # action_seqs = check(action_seqs).to(**self.tpdv)[..., :-1, :] # do not include the last action
        if len(cent_obs.shape) == 4:
            cent_obs = cent_obs.squeeze(-2)
            rnn_states = rnn_states.squeeze(-2)
            masks = masks.squeeze(-2)

        critic_features = self.base(cent_obs)
        
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            shape_critic = critic_features.shape
            shape_rnn = rnn_states.shape
            shape_masks = masks.shape
            critic_features = critic_features.reshape(-1, critic_features.shape[-1])
            rnn_states = rnn_states.reshape(-1, 1, rnn_states.shape[-1])
            masks = masks.reshape(-1, masks.shape[-1])
            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
            critic_features = critic_features.reshape(shape_critic)
            rnn_states = rnn_states.reshape(shape_rnn)
            masks = masks.reshape(shape_masks)
        
        
        if len(critic_features.shape) == 3:
            states = critic_features[:, :, :]
        else:
            states = critic_features.reshape(critic_features.shape[0], 1, critic_features.shape[1]).repeat(1, 1, 1) # bs x T x d
        # action_ts = torch.arange(self.n_timesteps).flip(0).reshape(1, -1).repeat(states.shape[0], 1).to(states).long() # bs x T
        # t_emb = self.t_embd(action_ts) # bs x T x d
        # act_emb = self.act_fc(action_seqs)
        inp = torch.cat([states], dim=-1)
        # debug_print(states.shape)


        values = self.v_out(self.net(inp))
        if self._use_latent_action:
            values = values.permute(0, 2, 1)
            # debug_print(values.shape)
            # debug_print(values.squeeze(-1).shape)

        return values, rnn_states

    def evaluate_states(self, cent_obs, rnn_states, action_ts, masks):
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        # diff_actions = check(diff_actions).to(**self.tpdv)
        action_ts = check(action_ts).to(**self.tpdv).long()

        critic_features = self.base(cent_obs)
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
        
        states = critic_features.reshape(critic_features.shape[0], critic_features.shape[1])
        # t_emb = self.t_embd(action_ts) # bs x d
        # act_emb = self.act_fc(diff_actions)
        inp = torch.cat([states], dim=-1)
        
        values = self.v_out(self.net(inp))

        return values, rnn_states
