import torch as t
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn.pool import global_max_pool
from typing import Optional, Sequence, Union
from rl.net import NeuralNetworkModule
from model.resnet import ResnetBlockFC
from model.robot_simple_latent_conditioned_moe.encoder import RLEncoder
from model.robot_simple_latent_conditioned_moe.decoder import RLDecoder


def _expand_to_list(
    value: Union[int, Sequence[int]], num_experts: int, name: str
) -> Sequence[int]:
    """Broadcast a scalar or validate a per-expert sequence."""
    if isinstance(value, int):
        return [value] * num_experts
    if isinstance(value, (list, tuple)):
        if len(value) != num_experts:
            raise ValueError(f"{name} must have length {num_experts}.")
        return list(value)
    raise TypeError(f"{name} must be int or sequence.")


class RobotLatentEncoderHeadEncoder(nn.Module):
    """Encoder that maps voxels/kinematics into edge-aligned node features."""

    def __init__(
        self,
        normalize_radius: float,
        grid_resolution: int = 64,
        voxel_hidden_dim: int = 4,
        voxel_c_dim: int = 4,
        voxel_decoder_hidden_dim: int = 4,
        voxel_feature_dim: int = 8,
        kinematic_hidden_dims: Sequence[int] = (256, 512, 512),
    ):
        """
        Args:
            normalize_radius: Radius used to normalize spatial inputs.
            grid_resolution: Grid resolution for voxel encoder/decoder.
            voxel_hidden_dim: Hidden width inside the voxel encoder.
            voxel_c_dim: Latent channel count output by the voxel encoder.
            voxel_decoder_hidden_dim: Hidden width of the voxel decoder.
            voxel_feature_dim: Output feature dimension per node from the voxel decoder.
            kinematic_hidden_dims: TransformerConv widths (len==3).
        """
        super().__init__()
        if len(kinematic_hidden_dims) != 3:
            raise ValueError("kinematic_hidden_dims must contain three widths.")

        self.normalize_radius = normalize_radius
        self.voxel_feature_dim = voxel_feature_dim
        self.kinematic_out_dim = kinematic_hidden_dims[-1]

        self.voxel_encoder_net = RLEncoder(
            f_dim=1,
            c_dim=voxel_c_dim,
            hidden_dim=voxel_hidden_dim,
            scatter_type="max",
            grid_resolution=grid_resolution,
            padding=0.02,
            unet3d_kwargs={"layer_order": "cl", "num_levels": 1},
        )
        self.voxel_decoder_net = RLDecoder(
            f_dim=voxel_feature_dim,
            c_dim=voxel_c_dim,
            hidden_dim=voxel_decoder_hidden_dim,
            grid_resolution=grid_resolution,
            sample_mode="bilinear",
            padding=0.02,
        )
        kin_in_channels = voxel_feature_dim + 14  # [node_voxel_feat, node_features]
        self.kinematic_encoder_conv1 = TransformerConv(
            in_channels=kin_in_channels,
            out_channels=kinematic_hidden_dims[0],
            edge_dim=9,
        )
        self.kinematic_encoder_conv2 = TransformerConv(
            in_channels=kinematic_hidden_dims[0],
            out_channels=kinematic_hidden_dims[1],
            edge_dim=9,
        )
        self.kinematic_encoder_conv3 = TransformerConv(
            in_channels=kinematic_hidden_dims[1],
            out_channels=kinematic_hidden_dims[2],
            edge_dim=9,
        )

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
    ):
        """
        Note:
            Voxel num V* corresponds to voxel num, variable across batch.
            Node num N* corresponds to rigid body num, variable across batch.
            Edge num E* corresponds to joint num, variable across batch.
            There are E* * 2 edges for edges in both directions.
        Args:
            com: List length B, each [1, 3].
            relative_voxel_positions: List length B, each [V*, 3], relative to COM.
            voxel_features: List length B, each [V*, 1].
            relative_node_positions: List length B, each [N*, 3], relative to COM.
            node_features: List length B, each [N*, 14].
            edges: List length B, each [2, E* * 2].
            edge_features: List length B, each [E* * 2, 9].
        Returns:
            out_node_features: [sum(N*), kinematic_out_dim].
            out_pooled_features: [B, kinematic_out_dim].
            out_unique_edge_batch: [sum(E*)] indices mapping edges to batch.
            out_unique_edges: [2, sum(E*)] filtered unique edges.
            unique_edge_num_per_robot: list length B with per-robot edge counts.
        """
        com = t.concatenate(com, dim=0)

        # Sample equal number of voxels for each sample in the batch.
        sample_num = min([len(v) for v in relative_voxel_positions])
        voxel_sample_mask = [
            t.randperm(len(v))[:sample_num] for v in relative_voxel_positions
        ]
        relative_voxel_positions = t.stack(
            [v[mask] for v, mask in zip(relative_voxel_positions, voxel_sample_mask)],
            dim=0,
        )
        voxel_features = t.stack(
            [v[mask] for v, mask in zip(voxel_features, voxel_sample_mask)], dim=0
        )

        device = com.device
        norm_voxel_positions = (
            t.clamp(relative_voxel_positions / (2 * self.normalize_radius), -0.5, 0.5)
        ) + 0.5
        latent = self.voxel_encoder_net(norm_voxel_positions, voxel_features)

        node_batch = []
        node_num = []
        max_node_num = 0
        for i in range(len(relative_node_positions)):
            node_batch += [i] * len(relative_node_positions[i])
            node_num.append(len(relative_node_positions[i]))
            max_node_num = max(max_node_num, len(relative_node_positions[i]))
        node_batch = t.tensor(node_batch, dtype=t.long, device=device)
        norm_node_positions_batch = t.zeros(
            [com.shape[0], max_node_num, 3],
            dtype=relative_voxel_positions.dtype,
            device=device,
        )
        for i in range(len(relative_node_positions)):
            norm_node_positions_batch[i, : len(relative_node_positions[i])] = (
                t.clamp(
                    relative_node_positions[i] / (2 * self.normalize_radius), -0.5, 0.5
                )
                + 0.5
            )
        # Add offsets to edge indices.
        edge_indices = []
        unique_edge_indices = []
        edge_batch = []
        unique_edge_batch = []
        unique_edge_num_per_robot = []
        offset = 0
        for i, (e, n_num) in enumerate(zip(edges, node_num)):
            edge_batch += [i] * e.shape[1]

            edge_indices.append(e + offset)
            unique_edge_batch += [i] * (e.shape[1] // 2)
            unique_edge_num_per_robot.append(e.shape[1] // 2)
            unique_edge_indices.append((e + offset)[:, : e.shape[1] // 2])
            offset += n_num

        edge_batch = t.tensor(edge_batch, dtype=t.long, device=device)
        edge_indices = t.concatenate(edge_indices, dim=1)
        unique_edge_batch = t.tensor(unique_edge_batch, dtype=t.long, device=device)
        unique_edge_indices = t.concatenate(unique_edge_indices, dim=1)
        edge_features = t.concatenate(edge_features, dim=0)

        node_voxel_features_batch = self.voxel_decoder_net(
            norm_node_positions_batch, latent
        )
        node_voxel_features = t.concatenate(
            [node_voxel_features_batch[i, :num] for i, num in enumerate(node_num)],
            dim=0,
        )

        # Shape [N_0 + N_1 + ... N_(B-1), 8 + 14 = 22]
        node_features = t.concatenate(
            (node_voxel_features, t.concatenate(node_features, dim=0).to(device)),
            dim=1,
        )

        out1 = self.kinematic_encoder_conv1(node_features, edge_indices, edge_features)
        out2 = F.leaky_relu(
            self.kinematic_encoder_conv2(
                F.leaky_relu(out1, negative_slope=0.2), edge_indices, edge_features
            ),
            negative_slope=0.2,
        )
        out_node_feature = self.kinematic_encoder_conv3(
            out2, edge_indices, edge_features
        )

        pooled_out_node_feature = global_max_pool(
            out_node_feature, node_batch, size=com.shape[0]
        )

        return (
            out_node_feature,
            pooled_out_node_feature,
            unique_edge_batch,
            unique_edge_indices,
            unique_edge_num_per_robot,
        )


class RobotLatentEncoderHeadCritic(NeuralNetworkModule):
    """Critic using the encoder output (no gating)."""

    def __init__(
        self,
        encoder: RobotLatentEncoderHeadEncoder,
        freeze_encoder: bool = False,
    ):
        super().__init__()
        self.encoder = encoder
        self.value_resnet = nn.Sequential(
            ResnetBlockFC(515, 256),
            ResnetBlockFC(256, 64),
            ResnetBlockFC(64, 1),
        )
        self.freeze_encoder = freeze_encoder
        self.set_input_module(self.value_resnet)
        self.set_output_module(self.value_resnet)

    def forward(
        self,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
    ):
        if self.freeze_encoder:
            with t.no_grad():
                _, pooled_out_node_features, *__ = self.encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )
        else:
            _, pooled_out_node_features, *__ = self.encoder(
                com,
                relative_voxel_positions,
                voxel_features,
                relative_node_positions,
                node_features,
                edges,
                edge_features,
            )

        # Shape [B, 512+3=515]
        pooled_out_node_features = t.cat(
            (
                pooled_out_node_features,
                t.cat([v for v in velocity], dim=0),
            ),
            dim=1,
        )

        value = self.value_resnet(pooled_out_node_features)
        return value.view(-1)


class RobotLatentActorEncoderHeadMoE(NeuralNetworkModule):
    """Mixture-of-encoders robot actor where each encoder owns its actor head."""

    def __init__(
        self,
        normalize_radius: float,
        grid_resolution: int = 64,
        freeze_encoder: bool = False,
        num_experts: int = 8,
        encoder_voxel_hidden_dim: int = 4,
        encoder_voxel_c_dim: int = 4,
        encoder_voxel_decoder_hidden_dim: int = 4,
        encoder_voxel_feature_dim: int = 8,
        kinematic_hidden_dims: Sequence[int] = (256, 512, 512),
        resolution: int = 5,
        actor_hidden1: Union[int, Sequence[int]] = 128,
        actor_hidden2: Union[int, Sequence[int]] = 64,
        gating_threshold: Optional[float] = None,
        gating_top_k: Optional[int] = None,
        latent_dim: int = 512,
    ):
        """
        Args:
            normalize_radius: Radius used to normalize spatial inputs.
            grid_resolution: Grid resolution for voxel encoder/decoder.
            freeze_encoder: Whether to block gradients to encoders.
            num_experts: Number of encoder/actor expert pairs.
            encoder_voxel_hidden_dim: Hidden width inside each voxel encoder.
            encoder_voxel_c_dim: Latent channel count output by voxel encoders.
            encoder_voxel_decoder_hidden_dim: Hidden width of voxel decoders.
            encoder_voxel_feature_dim: Output feature dimension per node from voxel decoders.
            kinematic_hidden_dims: TransformerConv widths (len==3) shared across experts.
            resolution: Number of discrete actions.
            actor_hidden1: First FC width per expert (scalar or list length num_experts).
            actor_hidden2: Second FC width per expert (scalar or list length num_experts).
            gating_threshold: Optional minimum weight cutoff before renormalization.
            gating_top_k: Optionally restrict selection to top-k experts.
            latent_dim: Dimension of VAE latent vector passed for gating.
        """
        super().__init__()
        if gating_top_k is not None and gating_top_k < 1:
            raise ValueError("gating_top_k must be positive.")

        self.resolution = resolution
        self.freeze_encoder = freeze_encoder
        self.num_experts = num_experts
        self.gating_threshold = gating_threshold
        self.gating_top_k = (
            min(gating_top_k, num_experts) if gating_top_k is not None else None
        )
        self.latent_dim = latent_dim

        actor_hidden1 = _expand_to_list(actor_hidden1, num_experts, "actor_hidden1")
        actor_hidden2 = _expand_to_list(actor_hidden2, num_experts, "actor_hidden2")

        self.encoders = nn.ModuleList(
            [
                RobotLatentEncoderHeadEncoder(
                    normalize_radius=normalize_radius,
                    grid_resolution=grid_resolution,
                    voxel_hidden_dim=encoder_voxel_hidden_dim,
                    voxel_c_dim=encoder_voxel_c_dim,
                    voxel_decoder_hidden_dim=encoder_voxel_decoder_hidden_dim,
                    voxel_feature_dim=encoder_voxel_feature_dim,
                    kinematic_hidden_dims=kinematic_hidden_dims,
                )
                for _ in range(num_experts)
            ]
        )

        node_feature_dim = kinematic_hidden_dims[-1]
        self.edge_feature_dim = node_feature_dim * 3 + 3
        self.gate_input_dim = latent_dim
        if num_experts == 1:
            # Single-expert mode: disable the gating network entirely so we do not
            # introduce any routing linear when there is nothing to route.
            self.gate = None
        else:
            # Gate maps latent vector to logits for all experts.
            self.gate = nn.Linear(self.gate_input_dim, num_experts)

        self.action_experts = nn.ModuleList(
            [
                nn.Sequential(
                    ResnetBlockFC(self.edge_feature_dim, h1),
                    ResnetBlockFC(h1, h2),
                    ResnetBlockFC(h2, resolution),
                )
                for h1, h2 in zip(actor_hidden1, actor_hidden2)
            ]
        )

        self.set_input_module(self.action_experts[0])
        self.set_output_module(self.action_experts[0])

    def _compute_gate_weights(self, gate_logits: t.Tensor) -> t.Tensor:
        """Convert gate logits [B, num_experts] into normalized weights."""
        weights = F.softmax(gate_logits, dim=1)
        if self.gating_top_k is not None:
            _, top_idx = gate_logits.topk(self.gating_top_k, dim=1)
            mask = t.zeros_like(weights).scatter_(1, top_idx, 1.0)
            weights = weights * mask
        if self.gating_threshold is not None:
            weights = weights * (weights >= self.gating_threshold).float()

        # Renormalize weights so that each row sum to 1.0
        weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)
        return weights

    def _build_edge_features(
        self,
        out_node_features: t.Tensor,
        pooled_out_node_features: t.Tensor,
        out_unique_edge_batch: t.Tensor,
        out_unique_edge_indices: t.Tensor,
        unique_edge_num_per_robot,
        velocity,
    ) -> t.Tensor:
        """Assemble per-edge features of shape [sum_E, edge_feature_dim].

        Args:
            out_node_features: Node embeddings, shape [sum(N*), F].
            pooled_out_node_features: Pooled node features per batch element, [B, F].
            out_unique_edge_batch: Batch indices per unique edge, [sum(E*)].
            out_unique_edge_indices: Edge index tensor, [2, sum(E*)].
            unique_edge_num_per_robot: List of unique edge counts per robot.
            velocity: List length B with per-robot velocity vectors [1, 3].
        """
        # out_node_features: [sum_N, 512], pooled_out_node_features: [B, 512]
        # out_unique_edge_indices: [2, sum_E], out_unique_edge_batch: [sum_E]
        node_dim = out_node_features.shape[1]
        edge_total = out_unique_edge_indices.shape[1]
        first_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[0].unsqueeze(1).expand(edge_total, node_dim),
        )  # [sum_E, 512]
        second_node_features = t.gather(
            out_node_features,
            dim=0,
            index=out_unique_edge_indices[1].unsqueeze(1).expand(edge_total, node_dim),
        )  # [sum_E, 512]
        node_pooled_features = t.gather(
            pooled_out_node_features,
            dim=0,
            index=out_unique_edge_batch.unsqueeze(1).expand(
                out_unique_edge_batch.shape[0], node_dim
            ),
        )  # [sum_E, 512] broadcasted pooled feature per edge

        # Velocity expanded per edge to preserve the original action interface.
        velocity_per_edge = t.cat(
            [v.expand(unique_edge_num_per_robot[i], 3) for i, v in enumerate(velocity)],
            dim=0,
        )

        edge_features_of_robot = t.cat(
            (first_node_features, second_node_features, node_pooled_features),
            dim=1,
        )  # [sum_E, 1536]
        edge_features_of_robot = t.cat(
            (edge_features_of_robot, velocity_per_edge), dim=1
        )  # [sum_E, 1539] append per-robot velocity repeated per edge
        return edge_features_of_robot

    def forward(
        self,
        robot_latent,
        com,
        relative_voxel_positions,
        voxel_features,
        relative_node_positions,
        node_features,
        edges,
        edge_features,
        velocity,
        action=None,
    ):
        """Run MoE policy with encoder-level gating. Signature matches RobotActor.

        Note:
            Voxel num V* corresponds to voxel num, variable across batch.
            Node num N* corresponds to rigid body num, variable across batch.
            Edge num E* corresponds to joint num, variable across batch.
            There are E* * 2 edges for edges in both directions.
        Args:
            robot_latent: List length B (each [1, latent_dim]) or batched tensor [B, latent_dim].
            com: List length B, each [1, 3].
            relative_voxel_positions: List length B, each [V*, 3], relative to COM.
            voxel_features: List length B, each [V*, 1].
            relative_node_positions: List length B, each [N*, 3], relative to COM.
            node_features: List length B, each [N*, 14].
            edges: List length B, each [2, E* * 2].
            edge_features: List length B, each [E* * 2, 9].

        """
        batch_size = len(com)
        # Convert latent inputs to tensor [B, latent_dim]
        if isinstance(robot_latent, (list, tuple)):
            robot_latent_tensor = t.cat(robot_latent, dim=0)
        else:
            robot_latent_tensor = robot_latent
        robot_latent_flat = robot_latent_tensor.view(batch_size, -1)

        expert_action_logits = []

        if self.gate is None:
            # Single-expert mode: gate weights are fixed to 1.0 so the sole
            # expert is used directly, and robot_latent is ignored for routing.
            expert_param = next(self.action_experts[0].parameters())
            gate_weights = t.ones(
                (batch_size, 1), device=expert_param.device, dtype=expert_param.dtype
            )
        else:
            # Compute gate logits directly from latent vector only.
            robot_latent_flat = robot_latent_flat.to(self.gate.weight.device)

            # Each robot latent comes in as [1, latent_dim]
            # concatenate B robots yields [B, latent_dim]
            # (latent_dim defaults to 512).
            gate_input = robot_latent_flat
            gate_logits = self.gate(gate_input)  # [B, num_experts]
            gate_weights = self._compute_gate_weights(gate_logits)

        ref_unique_edge_batch = None
        ref_unique_edge_indices = None

        for encoder, action_head in zip(self.encoders, self.action_experts):
            if self.freeze_encoder:
                with t.no_grad():
                    (
                        out_node_features,
                        pooled_out_node_features,
                        out_unique_edge_batch,
                        out_unique_edge_indices,
                        unique_edge_num_per_robot,
                    ) = encoder(
                        com,
                        relative_voxel_positions,
                        voxel_features,
                        relative_node_positions,
                        node_features,
                        edges,
                        edge_features,
                    )
            else:
                (
                    out_node_features,
                    pooled_out_node_features,
                    out_unique_edge_batch,
                    out_unique_edge_indices,
                    unique_edge_num_per_robot,
                ) = encoder(
                    com,
                    relative_voxel_positions,
                    voxel_features,
                    relative_node_positions,
                    node_features,
                    edges,
                    edge_features,
                )

            if ref_unique_edge_batch is None:
                ref_unique_edge_batch = out_unique_edge_batch
                ref_unique_edge_indices = out_unique_edge_indices
                # ref_unique_edge_batch: [E] robot index for each unique edge;
                # ref_unique_edge_indices: [2, E] node indices for those edges.
            else:
                if (
                    out_unique_edge_indices.shape != ref_unique_edge_indices.shape
                    or not t.equal(out_unique_edge_batch, ref_unique_edge_batch)
                ):
                    raise RuntimeError(
                        "Encoder outputs disagree on edge structure; inputs must match."
                    )

            edge_features_of_robot = self._build_edge_features(
                out_node_features,
                pooled_out_node_features,
                out_unique_edge_batch,
                out_unique_edge_indices,
                unique_edge_num_per_robot,
                velocity,
            )

            expert_action_logits.append(action_head(edge_features_of_robot))

        # Mix experts: ref_unique_edge_batch stores, for each edge (length E), the
        # corresponding robot index in [0, B). Using it to index gate_weights
        # broadcasts the per-robot gate logits to all edges of that robot.
        expert_action_stack = t.stack(
            expert_action_logits, dim=1
        )  # [E, num_experts, resolution]
        edge_gate_weights = gate_weights[ref_unique_edge_batch]  # [E, num_experts]
        # Weighted sum over experts -> final logits per edge, shape [E, resolution].
        action_logits = (expert_action_stack * edge_gate_weights.unsqueeze(-1)).sum(
            dim=1
        )

        dist = Categorical(logits=action_logits)
        raw_action = dist.sample() if action is None else t.concatenate(action)
        raw_log_prob = dist.log_prob(raw_action)
        raw_entropy = dist.entropy()

        action_out = []
        log_prob = []
        entropy = []
        for idx in range(batch_size):
            action_out.append(raw_action[ref_unique_edge_batch == idx])
            log_prob.append(t.sum(raw_log_prob[ref_unique_edge_batch == idx]))
            entropy.append(t.sum(raw_entropy[ref_unique_edge_batch == idx]))
        log_prob = t.stack(log_prob)
        entropy = t.stack(entropy)
        return action_out, log_prob, entropy, gate_weights
