from typing import Any, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from gym import spaces

from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.td3.policies import BasePolicy
from stable_baselines3.common.policies import BaseModel

import matplotlib.pyplot as plt
import os


"""
Modules
"""

class MultiHeadAttention(nn.Module):
    """Multi-head attention layer with output projection."""

    def __init__(self, embed_dim, n_head, attn_pdrop=0.0, resid_pdrop=0.0, att_type='hybrid', linear_bias=False):
        super().__init__()
        assert embed_dim % n_head == 0
        assert att_type in ['hybrid', 'cross', 'self']
        self.att_type = att_type
        # key, query, value projections for all heads
        self.key = nn.Linear(embed_dim, embed_dim, bias=linear_bias)
        self.query = nn.Linear(embed_dim, embed_dim, bias=linear_bias)
        self.value = nn.Linear(embed_dim, embed_dim, bias=linear_bias)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(embed_dim, embed_dim, bias=linear_bias)
        self.n_head = n_head

    def forward(self, x, c=None, mask=None, return_attention=False):
        B, N, T, C = x.size()  # batch size, n_particles, sequence length, embedding dimensionality (n_embd)

        query_input = x
        if self.att_type == 'hybrid':
            key_value_input = torch.cat([x, c], dim=1)
            key_value_N = key_value_input.shape[1]
        elif self.att_type == 'cross':
            key_value_input = c
            key_value_N = key_value_input.shape[1]
        else:   # self.att_type == 'self'
            key_value_input = x
            key_value_N = N

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(key_value_input).view(B, key_value_N * T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, key_value_N * T, hs)
        q = self.query(query_input).view(B, N * T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, N * T, hs)
        v = self.value(key_value_input).view(B, key_value_N * T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, key_value_N * T, hs)
        # causal self-attention; Self-attend: (B, nh, N * T, hs) x (B, nh, hs, N  *T) -> (B, nh, N * T, N *T )
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))  # (B, nh, N * T, key_value_N * T)
        if mask is not None:
            mask = mask.unsqueeze(1).expand(-1, self.n_head, -1, -1)
            att.masked_fill_(mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        if return_attention:
            attention_matrix = att
        att = self.attn_drop(att)
        y = att @ v  # (B, nh, N*T, key_value_N*T) x (B, nh, key_value_N*T, hs) -> (B, nh, N*T, hs)
        y = y.transpose(1, 2).contiguous().view(B, N * T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        y = y.view(B, N, T, -1)

        # return
        if return_attention:
            return y, attention_matrix
        else:
            return y

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, h_dim, n_head, attn_pdrop=0.1, resid_pdrop=0.1, att_type='self'):
        super().__init__()
        self.att_type = att_type

        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        if self.att_type != 'self':
            self.ln_c = nn.LayerNorm(embed_dim)

        self.attn = MultiHeadAttention(embed_dim, n_head, attn_pdrop, resid_pdrop, att_type)

        self.mlp = nn.Sequential(nn.Linear(embed_dim, h_dim),
                                 nn.ReLU(True),
                                 nn.Linear(h_dim, h_dim),
                                 nn.ReLU(True),
                                 nn.Linear(h_dim, embed_dim),
                                 nn.Dropout(resid_pdrop))

    def forward(self, x_in, c=None, x_mask=None, c_mask=None, return_attention=False):
        mask = x_mask
        if self.att_type != 'self':
            c = self.ln_c(c)

        if return_attention:
            x, attention_matrix = self.attn(self.ln1(x_in), c, mask, return_attention)
            x = x + x_in
        else:
            x = x_in + self.attn(self.ln1(x_in), c, mask)

        x = x + self.mlp(self.ln2(x))

        if return_attention:
            return x, attention_matrix
        else:
            return x

def get_action_dim(action_space):
    return action_space.shape[0]

"""
SB3 Parent Policy
"""
class CustomTD3Policy(BasePolicy):
    """
    General TD3 Policy class.

    :param observation_space: Observation space
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param actor_class: architecture class to be used for actor and target networks
    :param actor_kwargs: actor kwargs
    :param critic_class: architecture class to be used for critic and target networks
    :param critic_kwargs: critic kwargs
    :param n_critics: Number of critic networks to create.
    :param optimizer_class: The optimizer to use, ``th.optim.Adam`` by default
    :param optimizer_kwargs: Additional keyword arguments, excluding the learning rate, to pass to the optimizer
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        actor_class,
        actor_kwargs: Dict[str, Any],
        critic_class,
        critic_kwargs: Optional[Dict[str, Any]] = {},
        n_critics: int = 2,
        optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        super(CustomTD3Policy, self).__init__(
            observation_space,
            action_space,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=True,
        )

        # actor
        self.actor_class = actor_class
        self.actor_kwargs = actor_kwargs
        self.actor_kwargs.update({
            "observation_space": self.observation_space,
            "action_space": self.action_space,
        })
        self.actor, self.actor_target = None, None

        # critic
        self.critic_class = critic_class
        self.critic_kwargs = actor_kwargs.copy()
        self.critic_kwargs.update(critic_kwargs)
        self.critic_kwargs.update({"n_critics": n_critics})
        self.critic, self.critic_target = None, None

        # create networks and optimizers
        self._build(lr_schedule)

    def _build(self, lr_schedule: Schedule) -> None:
        # Create actor and target
        self.actor = self.actor_class(**self.actor_kwargs).to(self.device)
        self.actor_target = self.actor_class(**self.actor_kwargs).to(self.device)
        # Initialize the target to have the same weights as the actor
        self.actor_target.load_state_dict(self.actor.state_dict())

        self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

        # Create critic and target
        self.critic = self.critic_class(**self.critic_kwargs).to(self.device)
        self.critic_target = self.critic_class(**self.critic_kwargs).to(self.device)
        # Initialize the target to have the same weights as the actor
        self.critic_target.load_state_dict(self.critic.state_dict())

        self.critic.optimizer = self.optimizer_class(self.critic.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

        # Target networks should always be in eval mode
        self.actor_target.set_training_mode(False)
        self.critic_target.set_training_mode(False)

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        data.update(
            dict(
                n_critics=self.critic_kwargs["n_critics"],
                lr_schedule=self._dummy_schedule,  # dummy lr schedule, not needed for loading policy alone
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
            )
        )
        return data

    def forward(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        return self._predict(observation, deterministic=deterministic)

    def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        # Note: the deterministic parameter is ignored in the case of TD3. Predictions are always deterministic.
        return self.actor(observation)

    def set_training_mode(self, mode: bool) -> None:
        """
        Put the policy in either training or evaluation mode.

        This affects certain modules, such as batch normalisation and dropout.

        :param mode: if true, set to training mode, else set to evaluation mode
        """
        self.actor.set_training_mode(mode)
        self.critic.set_training_mode(mode)
        self.training = mode

"""
Entity Interaction Transformer Policy
"""
class EITActor(BasePolicy):
    def __init__(self, observation_space, action_space,
                 embed_dim=64, h_dim=128, n_head=1, dropout=0.0,
                 masking=False, goal_cond=False, **kwargs):
        super(EITActor, self).__init__(observation_space, action_space, squash_output=True)

        action_dim = get_action_dim(self.action_space)
        observation_shape = observation_space["achieved_goal"].shape
        particle_fdim = observation_shape[-1]

        self.masking = masking
        self.multiview = observation_shape[0] > 1 and len(observation_shape) == 3
        self.goal_cond = goal_cond

        particle_dim = particle_fdim - 1 if self.masking else particle_fdim
        self.particle_projection = nn.Linear(particle_dim, embed_dim)
        self.particle_self_att1 = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')

        if self.goal_cond:
            self.goal_particle_projection = nn.Linear(particle_dim, embed_dim)
            self.particle_cross_att = TransformerBlock(embed_dim, h_dim, n_head,
                                               attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.particle_self_att2 = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')


        self.particle_pool_att = TransformerBlock(embed_dim, h_dim, n_head,
                                          attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.ln = nn.LayerNorm(embed_dim)
        self.linear_out = nn.Linear(embed_dim, embed_dim, bias=True)

        self.output_mlp = nn.Sequential(nn.Linear(embed_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, action_dim))

        # particle encoding
        if self.multiview:
            self.view1_encoding = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))
            self.view2_encoding = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

        # special particle
        self.out_particle = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

    def forward(self, obs, return_attention=False):
        particles = obs["achieved_goal"]
        goal_particles = obs["desired_goal"]

        if len(particles.shape) == 4:
            bs, n_views, n_particles, feature_dim = particles.shape
        else:
            bs, n_particles, feature_dim = particles.shape
            n_views = 1

        if return_attention:
            attention_dict = {}

        # preprocess particles and produce masks
        state_mask, goal_mask = None, None
        if self.masking:
            # prepare attention masks (based on obj_on)
            particles_obj_on = particles[..., -1].view(bs, -1)
            particles = particles[..., :-1]  # remove obj_on from features
            state_mask = torch.where(particles_obj_on.unsqueeze(-1) < 0, True, False).transpose(1, 2)

            if self.goal_cond:
                goal_particles_obj_on = goal_particles[..., -1].view(bs, -1)
                goal_particles = goal_particles[..., :-1]  # remove obj_on from goal features
                goal_mask = torch.where(goal_particles_obj_on.unsqueeze(-1) < 0, True, False).transpose(1, 2)

        # project particle features
        particles = self.particle_projection(particles)
        if self.multiview:
            # add view identifying encoding
            particles_view1 = particles[:, 0] + self.view1_encoding.repeat(bs, n_particles, 1)
            particles_view2 = particles[:, 1] + self.view2_encoding.repeat(bs, n_particles, 1)
            particles = torch.cat([particles_view1, particles_view2], dim=1)
        else:
            particles = particles.squeeze(1)

        # forward through self-attention block1
        x = particles.unsqueeze(2)  # [bs, n_particles + 1, 1, embed_dim]
        if return_attention:
            x, attention_matrix = self.particle_self_att1(x, x_mask=state_mask, return_attention=True)
            attention_dict["self_1"] = attention_matrix
        else:
            x = self.particle_self_att1(x, x_mask=state_mask)

        if self.goal_cond:
            # project goal particle features
            goal_particles = self.goal_particle_projection(goal_particles)
            if self.multiview:
                # add goal view identifying encoding
                goal_particles_view1 = goal_particles[:, 0] + self.view1_encoding.repeat(bs, n_particles, 1)
                goal_particles_view2 = goal_particles[:, 1] + self.view2_encoding.repeat(bs, n_particles, 1)
                goal_particles = torch.cat([goal_particles_view1, goal_particles_view2], dim=1)
            else:
                goal_particles = goal_particles.squeeze(1)

            # forward through cross-attention block
            c = goal_particles.unsqueeze(2)  # [bs, n_particles, 1, embed_dim]
            if return_attention:
                x, attention_matrix = self.particle_cross_att(x, c, x_mask=goal_mask, return_attention=True)
                attention_dict["cross"] = attention_matrix
            else:
                x = self.particle_cross_att(x, c, x_mask=goal_mask)

        # forward through self-attention block2
        if return_attention:
            x, attention_matrix = self.particle_self_att2(x, x_mask=state_mask, return_attention=True)
            attention_dict["self_2"] = attention_matrix
        else:
            x = self.particle_self_att2(x, x_mask=state_mask)

        # pool using special output particle
        out_particle = self.out_particle.repeat(bs, 1, 1)
        out_particle = out_particle.unsqueeze(2)  # [bs, 1, 1, embed_dim]
        if return_attention:
            x_agg, attention_matrix = self.particle_pool_att(out_particle, x, x_mask=state_mask, return_attention=True)
            attention_dict["agg"] = attention_matrix
        else:
            x_agg = self.particle_pool_att(out_particle, x, x_mask=state_mask)
        x_agg = x_agg.squeeze(1, 2)  # [bs, embed_dim]
        # final layer norm
        x_agg = self.linear_out(self.ln(x_agg))

        # forward through output MLP
        action = torch.tanh(self.output_mlp(x_agg))  # [bs, action_dim]

        if return_attention:
            return action, attention_dict
        else:
            return action

    def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        # Note: the deterministic parameter is ignored in the case of TD3. Predictions are always deterministic.
        return self(observation)


class EITCriticNetwork(BaseModel):
    def __init__(self, observation_space, action_space,
                 embed_dim=64, h_dim=128, n_head=1, dropout=0.0,
                 masking=False, goal_cond=False, action_particle=True):
        super(EITCriticNetwork, self).__init__(observation_space, action_space)

        action_dim = get_action_dim(self.action_space)
        observation_shape = observation_space["achieved_goal"].shape
        particle_fdim = observation_shape[-1]

        self.masking = masking
        self.multiview = observation_shape[0] > 1 and len(observation_shape) == 3
        self.goal_cond = goal_cond
        self.action_particle = action_particle

        self.action_projection = nn.Sequential(nn.Linear(action_dim, h_dim),
                                               nn.ReLU(True),
                                               nn.Linear(h_dim, embed_dim))

        particle_dim = particle_fdim - 1 if self.masking else particle_fdim
        self.particle_projection = nn.Linear(particle_dim, embed_dim)
        self.particle_self_att1 = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')

        if self.goal_cond:
            self.goal_particle_projection = nn.Linear(particle_dim, embed_dim)
            self.particle_cross_att = TransformerBlock(embed_dim, h_dim, n_head,
                                               attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.particle_self_att2 = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')

        self.particle_pool_att = TransformerBlock(embed_dim, h_dim, n_head,
                                          attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.ln = nn.LayerNorm(embed_dim)
        self.linear_out = nn.Linear(embed_dim, embed_dim, bias=True)

        self.output_mlp = nn.Sequential(nn.Linear(2 * embed_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, 1))

        # particle encoding
        if self.multiview:
            self.view1_encoding = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))
            self.view2_encoding = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

        # special particle
        self.out_particle = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

    def forward(self, obs, action):
        particles = obs["achieved_goal"]
        goal_particles = obs["desired_goal"]

        if len(particles.shape) == 4:
            bs, n_views, n_particles, feature_dim = particles.shape
        else:
            bs, n_particles, feature_dim = particles.shape
            n_views = 1

        # preprocess particles and produce masks
        state_mask, goal_mask = None, None
        if self.masking:
            # prepare attention masks (based on obj_on)
            particles_obj_on = particles[..., -1].view(bs, -1)
            if self.action_particle:
                particles_obj_on = torch.cat([particles_obj_on.new_ones([bs, 1]), particles_obj_on], dim=-1)  # add special particles
            particles = particles[..., :-1]  # remove obj_on from features
            state_mask = torch.where(particles_obj_on.unsqueeze(-1) < 0, True, False).transpose(1, 2)

            if self.goal_cond:
                goal_particles_obj_on = goal_particles[..., -1].view(bs, -1)
                goal_particles = goal_particles[..., :-1]  # remove obj_on from goal features
                goal_mask = torch.where(goal_particles_obj_on.unsqueeze(-1) < 0, True, False).transpose(1, 2)

        # project particle features
        particles = self.particle_projection(particles)
        if self.multiview:
            # add view identifying encoding
            particles_view1 = particles[:, 0] + self.view1_encoding.repeat(bs, n_particles, 1)
            particles_view2 = particles[:, 1] + self.view2_encoding.repeat(bs, n_particles, 1)
            particles = torch.cat([particles_view1, particles_view2], dim=1)
        else:
            particles = particles.squeeze(1)

        # project action and add to particles
        action_particle = self.action_projection(action)
        if self.action_particle:
            x = torch.cat([action_particle.unsqueeze(1), particles], dim=1)  # [bs, n_particles + 1, embed_dim]
        else:
            x = particles  # [bs, n_particles, embed_dim]

        # forward through self-attention block1
        x = x.unsqueeze(2)  # [bs, n_particles + 1, 1, embed_dim]
        x = self.particle_self_att1(x, x_mask=state_mask)

        if self.goal_cond:
            # project goal particle features
            goal_particles = self.goal_particle_projection(goal_particles)
            if self.multiview:
                # add goal view identifying encoding
                goal_particles_view1 = goal_particles[:, 0] + self.view1_encoding.repeat(bs, n_particles, 1)
                goal_particles_view2 = goal_particles[:, 1] + self.view2_encoding.repeat(bs, n_particles, 1)
                goal_particles = torch.cat([goal_particles_view1, goal_particles_view2], dim=1)
            else:
                goal_particles = goal_particles.squeeze(1)

            # forward through cross-attention block
            c = goal_particles.unsqueeze(2)  # [bs, n_particles, 1, embed_dim]
            x = self.particle_cross_att(x, c, x_mask=goal_mask)

        # forward through self-attention block2
        x = self.particle_self_att2(x, x_mask=state_mask)

        # pool using special output particle
        out_particle = self.out_particle.repeat(bs, 1, 1)
        if self.action_particle:
            action_particle_out = x[:, 0].clone()
            x_out = torch.cat([out_particle, action_particle_out], dim=1)  # [bs, 2, embed_dim]
        else:
            x_out = out_particle
        x_out = x_out.unsqueeze(2)  # [bs, 2, 1, embed_dim]
        x_out = self.particle_pool_att(x_out, x, x_mask=state_mask)
        x_out = x_out.squeeze(2)  # [bs, 2, embed_dim]
        # final layer norm
        x_out = self.linear_out(self.ln(x_out))

        if self.action_particle:
            x_agg = torch.cat([x_out[:, 0], x_out[:, 1]], dim=-1)  # [bs, 2 * embed_dim]
        else:
            x_agg = torch.cat([x_out[:, 0], action_particle], dim=-1)  # [bs, 2 * embed_dim]

        # forward through output MLP
        output = self.output_mlp(x_agg)  # [bs, output_dim]
        return output


class EITCritic(BaseModel):
    def __init__(self, observation_space, action_space, n_critics=2, action_particle=True,
                 embed_dim=64, h_dim=256, n_head=1, dropout=0.0,
                 masking=False, goal_cond=False, **kwargs):
        super().__init__(observation_space, action_space)

        self.n_critics = n_critics
        self.q_networks = []
        for idx in range(n_critics):
            q_net = EITCriticNetwork(observation_space, action_space,
                                        embed_dim, h_dim, n_head, dropout,
                                        masking, goal_cond, action_particle)
            self.add_module(f"qf{idx}", q_net)
            self.q_networks.append(q_net)

    def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        qvalue_outputs = []
        for i in range(self.n_critics):
            value = self.q_networks[i](obs, actions)
            qvalue_outputs.append(value)
        return tuple(qvalue_outputs)

    def q1_forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        Only predict the Q-value using the first network.
        This allows to reduce computation when all the estimates are not needed
        (e.g. when updating the policy in TD3).
        """
        value = self.q_networks[0](obs, actions)
        return value

"""
Block Transfomer policy
"""
class BlockEmbedding(nn.Module):
    def __init__(self, num_slots, num_blocks, block_dim):
        super().__init__()
        self.slot_embed = nn.Embedding(num_slots, block_dim)
        self.block_embed = nn.Embedding(num_blocks, block_dim)

    def forward(self, block_feats, slot_ids, block_ids):
        emb = 0
        if slot_ids is not None:
            emb = emb + self.slot_embed(slot_ids)
        if block_ids is not None:
            emb = emb + self.block_embed(block_ids)
        return block_feats + emb
    
class BlockTransformerActor(BasePolicy):
    def __init__(self, observation_space, action_space,
                 embed_dim=64, h_dim=256, n_head=8, dropout=0.0,
                 num_slots=5, num_bg_slots=1, num_fg_slots=3, num_ag_slots=1, num_blocks=8, goal_cond=False, **kwargs):
        super(BlockTransformerActor, self).__init__(observation_space, action_space, squash_output=True)

        self.goal_cond = goal_cond

        action_dim = get_action_dim(self.action_space)
        observation_shape = observation_space["achieved_goal"].shape
        slot_dim = observation_shape[-1]           # slot_dim
        block_dim = slot_dim // num_blocks         # block_dim

        self.num_slots = num_slots
        self.num_fg_slots = num_fg_slots
        self.num_blocks = num_blocks

        self.ocr_projection_fg = nn.Linear(block_dim, embed_dim)
        self.ocr_projection_ag = nn.Linear(slot_dim, embed_dim)

        if self.goal_cond:
            self.goal_ocr_projection_fg = nn.Linear(block_dim, embed_dim)
            self.goal_ocr_projection_ag = nn.Linear(slot_dim, embed_dim)

        self.obj_cross_att = TransformerBlock(embed_dim, h_dim, n_head,
                                    attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.obj_queries = nn.Parameter(0.02 * torch.randn(1, self.num_fg_slots, 1, embed_dim))
        self.obj_pool_attn = TransformerBlock(embed_dim, h_dim, n_head,
                                    attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.self_att = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')
        self.pool_att = TransformerBlock(embed_dim, h_dim, n_head,
                                          attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.ln = nn.LayerNorm(embed_dim)
        self.linear_out = nn.Linear(embed_dim, embed_dim, bias=True)

        self.output_mlp = nn.Sequential(nn.Linear(embed_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, action_dim))

        self.out_particle = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

    def forward(self, obs):
        particles      = obs["achieved_goal"]     # [B, 1+S+1, slot_dim] # [B, K, block_dim]
        goal_particles = obs["desired_goal"]      # [B, 1+S+1, slot_dim] # [B, K, block_dim]

        B = particles.size(0)
        S = self.num_fg_slots
        K = self.num_blocks
        d_blk = particles.size(-1) // K

        # FG tokens
        fg_slots      = particles[:, :-1]                     # [B,S,slot_dim]
        fg_blocks     = fg_slots.reshape(B, S, K, d_blk)          # [B,S,K,d]
        g_fg_slots    = goal_particles[:, :-1]
        g_fg_blocks   = g_fg_slots.reshape(B, S, K, d_blk)        # [B,S,K,d]

        # Hungarian matchibng
        idxs = [ ]   # Static block index
        desc_in   = self.compute_static_blocks(fg_blocks,   idxs)          # [B,S,d]
        desc_goal = self.compute_static_blocks(g_fg_blocks, idxs)          # [B,S,d]
        perms = self.compute_hungarian_perm(desc_in, desc_goal)

        s_idx = perms[:, :, None, None]
        g_fg_blocks_perm = g_fg_blocks.gather(dim=1, index=s_idx.expand(-1, -1, K, d_blk))

        fg_blocks = self.ocr_projection_fg(fg_blocks)                 # [B,S,K,D]
        g_fg_blocks = self.goal_ocr_projection_fg(g_fg_blocks_perm)     # [B,S,K,D]

        # AG token
        ag_token = self.ocr_projection_ag(particles[:, -1].unsqueeze(1))    # [B,1,D]
        g_ag_token = self.goal_ocr_projection_ag(goal_particles[:, -1].unsqueeze(1)) # [B,1,D]

        # Block-wise cross-attention
        O_list = []
        for i in range(S):
            c_i = g_fg_blocks[:, i].unsqueeze(2)   # [B, K, 1, D]
            q_i = fg_blocks[:, i].unsqueeze(2)     # [B, K, 1, D]

            o_i = self.obj_cross_att(q_i, c_i, x_mask=None, return_attention=False)  # [B, K, 1, D]
            o_i = o_i.squeeze(2)                              # [B, K, D]

            q_pool = self.obj_queries[:, i:i+1].expand(q_i.size(0), -1, -1, -1)  # [B, 1, 1, D]
            O_i = self.obj_pool_attn(q_pool, o_i.unsqueeze(2), x_mask=None)      # [B, 1, 1, D]
            O_i = O_i.squeeze(2).squeeze(1)                                      # [B, D]

            O_list.append(O_i)
        O = torch.stack(O_list, dim=1)  # [B, S, D]

        oc_tokens = torch.cat([ag_token, g_ag_token, O], dim=1)                    # [B, N_tok, D]

        # Self-attention and final MLP
        x = self.self_att(oc_tokens.unsqueeze(2), x_mask=None)  # [B,N_tok,1,D]

        out_particle = self.out_particle.repeat(B, 1, 1).unsqueeze(2)  # [B,1,1,D]
        x_agg = self.pool_att(out_particle, x, x_mask=None)   # [B,1,1,D]
        x_agg = x_agg.squeeze(1, 2)                                   # [B,D]
        x_agg = self.linear_out(self.ln(x_agg))
        action = torch.tanh(self.output_mlp(x_agg))

        return action   
    
    def _predict(self, observation: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        # Note: the deterministic parameter is ignored in the case of TD3. Predictions are always deterministic.
        return self(observation)

    def compute_static_blocks(self, blocks, idxs):
        sel = blocks[:,:,idxs,:]
        desc = sel.mean(dim=2)
        desc = F.normalize(desc, dim=-1) 

        return desc

    @torch.no_grad()
    def compute_hungarian_perm(self, desc_in: torch.Tensor, desc_goal: torch.Tensor) -> torch.Tensor:
        """
        desc_in:   [B, 3, d]
        desc_goal: [B, 3, d]
        return: perms [B, 3]  each row is a permutation of [0,1,2]
        """
        B, S, d = desc_in.shape

        x = F.normalize(desc_in,   dim=-1)
        y = F.normalize(desc_goal, dim=-1)
        sim = torch.matmul(x, y.transpose(1, 2))     # [B, 3, 3]
        cost = 1.0 - sim                             # [B, 3, 3]

        perms = torch.tensor([
            [0,1,2],[0,2,1],[1,0,2],
            [1,2,0],[2,0,1],[2,1,0]
        ], device=cost.device)                        # [6,3]
        P = perms.size(0)

        C = cost.unsqueeze(1).expand(B, P, S, S)      # [B,P,3,3]
        Cg = torch.gather(C, 3, perms.view(1, P, 1, S).expand(B, P, S, S))  # [B,P,3,3]
        cost_p = Cg.diagonal(dim1=2, dim2=3).sum(-1) # [B,P]
        best = cost_p.argmin(dim=1)                   # [B]

        return perms[best]   

class BlockTransformerCriticNetwork(BaseModel):
    def __init__(self, observation_space, action_space,
                 embed_dim=64, h_dim=256, n_head=4, dropout=0.0,
                 num_slots=5, num_bg_slots=1, num_fg_slots=3, num_ag_slots=1, num_blocks=8, goal_cond=False, action_particle=True):
        super(BlockTransformerCriticNetwork, self).__init__(observation_space, action_space)

        self.goal_cond = goal_cond
        self.action_particle = action_particle

        action_dim = get_action_dim(self.action_space)
        observation_shape = observation_space["achieved_goal"].shape
        slot_dim = observation_shape[-1]           # slot_dim
        block_dim = slot_dim // num_blocks         # block_dim

        self.num_slots = num_slots
        self.num_fg_slots = num_fg_slots
        self.num_blocks = num_blocks

        self.ocr_projection_fg = nn.Linear(block_dim, embed_dim)
        self.ocr_projection_ag = nn.Linear(slot_dim, embed_dim)

        if self.goal_cond:
            self.goal_ocr_projection_fg = nn.Linear(block_dim, embed_dim)
            self.goal_ocr_projection_ag = nn.Linear(slot_dim, embed_dim)

        self.obj_cross_att = TransformerBlock(embed_dim, h_dim, n_head,
                                    attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')
        self.obj_queries = nn.Parameter(0.02 * torch.randn(1, self.num_fg_slots, 1, embed_dim))
        self.obj_pool_attn = TransformerBlock(embed_dim, h_dim, n_head,
                                    attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.self_att = TransformerBlock(embed_dim, h_dim, n_head,
                                           attn_pdrop=dropout, resid_pdrop=dropout, att_type='self')
        self.pool_att = TransformerBlock(embed_dim, h_dim, n_head,
                                          attn_pdrop=dropout, resid_pdrop=dropout, att_type='cross')

        self.action_projection = nn.Sequential(nn.Linear(action_dim, h_dim),
                                               nn.ReLU(True),
                                               nn.Linear(h_dim, embed_dim))

        self.ln = nn.LayerNorm(embed_dim)
        self.linear_out = nn.Linear(embed_dim, embed_dim, bias=True)

        self.output_mlp = nn.Sequential(nn.Linear(2 * embed_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, h_dim),
                                        nn.ReLU(True),
                                        nn.Linear(h_dim, 1))

        self.out_particle = nn.Parameter(0.02 * torch.randn(1, 1, embed_dim))

    def forward(self, obs, action):
        particles      = obs["achieved_goal"]
        goal_particles = obs["desired_goal"]

        B      = particles.size(0)
        S      = self.num_fg_slots
        K      = self.num_blocks
        d_blk  = particles.size(-1) // K

        # FG tokens
        fg_slots    = particles[:, :-1]                          # [B,S,slot_dim]
        fg_blocks   = fg_slots.view(B, S, K, d_blk)       # [B,S,K,d_blk]
        g_fg_slots  = goal_particles[:, :-1]
        g_fg_blocks = g_fg_slots.view(B, S, K, d_blk)     # [B,S,K,d_blk]

        # Hungarian matching
        idxs   = [ ] # Static block index
        desc_in   = self.compute_static_blocks(fg_blocks,   idxs)  # [B,S,d_blk]
        desc_goal = self.compute_static_blocks(g_fg_blocks, idxs)  # [B,S,d_blk]
        perms = self.compute_hungarian_perm(desc_in, desc_goal)

        s_idx = perms[:, :, None, None]
        g_fg_blocks_perm = g_fg_blocks.gather(dim=1, index=s_idx.expand(-1, -1, K, d_blk))  # [B,S,K,d_blk]

        fg_blocks   = self.ocr_projection_fg(fg_blocks)             # [B,S,K,D]
        g_fg_blocks = self.goal_ocr_projection_fg(g_fg_blocks_perm) if self.goal_cond else None  # [B,S,K,D] or None

        # AG token
        ag_token = self.ocr_projection_ag(particles[:, -1].unsqueeze(1))         # [B,1,D]
        g_ag_token = self.goal_ocr_projection_ag(goal_particles[:, -1].unsqueeze(1))  # [B,1,D]

        # Block-wise cross-attention
        O_list = []
        for i in range(S):
            c_i = g_fg_blocks[:, i].unsqueeze(2)
            q_i = fg_blocks[:, i].unsqueeze(2)                                                      # [B,K,1,D]

            o_i = self.obj_cross_att(q_i, c_i, x_mask=None)   # [B,K,1,D]
            o_i = o_i.squeeze(2)                               # [B,K,D]

            q_pool = self.obj_queries[:, i:i+1].expand(B, -1, -1, -1)   # [B,1,1,D]
            O_i    = self.obj_pool_attn(q_pool, o_i.unsqueeze(2), x_mask=None)  # [B,1,1,D]
            O_i    = O_i.squeeze(2).squeeze(1)                           # [B,D]

            O_list.append(O_i)
        O = torch.stack(O_list, dim=1)  # [B,S,D]

        oc_tokens = torch.cat([ag_token, g_ag_token, O], dim=1)              # [B, 2+S, D]

        # Action token if used
        action_particle = self.action_projection(action)                      # [B,D]
        if self.action_particle:
            x = torch.cat([action_particle.unsqueeze(1), oc_tokens], dim=1)  # [B, 1+N_tok, D]
        else:
            x = oc_tokens

        # Self-attention and final MLP
        x = self.self_att(x.unsqueeze(2), x_mask=None)       # [B, N, 1, D]

        out_particle = self.out_particle.repeat(B, 1, 1)
        if self.action_particle:
            action_particle_out = x[:, 0].clone()
            x_out = torch.cat([out_particle, action_particle_out], dim=1)  # [bs, 2, embed_dim]
        else:
            x_out = out_particle
        x_out = x_out.unsqueeze(2)  # [bs, 2, 1, embed_dim]
        x_out = self.pool_att(x_out, x, x_mask=None)
        x_out = x_out.squeeze(2)  # [bs, 2, embed_dim]
        x_out = self.linear_out(self.ln(x_out))

        if self.action_particle:
            x_agg = torch.cat([x_out[:, 0], x_out[:, 1]], dim=-1)  # [bs, 2 * embed_dim]
        else:
            x_agg = torch.cat([x_out[:, 0], action_particle], dim=-1)  # [bs, 2 * embed_dim]
        output = self.output_mlp(x_agg)  # [bs, output_dim]

        return output

    def compute_static_blocks(self, blocks, idxs):
        """Select blocks at idxs, mean-pool per slot, and L2 normalize."""
        sel  = blocks[:, :, idxs, :]
        desc = sel.mean(dim=2)
        desc = F.normalize(desc, dim=-1)
        return desc
    
    @torch.no_grad() 
    def compute_hungarian_perm(self, desc_in: torch.Tensor, desc_goal: torch.Tensor) -> torch.Tensor:
        """
        desc_in:   [B, 3, d]
        desc_goal: [B, 3, d]
        return: perms [B, 3]  each row is a permutation of [0,1,2]
        """
        B, S, d = desc_in.shape
        assert S == 3, "This fast path is for S=3."

        x = F.normalize(desc_in,   dim=-1)
        y = F.normalize(desc_goal, dim=-1)
        sim = torch.matmul(x, y.transpose(1, 2))     # [B, 3, 3]
        cost = 1.0 - sim                             # [B, 3, 3]

        perms = torch.tensor([
            [0,1,2],[0,2,1],[1,0,2],
            [1,2,0],[2,0,1],[2,1,0]
        ], device=cost.device)                        # [6,3]
        P = perms.size(0)

        C = cost.unsqueeze(1).expand(B, P, S, S)      # [B,P,3,3]
        Cg = torch.gather(C, 3, perms.view(1, P, 1, S).expand(B, P, S, S))  # [B,P,3,3]
        cost_p = Cg.diagonal(dim1=2, dim2=3).sum(-1) # [B,P]
        best = cost_p.argmin(dim=1)                   # [B]
        return perms[best]   

class BlockTransformerCritic(BaseModel):
    def __init__(self, observation_space, action_space, n_critics=2, action_particle=True,
                 embed_dim=64, h_dim=256, n_head=8, dropout=0.0,
                 num_slots=5, num_bg_slots=1, num_fg_slots=3, num_ag_slots=1, num_blocks=8,
                 goal_cond=False, **kwargs):
        super().__init__(observation_space, action_space)

        self.n_critics = n_critics
        self.q_networks = []
        for idx in range(n_critics):
            q_net = BlockTransformerCriticNetwork(observation_space, action_space,
                                        embed_dim, h_dim, n_head, dropout,
                                        num_slots, num_bg_slots, num_fg_slots, num_ag_slots, num_blocks,
                                        goal_cond, action_particle)
            self.add_module(f"qf{idx}", q_net)
            self.q_networks.append(q_net)

    def forward(self, obs: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        qvalue_outputs = []
        for i in range(self.n_critics):
            value = self.q_networks[i](obs, actions)
            qvalue_outputs.append(value)
        return tuple(qvalue_outputs)

    def q1_forward(self, obs: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        Only predict the Q-value using the first network.
        This allows to reduce computation when all the estimates are not needed
        (e.g. when updating the policy in TD3).
        """
        value = self.q_networks[0](obs, actions)
        return value